diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml index a02c97a96..de8c2b1c4 100644 --- a/.github/workflows/cmake.yml +++ b/.github/workflows/cmake.yml @@ -40,6 +40,7 @@ jobs: -DSEQUANT_USE_SYSTEM_BOOST_HASH=OFF -DCMAKE_CXX_STANDARD=20 -DCMAKE_CXX_EXTENSIONS=OFF + -DSEQUANT_BUILD_BENCHMARKS=ON steps: - uses: actions/checkout@v4 diff --git a/CMakeLists.txt b/CMakeLists.txt index 3530db66e..5ab939237 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -188,7 +188,6 @@ add_library(SeQuant-bliss set(SeQuant_src ${PROJECT_BINARY_DIR}/SeQuant/version.hpp SeQuant/version.cpp - SeQuant/core/abstract_tensor.cpp SeQuant/core/abstract_tensor.hpp SeQuant/core/algorithm.hpp SeQuant/core/any.hpp @@ -210,6 +209,7 @@ set(SeQuant_src SeQuant/core/expr.hpp SeQuant/core/expr_algorithm.hpp SeQuant/core/expr_operator.hpp + SeQuant/core/hash.cpp SeQuant/core/hash.hpp SeQuant/core/hugenholtz.hpp SeQuant/core/index.cpp @@ -239,8 +239,12 @@ set(SeQuant_src SeQuant/core/tag.hpp SeQuant/core/tensor.cpp SeQuant/core/tensor.hpp + SeQuant/core/tensor_canonicalizer.cpp + SeQuant/core/tensor_canonicalizer.hpp SeQuant/core/tensor_network.cpp SeQuant/core/tensor_network.hpp + SeQuant/core/tensor_network_v2.cpp + SeQuant/core/tensor_network_v2.hpp SeQuant/core/timer.hpp SeQuant/core/utility/context.hpp SeQuant/core/utility/indices.hpp @@ -249,6 +253,8 @@ set(SeQuant_src SeQuant/core/utility/singleton.hpp SeQuant/core/utility/string.hpp SeQuant/core/utility/string.cpp + SeQuant/core/vertex_painter.cpp + SeQuant/core/vertex_painter.hpp SeQuant/core/wick.hpp SeQuant/core/wick.impl.hpp SeQuant/core/wolfram.hpp @@ -410,6 +416,7 @@ if (BUILD_TESTING) include(FindOrFetchCatch2) set(utests_src + tests/unit/catch2_sequant.hpp tests/unit/test_space.cpp tests/unit/test_index.cpp tests/unit/test_op.cpp @@ -559,6 +566,11 @@ if (BUILD_TESTING) target_link_libraries(${example${i}} SeQuant) endforeach () + set(example12 "tensor_network_graphs") + add_executable(${example12} EXCLUDE_FROM_ALL + examples/${example12}/${example12}.cpp) + target_link_libraries(${example12} SeQuant) + # add tests for running examples foreach (i RANGE ${lastexample}) if (TARGET ${example${i}}) diff --git a/SeQuant/core/abstract_tensor.cpp b/SeQuant/core/abstract_tensor.cpp deleted file mode 100644 index 75c6b0ad2..000000000 --- a/SeQuant/core/abstract_tensor.cpp +++ /dev/null @@ -1,117 +0,0 @@ -// -// Created by Eduard Valeyev on 2019-03-24. -// - -#include -#include -#include -#include - -#include - -#include - -namespace sequant { - -TensorCanonicalizer::~TensorCanonicalizer() = default; - -std::pair>*, - std::unique_lock> -TensorCanonicalizer::instance_map_accessor() { - static container::map> - map_; - static std::recursive_mutex mtx_; - return std::make_pair(&map_, std::unique_lock{mtx_}); -} - -container::vector& -TensorCanonicalizer::cardinal_tensor_labels_accessor() { - static container::vector ctlabels_; - return ctlabels_; -} - -std::shared_ptr -TensorCanonicalizer::nondefault_instance_ptr(std::wstring_view label) { - auto&& [map_ptr, lock] = instance_map_accessor(); - // look for label-specific canonicalizer - auto it = map_ptr->find(std::wstring{label}); - if (it != map_ptr->end()) { - return it->second; - } else - return {}; -} - -std::shared_ptr TensorCanonicalizer::instance_ptr( - std::wstring_view label) { - auto result = nondefault_instance_ptr(label); - if (!result) // not found? look for default - result = nondefault_instance_ptr(L""); - return result; -} - -std::shared_ptr TensorCanonicalizer::instance( - std::wstring_view label) { - auto inst_ptr = instance_ptr(label); - if (!inst_ptr) - throw std::runtime_error( - "must first register canonicalizer via " - "TensorCanonicalizer::register_instance(...)"); - return inst_ptr; -} - -void TensorCanonicalizer::register_instance( - std::shared_ptr can, std::wstring_view label) { - auto&& [map_ptr, lock] = instance_map_accessor(); - (*map_ptr)[std::wstring{label}] = can; -} - -bool TensorCanonicalizer::try_register_instance( - std::shared_ptr can, std::wstring_view label) { - auto&& [map_ptr, lock] = instance_map_accessor(); - if (!map_ptr->contains(std::wstring{label})) { - (*map_ptr)[std::wstring{label}] = can; - return true; - } else - return false; -} - -void TensorCanonicalizer::deregister_instance(std::wstring_view label) { - auto&& [map_ptr, lock] = instance_map_accessor(); - auto it = map_ptr->find(std::wstring{label}); - if (it != map_ptr->end()) { - map_ptr->erase(it); - } -} - -std::function - TensorCanonicalizer::index_comparer_ = std::less{}; - -const std::function& -TensorCanonicalizer::index_comparer() { - return index_comparer_; -} - -void TensorCanonicalizer::index_comparer( - std::function comparer) { - index_comparer_ = std::move(comparer); -} - -ExprPtr NullTensorCanonicalizer::apply(AbstractTensor&) { return {}; } - -ExprPtr DefaultTensorCanonicalizer::apply(AbstractTensor& t) { - // tag all indices as ext->true/ind->false - auto braket_view = braket(t); - ranges::for_each(braket_view, [this](auto& idx) { - auto it = external_indices_.find(std::wstring(idx.label())); - auto is_ext = it != external_indices_.end(); - idx.tag().assign( - is_ext ? 0 : 1); // ext -> 0, int -> 1, so ext will come before - }); - - auto result = this->apply(t, this->index_comparer_); - reset_tags(t); - - return result; -} - -} // namespace sequant diff --git a/SeQuant/core/abstract_tensor.hpp b/SeQuant/core/abstract_tensor.hpp index e49d621af..5d4363561 100644 --- a/SeQuant/core/abstract_tensor.hpp +++ b/SeQuant/core/abstract_tensor.hpp @@ -29,33 +29,50 @@ namespace sequant { -/// This interface class defines a Tensor concept. Object @c t of a type that -/// meets the concept must satisfy the following: -/// - @c bra(t) , @c ket(t) , and @c braket(t) are valid expressions and -/// evaluate to a range of Index objects; -/// - @c bra_rank(t) and @c ket_rank(t) are valid expression and return -/// sizes of the @c bra(t) and @c ket(t) ranges, respectively; -/// - @c symmetry(t) is a valid expression and evaluates to a Symmetry -/// object that describes the symmetry of bra/ket of a -/// _particle-symmetric_ @c t ; -/// - @c braket_symmetry(t) is a valid expression and evaluates to a -/// BraKetSymmetry object that describes the bra-ket symmetry of @c t ; -/// - @c particle_symmetry(t) is a valid expression and evaluates to a -/// ParticleSymmetry object that describes the symmetry of @c t with -/// respect to permutations of particles; -/// - @c color(t) is a valid expression and returns whether a -/// nonnegative integer that identifies the type of a tensor; tensors -/// with different colors can be reordered in a Product at will -/// - @c is_cnumber(t) is a valid expression and returns whether t -/// commutes with other tensor of same color (tensors of different -/// colors are, for now, always assumed to commute) -/// - @c label(t) is a valid expression and its return is convertible to -/// a std::wstring; -/// - @c to_latex(t) is a valid expression and its return is convertible -/// to a std::wstring. -/// To adapt an existing class intrusively derive it from AbstractTensor and -/// implement all member functions. This allows to implememnt heterogeneous -/// containers of objects that meet the Tensor concept. +class TensorCanonicalizer; + +/// AbstractTensor is a [tensor](https://en.wikipedia.org/wiki/Tensor) over +/// general (i.e., not necessarily commutative) +/// rings. A tensor with \f$ k \geq 0 \f$ contravariant +/// (ket, [Dirac notation](https://en.wikipedia.org/wiki/Bra-ket_notation) ) and +/// \f$ b \geq 0 \f$ covariant (bra) modes +/// describes elements of a tensor product of \f$ b+k \f$ vector spaces. +/// Equivalently it represents a linear map between the tensor product +/// of \f$ k \f$ _primal_ vector spaces to the tensor product of \f$ b \f$ +/// _dual_ vector spaces. Tensor modes are 1-to-1 represented by unique +/// [indices](https://en.wikipedia.org/wiki/Abstract_index_notation), +/// represented by Index objects. +/// +/// It is also necessary to support modes that are "array-like" in that they +/// do not refer to a vector space or its dual; such modes can represent +/// ordinal indices (e.g. to treat a collection/sequence of tensors as +/// a single tensor) Thus each tensor has zero or more auxiliary (aux) modes. +/// The aux modes are invariant under the transformations of vector spaces and +/// do not contribute to the tensor symmetries. +/// +/// Tensors can have the following symmetries: +/// - Tensors can be symmetric or nonsymmetric with respect to the transposition +/// of corresponding (first, second, etc.) modes in bra/ket mode ranges. This +/// symmetry is used to treat particle as indistinguishable or distinguishable +/// in many-particle quantum mechanics context. +/// - Tensors can be symmetric, antisymmetric, and nonsymmetric +/// with respect to the transposition of modes within the bra or ket sets. +/// This symmetry is used to model the +/// distinguishable and indistinguishable (bosonic and fermionic) degrees +/// of freedom in many-body quantum mechanics context. More complicated +/// symmetries are not yet supported. +/// - Tensors can be symmetric, conjugate, or nonsymmetric with respect to +/// swap of bra with ket. This symmetry corresponds to time reversal in +/// physical simulation. +/// +/// Lastly, the supporting rings are not assumed to be scalars, hence tensor +/// product supporting the concept of tensor is not necessarily commutative. +/// +/// \note This interface class defines a Tensor _concept_. All Tensor objects +/// must fulfill the is_tensor trait (see below). To adapt an existing class +/// intrusively derive it from AbstractTensor and implement all member +/// functions. This allows to implement heterogeneous containers of objects +/// that meet the Tensor concept. class AbstractTensor { inline auto missing_instantiation_for(const char* fn_name) const { std::ostringstream oss; @@ -78,30 +95,54 @@ class AbstractTensor { ranges::any_view; - /// view of a contiguous range of Index objects + /// accessor bra (covariant) indices + /// @return view of a contiguous range of Index objects virtual const_any_view_randsz _bra() const { throw missing_instantiation_for("_bra"); } - /// view of a contiguous range of Index objects + /// accesses ket (contravariant) indices + /// @return view of a contiguous range of Index objects virtual const_any_view_randsz _ket() const { throw missing_instantiation_for("_ket"); } + /// accesses aux (invariant) indices + /// @return view of a contiguous range of Index objects + virtual const_any_view_randsz _aux() const { + throw missing_instantiation_for("_aux"); + } + /// accesses bra and ket indices /// view of a not necessarily contiguous range of Index objects virtual const_any_view_rand _braket() const { throw missing_instantiation_for("_braket"); } + /// accesses bra, ket, and aux indices + /// @return view of a not necessarily contiguous range of Index objects + virtual const_any_view_rand _indices() const { + throw missing_instantiation_for("_indices"); + } + /// @return the number of bra indices virtual std::size_t _bra_rank() const { throw missing_instantiation_for("_bra_rank"); } + /// @return the number of ket indices virtual std::size_t _ket_rank() const { throw missing_instantiation_for("_ket_rank"); } + /// @return the number of aux indices + virtual std::size_t _aux_rank() const { + throw missing_instantiation_for("_aux_rank"); + } + /// @return the permutational symmetry of the vector space indices of + /// the tensor virtual Symmetry _symmetry() const { throw missing_instantiation_for("_symmetry"); } + /// @return the symmetry of tensor under exchange of vectors space (bra) and + /// its dual (ket) virtual BraKetSymmetry _braket_symmetry() const { throw missing_instantiation_for("_braket_symmetry"); } + /// @return the symmetry of tensor under exchange of bra and ket virtual ParticleSymmetry _particle_symmetry() const { throw missing_instantiation_for("_particle_symmetry"); } @@ -143,6 +184,12 @@ class AbstractTensor { virtual any_view_randsz _ket_mutable() { throw missing_instantiation_for("_ket_mutable"); } + /// @return mutable view of aux indices + /// @warning this is used for mutable access, flush memoized state before + /// returning! + virtual any_view_randsz _aux_mutable() { + throw missing_instantiation_for("_aux_mutable"); + } friend class TensorCanonicalizer; }; @@ -151,8 +198,10 @@ class AbstractTensor { /// objects. /// @{ inline auto braket(const AbstractTensor& t) { return t._braket(); } +inline auto indices(const AbstractTensor& t) { return t._indices(); } inline auto bra_rank(const AbstractTensor& t) { return t._bra_rank(); } inline auto ket_rank(const AbstractTensor& t) { return t._ket_rank(); } +inline auto aux_rank(const AbstractTensor& t) { return t._aux_rank(); } inline auto symmetry(const AbstractTensor& t) { return t._symmetry(); } inline auto braket_symmetry(const AbstractTensor& t) { return t._braket_symmetry(); @@ -164,8 +213,59 @@ inline auto color(const AbstractTensor& t) { return t._color(); } inline auto is_cnumber(const AbstractTensor& t) { return t._is_cnumber(); } inline auto label(const AbstractTensor& t) { return t._label(); } inline auto to_latex(const AbstractTensor& t) { return t._to_latex(); } -/// @tparam IndexMap a {source Index -> target Index} map type; if it is not @c -/// container::map + +/// Type trait for checking whether a given class fulfills the Tensor interface +/// requirements Object @c t of a type that meets the concept must satisfy the +/// following: +/// - @c braket(t) and +/// @c indices(t) are valid expressions and evaluate to a range of Index +/// objects; +/// - @c bra_rank(t), @c ket_rank(t) and @c aux_rank(t) are valid +/// expression and return sizes of the @c bra(t), @c ket(t) and +/// @c aux(t) ranges, respectively; +/// - @c symmetry(t) is a valid expression and evaluates to a Symmetry +/// object that describes the symmetry of bra/ket of a +/// _particle-symmetric_ @c t ; +/// - @c braket_symmetry(t) is a valid expression and evaluates to a +/// BraKetSymmetry object that describes the bra-ket symmetry of @c t ; +/// - @c particle_symmetry(t) is a valid expression and evaluates to a +/// ParticleSymmetry object that describes the symmetry of @c t with +/// respect to permutations of particles; +/// - @c color(t) is a valid expression and returns whether a +/// nonnegative integer that identifies the type of a tensor; tensors +/// with different colors can be reordered in a Product at will +/// - @c is_cnumber(t) is a valid expression and returns whether t +/// commutes with other tensor of same color (tensors of different +/// colors are, for now, always assumed to commute) +/// - @c label(t) is a valid expression and its return is convertible to +/// a std::wstring; +/// - @c to_latex(t) is a valid expression and its return is convertible +/// to a std::wstring. +template +struct is_tensor + : std::bool_constant< + std::is_invocable_v && + std::is_invocable_v && + std::is_invocable_v && + std::is_invocable_v && + std::is_invocable_v && + std::is_invocable_v && + std::is_invocable_v && + std::is_invocable_v && + std::is_invocable_v && + std::is_invocable_v && + std::is_invocable_v && + std::is_invocable_v< + decltype(static_cast(to_latex)), T>> { +}; +template +constexpr bool is_tensor_v = is_tensor::value; +static_assert(is_tensor_v, + "The AbstractTensor class does not fulfill the requirements of " + "the Tensor interface"); + +/// @tparam IndexMap a {source Index -> target Index} map type; if it is not +/// @c container::map /// will need to make a copy. /// @param[in,out] t an AbstractTensor object whose indices will be transformed /// @param[in] index_map a const reference to an IndexMap object that specifies @@ -198,203 +298,6 @@ inline void reset_tags(AbstractTensor& t) { t._reset_tags(); } using AbstractTensorPtr = std::shared_ptr; -/// @brief Base class for Tensor canonicalizers -/// To make custom canonicalizer make a derived class and register an instance -/// of that class with TensorCanonicalizer::register_instance -class TensorCanonicalizer { - public: - virtual ~TensorCanonicalizer(); - - /// @return ptr to the TensorCanonicalizer object, if any, that had been - /// previously registered via TensorCanonicalizer::register_instance() - /// with @c label , or to the default canonicalizer, if any - static std::shared_ptr instance_ptr( - std::wstring_view label = L""); - - /// @return ptr to the TensorCanonicalizer object, if any, that had been - /// previously registered via TensorCanonicalizer::register_instance() - /// with @c label - /// @sa instance_ptr - static std::shared_ptr nondefault_instance_ptr( - std::wstring_view label); - - /// @return a TensorCanonicalizer previously registered via - /// TensorCanonicalizer::register_instance() with @c label or to the default - /// canonicalizer - /// @throw std::runtime_error if no canonicalizer has been registered - static std::shared_ptr instance( - std::wstring_view label = L""); - - /// registers @c canonicalizer to be applied to Tensor objects with label - /// @c label ; leave the label empty if @c canonicalizer is to apply to Tensor - /// objects with any label - /// @note if a canonicalizer registered with label @c label exists, it is - /// replaced - static void register_instance( - std::shared_ptr canonicalizer, - std::wstring_view label = L""); - - /// tries to register @c canonicalizer to be applied to Tensor objects - /// with label @c label ; leave the label empty if @c canonicalizer is to - /// apply to Tensor objects with any label - /// @return false if there is already a canonicalizer registered with @c label - /// @sa regiter_instance - static bool try_register_instance( - std::shared_ptr canonicalizer, - std::wstring_view label = L""); - - /// deregisters canonicalizer (if any) registered previously - /// to be applied to tensors with label @c label - static void deregister_instance(std::wstring_view label = L""); - - /// @return a list of Tensor labels with lexicographic preference (in order) - static const auto& cardinal_tensor_labels() { - return cardinal_tensor_labels_accessor(); - } - - /// @param labels a list of Tensor labels with lexicographic - /// preference (in order) - static void set_cardinal_tensor_labels( - const container::vector& labels) { - cardinal_tensor_labels_accessor() = labels; - } - - /// @return a side effect of canonicalization (e.g. phase), or nullptr if none - /// @internal what should be returned if canonicalization requires - /// complex conjugation? Special ExprPtr type (e.g. ConjOp)? Or the actual - /// return of the canonicalization? - /// @note canonicalization compared indices returned by index_comparer - // TODO generalize for complex tensors - virtual ExprPtr apply(AbstractTensor&) = 0; - - /// @return reference to the object used to compare Index objects - /// @note the default is to use an object of type `std::less` - static const std::function& - index_comparer(); - - /// @param comparer the compare object to be used by this - static void index_comparer( - std::function comparer); - - protected: - inline auto bra_range(AbstractTensor& t) { return t._bra_mutable(); } - inline auto ket_range(AbstractTensor& t) { return t._ket_mutable(); } - - /// the object used to compare indices - static std::function index_comparer_; - - private: - static std::pair< - container::map>*, - std::unique_lock> - instance_map_accessor(); // map* + locked recursive mutex - static container::vector& cardinal_tensor_labels_accessor(); -}; - -/// @brief null Tensor canonicalizer does nothing -class NullTensorCanonicalizer : public TensorCanonicalizer { - public: - virtual ~NullTensorCanonicalizer() = default; - - ExprPtr apply(AbstractTensor&) override; -}; - -class DefaultTensorCanonicalizer : public TensorCanonicalizer { - public: - DefaultTensorCanonicalizer() = default; - - /// @tparam IndexContainer a Container of Index objects such that @c - /// IndexContainer::value_type is convertible to Index (e.g. this can be - /// std::vector or std::set , but not std::map) - /// @param external_indices container of external Index objects - /// @warning @c external_indices is assumed to be immutable during the - /// lifetime of this object - template - DefaultTensorCanonicalizer(IndexContainer&& external_indices) { - ranges::for_each(external_indices, [this](const Index& idx) { - this->external_indices_.emplace(idx.label(), idx); - }); - } - virtual ~DefaultTensorCanonicalizer() = default; - - /// Implements TensorCanonicalizer::apply - /// @note Canonicalizes @c t by sorting its bra (if @c - /// t.symmetry()==Symmetry::nonsymm ) or its bra and ket (if @c - /// t.symmetry()!=Symmetry::nonsymm ), - /// with the external indices appearing "before" (smaller particle - /// indices) than the internal indices - ExprPtr apply(AbstractTensor& t) override; - - /// Core of DefaultTensorCanonicalizer::apply, only does the canonicalization, - /// i.e. no tagging/untagging - template - ExprPtr apply(AbstractTensor& t, const Compare& comp) { - // std::wcout << "abstract tensor: " << to_latex(t) << "\n"; - auto s = symmetry(t); - auto is_antisymm = (s == Symmetry::antisymm); - const auto _bra_rank = bra_rank(t); - const auto _ket_rank = ket_rank(t); - const auto _rank = std::min(_bra_rank, _ket_rank); - - // nothing to do for rank-1 tensors - if (_bra_rank == 1 && _ket_rank == 1) return nullptr; - - using ranges::begin; - using ranges::end; - using ranges::views::counted; - using ranges::views::take; - using ranges::views::zip; - - bool even = true; - switch (s) { - case Symmetry::antisymm: - case Symmetry::symm: { - auto _bra = bra_range(t); - auto _ket = ket_range(t); - // std::wcout << "canonicalizing " << to_latex(t); - IndexSwapper::thread_instance().reset(); - // std::{stable_}sort does not necessarily use swap! so must implement - // sort ourselves .. thankfully ranks will be low so can stick with - // bubble - bubble_sort(begin(_bra), end(_bra), comp); - bubble_sort(begin(_ket), end(_ket), comp); - if (is_antisymm) - even = IndexSwapper::thread_instance().even_num_of_swaps(); - // std::wcout << " is " << (even ? "even" : "odd") << " and - // produces " << to_latex(t) << std::endl; - } break; - - case Symmetry::nonsymm: { - // sort particles with bra and ket functions first, - // then the particles with either bra or ket index - auto _bra = bra_range(t); - auto _ket = ket_range(t); - auto _zip_braket = zip(take(_bra, _rank), take(_ket, _rank)); - bubble_sort(begin(_zip_braket), end(_zip_braket), comp); - if (_bra_rank > _rank) { - auto size_of_rest = _bra_rank - _rank; - auto rest_of = counted(begin(_bra) + _rank, size_of_rest); - bubble_sort(begin(rest_of), end(rest_of), comp); - } else if (_ket_rank > _rank) { - auto size_of_rest = _ket_rank - _rank; - auto rest_of = counted(begin(_ket) + _rank, size_of_rest); - bubble_sort(begin(rest_of), end(rest_of), comp); - } - } break; - - default: - abort(); - } - - ExprPtr result = - is_antisymm ? (even == false ? ex(-1) : nullptr) : nullptr; - return result; - } - - private: - container::map external_indices_; -}; - } // namespace sequant #endif // SEQUANT_ABSTRACT_TENSOR_HPP diff --git a/SeQuant/core/algorithm.hpp b/SeQuant/core/algorithm.hpp index da19e0c98..ff4143e40 100644 --- a/SeQuant/core/algorithm.hpp +++ b/SeQuant/core/algorithm.hpp @@ -5,6 +5,8 @@ #ifndef SEQUANT_ALGORITHM_HPP #define SEQUANT_ALGORITHM_HPP +#include + #include #include #include @@ -13,6 +15,10 @@ namespace sequant { +template +using suitable_call_operator = + decltype(std::declval()(std::declval()...)); + /// @brief bubble sort that uses swap exclusively template void bubble_sort(ForwardIter begin, Sentinel end, Compare comp) { @@ -22,29 +28,45 @@ void bubble_sort(ForwardIter begin, Sentinel end, Compare comp) { swapped = false; auto i = begin; auto inext = i; - // iterators either dereference to a reference or to a composite of - // references - constexpr const bool iter_deref_to_ref = - std::is_reference_v()))>; + + using deref_type = decltype(*(std::declval())); + constexpr const bool comp_works_for_range_type = + meta::is_detected_v; + for (++inext; inext != end; ++i, ++inext) { - if constexpr (iter_deref_to_ref) { - auto& val0 = *inext; - auto& val1 = *i; + if constexpr (comp_works_for_range_type) { + const auto& val0 = *inext; + const auto& val1 = *i; if (comp(val0, val1)) { - using std::swap; - swap(val1, val0); + // current assumption: whenever iter_swap from below does not fall + // back to std::iter_swap, we are handling zipped ranges where the + // tuple sizes is two (even) -> thus using a non-std swap + // implementation won't mess with the information of whether or not an + // even amount of swaps has occurred. + using ranges::iter_swap; + iter_swap(i, inext); swapped = true; } } else { - auto val0 = *inext; - auto val1 = *i; - static_assert(std::tuple_size_v == 2, + const auto& val0 = *inext; + const auto& val1 = *i; + static_assert(std::tuple_size_v> == 2, "need to generalize comparer to handle tuples"); + using lhs_type = decltype(std::get<0>(val0)); + using rhs_type = decltype(std::get<1>(val0)); + constexpr const bool comp_works_for_tuple_entries = + meta::is_detected_v; + static_assert(comp_works_for_tuple_entries, + "Provided comparator not suitable for entries in " + "tuple-like objects (in zipped range?)"); + auto composite_compare = [&comp](auto&& c0, auto&& c1) { if (comp(std::get<0>(c0), std::get<0>(c1))) { // c0[0] < c1[0] return true; - } else if (!(comp(std::get<0>(c1), - std::get<0>(c0)))) { // c0[0] == c1[0] + } else if (!comp(std::get<0>(c1), + std::get<0>(c0))) { // c0[0] == c1[0] return comp(std::get<1>(c0), std::get<1>(c1)); } else { // c0[0] > c1[0] return false; diff --git a/SeQuant/core/attr.hpp b/SeQuant/core/attr.hpp index 4131c1d1f..f864ccbf4 100644 --- a/SeQuant/core/attr.hpp +++ b/SeQuant/core/attr.hpp @@ -5,6 +5,7 @@ #ifndef SEQUANT_ATTR_HPP #define SEQUANT_ATTR_HPP +#include #include namespace sequant { @@ -49,11 +50,33 @@ inline std::wstring to_wolfram(const Symmetry& symmetry) { return result; } +inline std::wstring to_wstring(Symmetry sym) { + switch (sym) { + case Symmetry::symm: + return L"symmetric"; + case Symmetry::antisymm: + return L"antisymmetric"; + case Symmetry::nonsymm: + return L"nonsymmetric"; + case Symmetry::invalid: + return L"invalid"; + } + + assert(false); + std::abort(); +} + enum class BraKetPos { bra, ket, none }; inline std::wstring to_wolfram(BraKetPos a) { - using namespace std::literals; - return L"indexType["s + (a == BraKetPos::bra ? L"bra" : L"ket") + L"]"; + switch (a) { + case BraKetPos::bra: + return L"indexType[bra]"; + case BraKetPos::ket: + return L"indexType[ket]"; + case BraKetPos::none: + return L"indexType[none]"; + } } enum class Statistics { diff --git a/SeQuant/core/eval_expr.cpp b/SeQuant/core/eval_expr.cpp index 6418d27ad..c84f6c7ce 100644 --- a/SeQuant/core/eval_expr.cpp +++ b/SeQuant/core/eval_expr.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -17,9 +18,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -35,7 +38,7 @@ size_t hash_imed(EvalExpr const&, EvalExpr const&, EvalOp) noexcept; ExprPtr make_imed(EvalExpr const&, EvalExpr const&, EvalOp) noexcept; bool is_tot(Tensor const& t) noexcept { - return ranges::any_of(t.const_braket(), &Index::has_proto_indices); + return ranges::any_of(t.const_indices(), &Index::has_proto_indices); } std::wstring_view const var_label = L"Z"; @@ -58,6 +61,16 @@ NestedTensorIndices::NestedTensorIndices(const sequant::Tensor& tnsr) { append_unique(outer, ix); } +std::string to_label_annotation(const Index& idx) { + using namespace ranges::views; + using ranges::to; + + return sequant::to_string(idx.label()) + + (idx.proto_indices() | transform(&Index::label) | + transform([](auto&& str) { return sequant::to_string(str); }) | + ranges::views::join | to); +} + std::string EvalExpr::braket_annot() const noexcept { if (!is_tensor()) return {}; @@ -68,12 +81,9 @@ std::string EvalExpr::braket_annot() const noexcept { auto annot = [](auto&& ixs) -> std::string { using namespace ranges::views; - auto full_labels = ixs // - | transform(&Index::full_label) // - | transform([](auto&& fl) { // - return sequant::to_string(fl); - }); - return full_labels // + auto annotations = ixs | transform(to_label_annotation); + + return annotations // | intersperse(std::string{","}) // | join // | ranges::to; @@ -190,9 +200,9 @@ namespace { /// the hash. /// template -size_t hash_braket(T const& bk) noexcept { +size_t hash_indices(T const& indices) noexcept { size_t h = 0; - for (auto const& idx : bk) { + for (auto const& idx : indices) { hash::combine(h, hash::value(idx.space().type().to_int32())); hash::combine(h, hash::value(idx.space().qns().to_int32())); if (idx.has_proto_indices()) { @@ -208,7 +218,7 @@ size_t hash_braket(T const& bk) noexcept { /// \return hash value to identify the connectivity between a pair of tensors. /// /// @note Let [(i,j)] be the list of ordered pair of index positions that are -/// connected. i is the position in the braket of the first tensor (T1) +/// connected. i is the position in the indices of the first tensor (T1) /// and j is that of the second tensor (T2). Then this function combines /// the hash values of the elements of this list. /// @@ -217,8 +227,8 @@ size_t hash_braket(T const& bk) noexcept { size_t hash_tensor_pair_topology(Tensor const& t1, Tensor const& t2) noexcept { using ranges::views::enumerate; size_t h = 0; - for (auto&& [pos1, idx1] : t1.const_braket() | enumerate) - for (auto&& [pos2, idx2] : t2.const_braket() | enumerate) + for (auto&& [pos1, idx1] : t1.const_indices() | enumerate) + for (auto&& [pos2, idx2] : t2.const_indices() | enumerate) if (idx1.label() == idx2.label()) hash::combine(h, hash::value(std::pair(pos1, pos2))); return h; @@ -227,7 +237,7 @@ size_t hash_tensor_pair_topology(Tensor const& t1, Tensor const& t2) noexcept { size_t hash_terminal_tensor(Tensor const& tnsr) noexcept { size_t h = 0; hash::combine(h, hash::value(tnsr.label())); - hash::combine(h, hash_braket(tnsr.const_braket())); + hash::combine(h, hash_indices(tnsr.const_indices())); return h; } @@ -325,11 +335,11 @@ Symmetry tensor_symmetry_prod(EvalExpr const& left, if (hash::value(left) == hash::value(right)) { // potential outer product of the same tensor auto const uniq_idxs = - ranges::views::concat(tnsr1.const_braket(), tnsr2.const_braket()) | + ranges::views::concat(tnsr1.const_indices(), tnsr2.const_indices()) | ranges::to; if (static_cast(ranges::distance(uniq_idxs)) == - tnsr1.const_braket().size() + tnsr2.const_braket().size()) { + tnsr1.const_indices().size() + tnsr2.const_indices().size()) { // outer product confirmed return Symmetry::antisymm; } @@ -376,7 +386,7 @@ ExprPtr make_sum(EvalExpr const& left, EvalExpr const& right) noexcept { auto ts = tensor_symmetry_sum(left, right); auto ps = particle_symmetry(ts); auto bks = get_default_context().braket_symmetry(); - return ex(L"I", t1.bra(), t1.ket(), ts, bks, ps); + return ex(L"I", t1.bra(), t1.ket(), t1.aux(), ts, bks, ps); } ExprPtr make_prod(EvalExpr const& left, EvalExpr const& right) noexcept { @@ -385,8 +395,8 @@ ExprPtr make_prod(EvalExpr const& left, EvalExpr const& right) noexcept { auto const& t1 = left.as_tensor(); auto const& t2 = right.as_tensor(); - auto [b, k] = target_braket(t1, t2); - if (b.empty() && k.empty()) { + auto [b, k, a] = get_uncontracted_indices(t1, t2); + if (b.empty() && k.empty() && a.empty()) { // dot product return ex(var_label); } else { @@ -394,7 +404,8 @@ ExprPtr make_prod(EvalExpr const& left, EvalExpr const& right) noexcept { auto ts = tensor_symmetry_prod(left, right); auto ps = particle_symmetry(ts); auto bks = get_default_context().braket_symmetry(); - return ex(L"I", bra(std::move(b)), ket(std::move(k)), ts, bks, ps); + return ex(L"I", bra(std::move(b)), ket(std::move(k)), + aux(std::move(a)), ts, bks, ps); } } @@ -415,7 +426,7 @@ ExprPtr make_imed(EvalExpr const& left, EvalExpr const& right, assert(op == EvalOp::Prod && "scalar + tensor not supported"); auto const& t = right.expr()->as(); - return ex(Tensor{L"I", t.bra(), t.ket(), t.symmetry(), + return ex(Tensor{L"I", t.bra(), t.ket(), t.aux(), t.symmetry(), t.braket_symmetry(), t.particle_symmetry()}); } else if (lres == ResultType::Tensor && rres == ResultType::Scalar) { diff --git a/SeQuant/core/export/itf.cpp b/SeQuant/core/export/itf.cpp index e0427392e..cfe63b28e 100644 --- a/SeQuant/core/export/itf.cpp +++ b/SeQuant/core/export/itf.cpp @@ -45,7 +45,7 @@ Tensor generateResultTensor(ExprPtr expr) { IndexGroups externals = get_unique_indices(expr); return Tensor(L"Result", bra(std::move(externals.bra)), - ket(std::move(externals.ket))); + ket(std::move(externals.ket)), aux(std::move(externals.aux))); } Result::Result(ExprPtr expression, bool importResultTensor) @@ -126,6 +126,9 @@ std::vector to_contractions(const Product &product, intermediateIndices.insert(intermediateIndices.end(), intermediateIndexGroups.ket.begin(), intermediateIndexGroups.ket.end()); + intermediateIndices.insert(intermediateIndices.end(), + intermediateIndexGroups.aux.begin(), + intermediateIndexGroups.aux.end()); std::sort(intermediateIndices.begin(), intermediateIndices.end(), [](const Index &lhs, const Index &rhs) { IndexTypeComparer cmp; @@ -254,6 +257,7 @@ void one_electron_integral_remapper( auto braIndices = tensor.bra(); auto ketIndices = tensor.ket(); + assert(tensor.aux().empty()); IndexTypeComparer cmp; @@ -269,7 +273,7 @@ void one_electron_integral_remapper( } expr = ex(tensor.label(), bra(std::move(braIndices)), - ket(std::move(ketIndices))); + ket(std::move(ketIndices)), tensor.aux()); } template @@ -314,6 +318,7 @@ void two_electron_integral_remapper( // Copy indices as we might have to mutate them auto braIndices = tensor.bra(); auto ketIndices = tensor.ket(); + assert(tensor.aux().empty()); IndexTypeComparer cmp; @@ -404,7 +409,7 @@ void two_electron_integral_remapper( } expr = ex(std::move(tensorLabel), bra(std::move(braIndices)), - ket(std::move(ketIndices))); + ket(std::move(ketIndices)), tensor.aux()); } void integral_remapper(ExprPtr &expr, std::wstring_view oneElectronIntegralName, diff --git a/SeQuant/core/expr.cpp b/SeQuant/core/expr.cpp index 6af8a9c38..728e71eed 100644 --- a/SeQuant/core/expr.cpp +++ b/SeQuant/core/expr.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include diff --git a/SeQuant/core/expr.hpp b/SeQuant/core/expr.hpp index f66a9004f..e0bccba1a 100644 --- a/SeQuant/core/expr.hpp +++ b/SeQuant/core/expr.hpp @@ -266,18 +266,18 @@ class Expr : public std::enable_shared_from_this, std::const_pointer_cast(this->shared_from_this())); } - /// Canonicalizes @c this and returns the biproduct of canonicalization (e.g. + /// Canonicalizes @c this and returns the byproduct of canonicalization (e.g. /// phase) - /// @return the biproduct of canonicalization, or @c nullptr if no biproduct + /// @return the byproduct of canonicalization, or @c nullptr if no byproduct /// generated virtual ExprPtr canonicalize() { return {}; // by default do nothing and return nullptr } /// Performs approximate, but fast, canonicalization of @c this and returns - /// the biproduct of canonicalization (e.g. phase) The default is to use + /// the byproduct of canonicalization (e.g. phase) The default is to use /// canonicalize(), unless overridden in the derived class. - /// @return the biproduct of canonicalization, or @c nullptr if no biproduct + /// @return the byproduct of canonicalization, or @c nullptr if no byproduct /// generated virtual ExprPtr rapid_canonicalize() { return this->canonicalize(); } @@ -1141,7 +1141,12 @@ class Product : public Expr { if (!scalar().is_zero()) { const auto scal = negate ? -scalar() : scalar(); if (!scal.is_identity()) { - result += sequant::to_latex(scal); + // replace -1 prefactor by - + if (!(negate ? scalar() : -scalar()).is_identity()) { + result += sequant::to_latex(scal); + } else { + result += L"{-}"; + } } for (const auto &i : factors()) { if (i->is()) diff --git a/SeQuant/core/expr_algorithm.hpp b/SeQuant/core/expr_algorithm.hpp index faccf6fd4..0a3baeef9 100644 --- a/SeQuant/core/expr_algorithm.hpp +++ b/SeQuant/core/expr_algorithm.hpp @@ -15,9 +15,9 @@ namespace sequant { /// _replaced_ (i.e. `&expr` may be mutated by call) /// @return \p expr to facilitate chaining inline ExprPtr& canonicalize(ExprPtr& expr) { - const auto biproduct = expr->canonicalize(); - if (biproduct && biproduct->is()) { - expr = biproduct * expr; + const auto byproduct = expr->canonicalize(); + if (byproduct && byproduct->is()) { + expr = byproduct * expr; } return expr; } @@ -27,9 +27,9 @@ inline ExprPtr& canonicalize(ExprPtr& expr) { /// @param[in] expr_rv rvalue-ref-to-expression to be canonicalized /// @return canonicalized form of \p expr_rv inline ExprPtr canonicalize(ExprPtr&& expr_rv) { - const auto biproduct = expr_rv->canonicalize(); - if (biproduct && biproduct->is()) { - expr_rv = biproduct * expr_rv; + const auto byproduct = expr_rv->canonicalize(); + if (byproduct && byproduct->is()) { + expr_rv = byproduct * expr_rv; } return std::move(expr_rv); } diff --git a/SeQuant/core/hash.cpp b/SeQuant/core/hash.cpp new file mode 100644 index 000000000..2b5f84836 --- /dev/null +++ b/SeQuant/core/hash.cpp @@ -0,0 +1,8 @@ +#include +#include + +namespace sequant { + +std::size_t hash_value(const ExprPtr &expr) { return hash_value(*expr); } + +} // namespace sequant diff --git a/SeQuant/core/hash.hpp b/SeQuant/core/hash.hpp index 7c2b48d60..7dd413b7a 100644 --- a/SeQuant/core/hash.hpp +++ b/SeQuant/core/hash.hpp @@ -22,6 +22,8 @@ namespace sequant_boost = boost; namespace sequant { +class ExprPtr; + namespace hash { /// the hashing versions known to SeQuant (N.B. hashing changed in Boost 1.81) @@ -89,6 +91,8 @@ auto hash_value(const T& obj) { return sequant_boost::hash_value(obj); } +std::size_t hash_value(const ExprPtr& expr); + // clang-format off // rationale: // boost::hash_combine is busted ... it dispatches to one of 3 implementations (all line numbers refer to boost 1.72.0): diff --git a/SeQuant/core/index.hpp b/SeQuant/core/index.hpp index 2865786cb..e3ddbbab0 100644 --- a/SeQuant/core/index.hpp +++ b/SeQuant/core/index.hpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -366,6 +367,20 @@ class Index : public Taggable { return make_split_label(this->label()); } + /// + /// \return The numeric suffix if present in the label. + /// + std::optional suffix() const { + auto &&[_, s_] = split_label(); + auto &&s = sequant::to_string(s_); + + int value{}; + if (std::from_chars(s.data(), s.data() + s.size(), value).ec == std::errc{}) + return value; + else + return std::nullopt; + } + /// @return A string label representable in ASCII encoding /// @warning not to be used with proto indices /// @brief Replaces non-ascii wstring characters with human-readable analogs, @@ -389,10 +404,12 @@ class Index : public Taggable { std::wstring_view full_label() const { if (!has_proto_indices()) return label(); if (full_label_) return *full_label_; - std::wstring result = label_; - ranges::for_each(proto_indices_, [&result](const Index &idx) { - result += idx.full_label(); - }); + std::wstring result = label_ + L"<"; + result += + ranges::views::transform(proto_indices_, + [](const Index &idx) { return idx.label(); }) | + ranges::views::join(L", ") | ranges::to(); + result += L">"; full_label_ = result; return *full_label_; } @@ -762,32 +779,39 @@ class Index : public Taggable { friend bool operator<(const Index &i1, const Index &i2) { // compare qns, tags and spaces in that sequence - auto i1_Q = i1.space().qns(); - auto i2_Q = i2.space().qns(); - auto compare_space = [&i1, &i2]() { - if (i1.space() == i2.space()) { - if (i1.label() == i2.label()) { - return i1.proto_indices() < i2.proto_indices(); - } else { - return i1.label() < i2.label(); - } - } else { + if (i1.space() != i2.space()) { return i1.space() < i2.space(); } + + if (i1.label() != i2.label()) { + // Note: Can't simply use label1 < label2 as that won't yield expected + // results for e.g. i2 < i11 (which will yield false) + if (i1.label().size() != i2.label().size()) { + return i1.label().size() < i2.label().size(); + } + + return i1.label() < i2.label(); + } + + return i1.proto_indices() < i2.proto_indices(); }; + const auto i1_Q = i1.space().qns(); + const auto i2_Q = i2.space().qns(); + if (i1_Q == i2_Q) { const bool have_tags = i1.tag().has_value() && i2.tag().has_value(); if (!have_tags || i1.tag() == i2.tag()) { + // Note that comparison of index spaces contains comparison of QNs return compare_space(); - } else { - return i1.tag() < i2.tag(); } - } else { - return i1_Q < i2_Q; + + return i1.tag() < i2.tag(); } + + return i1_Q < i2_Q; } }; // class Index diff --git a/SeQuant/core/logger.hpp b/SeQuant/core/logger.hpp index 3c68abf90..c2ba917ce 100644 --- a/SeQuant/core/logger.hpp +++ b/SeQuant/core/logger.hpp @@ -21,6 +21,7 @@ struct Logger : public Singleton { bool wick_stats = false; bool expand = false; bool canonicalize = false; + bool canonicalize_input_graph = false; bool canonicalize_dot = false; bool simplify = false; bool tensor_network = false; diff --git a/SeQuant/core/op.hpp b/SeQuant/core/op.hpp index 4001e9a21..8b63d5762 100644 --- a/SeQuant/core/op.hpp +++ b/SeQuant/core/op.hpp @@ -788,13 +788,20 @@ class NormalOperator : public Operator, ranges::views::transform( [](auto &&op) -> const Index & { return op.index(); }); } + AbstractTensor::const_any_view_randsz _aux() const override final { + return {}; + } AbstractTensor::const_any_view_rand _braket() const override final { return ranges::views::concat(annihilators(), creators()) | ranges::views::transform( [](auto &&op) -> const Index & { return op.index(); }); } + AbstractTensor::const_any_view_rand _indices() const override final { + return _braket(); + } std::size_t _bra_rank() const override final { return nannihilators(); } std::size_t _ket_rank() const override final { return ncreators(); } + std::size_t _aux_rank() const override final { return 0; } Symmetry _symmetry() const override final { return (S == Statistics::FermiDirac ? (get_default_context(S).spbasis() == SPBasis::spinorbital @@ -842,8 +849,18 @@ class NormalOperator : public Operator, ranges::views::transform( [](auto &&op) -> Index & { return op.index(); }); } + AbstractTensor::any_view_randsz _aux_mutable() override final { return {}; } }; +static_assert( + is_tensor_v>, + "The NormalOperator class does not fulfill the " + "requirements of the Tensor interface"); +static_assert( + is_tensor_v>, + "The NormalOperator class does not fulfill the " + "requirements of the Tensor interface"); + template bool operator==(const NormalOperator &op1, const NormalOperator &op2) { using base_type = Operator; diff --git a/SeQuant/core/optimize.hpp b/SeQuant/core/optimize.hpp index 2f5932d45..7491610cf 100644 --- a/SeQuant/core/optimize.hpp +++ b/SeQuant/core/optimize.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #if __cplusplus >= 202002L #include @@ -51,6 +52,23 @@ class Tensor; namespace opt { +/// +/// \param idxsz An invocable that returns size_t for Index argument. +/// \param idxs Index objects. +/// \return flops count +/// +template , + bool> = true> +double ops_count(IdxToSz const& idxsz, Idxs const& idxs) { + auto oixs = tot_indices(idxs); + double ops = 1.0; + for (auto&& idx : ranges::views::concat(oixs.outer, oixs.inner)) + ops *= std::invoke(idxsz, idx); + // ops == 1.0 implies zero flops. + return ops == 1.0 ? 0 : ops; +} + namespace { /// @@ -84,26 +102,6 @@ void biparts(I n, F const& func) { } } -/// -/// \tparam IdxToSz map-like {IndexSpace : size_t} -/// \param idxsz see @c IdxToSz -/// \param commons Index objects -/// \param diffs Index objects -/// \return flops count -/// @note @c commons and @c diffs have unique indices individually as well as -/// combined -template , - bool> = true> -double ops_count(IdxToSz const& idxsz, container::svector const& commons, - container::svector const& diffs) { - double ops = 1.0; - for (auto&& idx : ranges::views::concat(commons, diffs)) - ops *= std::invoke(idxsz, idx); - // ops == 1.0 implies both commons and diffs empty - return ops == 1.0 ? 0 : ops; -} - /// /// any element in the vector belongs to the integral range [-1,N) /// where N is the length of the [Expr] (ie. the iterable of expressions) @@ -138,20 +136,20 @@ struct OptRes { /// @note I1 and I2 containers are assumed to be sorted by using /// Index::LabelCompare{}; /// -template +template > container::svector common_indices(I1 const& idxs1, I2 const& idxs2) { using std::back_inserter; using std::begin; using std::end; using std::set_intersection; - assert(std::is_sorted(begin(idxs1), end(idxs1), Index::LabelCompare{})); - assert(std::is_sorted(begin(idxs2), end(idxs2), Index::LabelCompare{})); + assert(std::is_sorted(begin(idxs1), end(idxs1), Comp{})); + assert(std::is_sorted(begin(idxs2), end(idxs2), Comp{})); container::svector result; set_intersection(begin(idxs1), end(idxs1), begin(idxs2), end(idxs2), - back_inserter(result), Index::LabelCompare{}); + back_inserter(result), Comp{}); return result; } @@ -162,20 +160,20 @@ container::svector common_indices(I1 const& idxs1, I2 const& idxs2) { /// @note I1 and I2 containers are assumed to be sorted by using /// Index::LabelCompare{}; /// -template +template > container::svector diff_indices(I1 const& idxs1, I2 const& idxs2) { using std::back_inserter; using std::begin; using std::end; using std::set_symmetric_difference; - assert(std::is_sorted(begin(idxs1), end(idxs1), Index::LabelCompare{})); - assert(std::is_sorted(begin(idxs2), end(idxs2), Index::LabelCompare{})); + assert(std::is_sorted(begin(idxs1), end(idxs1), Comp{})); + assert(std::is_sorted(begin(idxs2), end(idxs2), Comp{})); container::svector result; set_symmetric_difference(begin(idxs1), end(idxs1), begin(idxs2), end(idxs2), - back_inserter(result), Index::LabelCompare{}); + back_inserter(result), Comp{}); return result; } @@ -190,21 +188,23 @@ template , bool> = true> eval_seq_t single_term_opt(TensorNetwork const& network, IdxToSz const& idxsz) { + using ranges::views::concat; + using IndexContainer = container::svector; // number of terms auto const nt = network.tensors().size(); if (nt == 1) return eval_seq_t{0}; if (nt == 2) return eval_seq_t{0, 1, -1}; - auto nth_tensor_indices = container::svector>{}; + auto nth_tensor_indices = container::svector{}; nth_tensor_indices.reserve(nt); for (std::size_t i = 0; i < nt; ++i) { auto const& tnsr = *network.tensors().at(i); - auto bk = container::svector{}; - bk.reserve(bra_rank(tnsr) + ket_rank(tnsr)); - for (auto&& idx : braket(tnsr)) bk.push_back(idx); - ranges::sort(bk, Index::LabelCompare{}); - nth_tensor_indices.emplace_back(std::move(bk)); + nth_tensor_indices.emplace_back(); + auto& ixs = nth_tensor_indices.back(); + for (auto&& j : indices(tnsr)) ixs.emplace_back(j); + + ranges::sort(ixs, std::less{}); } container::svector results((1 << nt), OptRes{{}, 0, {}}); @@ -230,9 +230,9 @@ eval_seq_t single_term_opt(TensorNetwork const& network, IdxToSz const& idxsz) { auto commons = common_indices(results[lpart].indices, results[rpart].indices); auto diffs = diff_indices(results[lpart].indices, results[rpart].indices); - auto new_cost = ops_count(idxsz, // - commons, diffs) // - + results[lpart].flops // + auto new_cost = ops_count(idxsz, // + concat(commons, diffs)) // + + results[lpart].flops // + results[rpart].flops; if (new_cost <= curr_cost) { curr_cost = new_cost; @@ -256,10 +256,9 @@ eval_seq_t single_term_opt(TensorNetwork const& network, IdxToSz const& idxsz) { auto const& first = results[curr_parts.first].sequence; auto const& second = results[curr_parts.second].sequence; - curr_result.sequence = - (first[0] < second[0] ? ranges::views::concat(first, second) - : ranges::views::concat(second, first)) | - ranges::to; + curr_result.sequence = (first[0] < second[0] ? concat(first, second) + : concat(second, first)) | + ranges::to; curr_result.sequence.push_back(-1); } } @@ -345,11 +344,14 @@ Sum reorder(Sum const& sum); /// /// \param expr Expression to be optimized. /// \param idxsz An invocable object that maps an Index object to size. +/// \param reorder_sum If true, the summands are reordered so that terms with +/// common sub-expressions appear closer to each other. /// \return Optimized expression for lower evaluation cost. template >> -ExprPtr optimize(ExprPtr const& expr, IdxToSize const& idx2size) { +ExprPtr optimize(ExprPtr const& expr, IdxToSize const& idx2size, + bool reorder_sum) { using ranges::views::transform; if (expr->is()) return expr->clone(); @@ -360,7 +362,7 @@ ExprPtr optimize(ExprPtr const& expr, IdxToSize const& idx2size) { return optimize(s, idx2size); }) | ranges::to_vector; auto sum = Sum{smands.begin(), smands.end()}; - return ex(opt::reorder(sum)); + return reorder_sum ? ex(opt::reorder(sum)) : ex(std::move(sum)); } else throw std::runtime_error{"Optimization attempted on unsupported Expr type"}; } @@ -372,8 +374,38 @@ ExprPtr optimize(ExprPtr const& expr, IdxToSize const& idx2size) { /// index extent. /// /// \param expr Expression to be optimized. +/// \param reorder_sum If true, the summands are reordered so that terms with +/// common sub-expressions appear closer to each other. +/// True by default. /// \return Optimized expression for lower evaluation cost. -ExprPtr optimize(ExprPtr const& expr); +ExprPtr optimize(ExprPtr const& expr, bool reorder_sum = true); + +/// +/// Converts the 4-center 'g' tensors into a product of two rank-3 tensors. +/// +/// \param expr The expression to be density-fit. +/// \param aux_label The label of the introduced auxilliary index. eg. 'x', 'p'. +/// \return The density-fit expression if 'g' of rank-4 present, otherwise the +/// input expression itself will be returned. +/// +ExprPtr density_fit(ExprPtr const& expr, std::wstring const& aux_label); + +/// +/// Converts the tensors in CSV basis into a product of full-basis +/// tensors times the CSV-transformation tensors. +/// +/// \param expr The expression to be CSV-transformed. +/// \param coeff_tensor_label The label of the CSV-tranformation tensors that +/// will be introduced. +/// \param csv_tensors The label of the CSV-basis tensors that will be +/// written as the transformed product. Eg. 'f', 'g'. +/// \return The CSV-transformed expression if CSV-tensors with labels present +/// in @c csv_tensors appear in @c expr. Otherwise returns the input +/// expression itself. +ExprPtr csv_transform(ExprPtr const& expr, + std::wstring const& coeff_tensor_label = L"C", + container::svector const& csv_tensors = { + L"f", L"g"}); } // namespace sequant diff --git a/SeQuant/core/optimize/optimize.cpp b/SeQuant/core/optimize/optimize.cpp index 86a7df5e2..b53236b62 100644 --- a/SeQuant/core/optimize/optimize.cpp +++ b/SeQuant/core/optimize/optimize.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -158,9 +159,143 @@ Sum reorder(Sum const& sum) { } // namespace opt -ExprPtr optimize(ExprPtr const& expr) { +ExprPtr optimize(ExprPtr const& expr, bool reorder_sum) { return opt::optimize( - expr, [](Index const& ix) { return ix.space().approximate_size(); }); + expr, [](Index const& ix) { return ix.space().approximate_size(); }, + reorder_sum); +} + +ExprPtr density_fit_impl(Tensor const& tnsr, Index const& aux_idx) { + assert(tnsr.bra_rank() == 2 // + && tnsr.ket_rank() == 2 // + && tnsr.aux_rank() == 0); + + auto t1 = ex(L"g", bra({ranges::front(tnsr.bra())}), + ket({ranges::front(tnsr.ket())}), aux({aux_idx})); + + auto t2 = ex(L"g", bra({ranges::back(tnsr.bra())}), + ket({ranges::back(tnsr.ket())}), aux({aux_idx})); + + return ex(1, ExprPtrList{t1, t2}); +} + +ExprPtr density_fit(ExprPtr const& expr, std::wstring const& aux_label) { + using ranges::views::transform; + if (expr->is()) + return ex(*expr | transform([&aux_label](auto&& x) { + return density_fit(x, aux_label); + })); + + else if (expr->is()) { + auto const& g = expr->as(); + if (g.label() == L"g" // + && g.bra_rank() == 2 // + && g.ket_rank() == 2 // + && ranges::none_of(g.indices(), &Index::has_proto_indices)) + return density_fit_impl(expr->as(), Index(aux_label + L"_1")); + else + return expr; + } else if (expr->is()) { + auto const& prod = expr->as(); + + Product result; + result.scale(prod.scalar()); + size_t aux_ix = 0; + for (auto&& f : prod.factors()) + if (f.is() && f.as().label() == L"g") { + auto const& g = f->as(); + auto g_df = density_fit_impl( + g, Index(aux_label + L"_" + std::to_wstring(++aux_ix))); + result.append(1, std::move(g_df), Product::Flatten::Yes); + } else { + result.append(1, f, Product::Flatten::No); + } + return ex(std::move(result)); + } else + return expr; +} + +ExprPtr csv_transform_impl(Tensor const& tnsr, + std::wstring_view coeff_tensor_label) { + using ranges::views::transform; + + if (ranges::none_of(tnsr.const_braket(), &Index::has_proto_indices)) + return nullptr; + + //// + auto drop_protos = [](auto&& ixs) { + return ixs | transform(&Index::drop_proto_indices); + }; + //// + + if (tnsr.label() == overlap_label()) { + assert(tnsr.bra_rank() == 1 // + && tnsr.ket_rank() == 1 // + && tnsr.aux_rank() == 0); + + auto&& bra_idx = tnsr.bra().at(0); + auto&& ket_idx = tnsr.ket().at(0); + + auto dummy_idx = suffix_compare(bra_idx, ket_idx) // + ? bra_idx.drop_proto_indices() // + : ket_idx.drop_proto_indices(); + + return ex( + 1, + ExprPtrList{ex(coeff_tensor_label, // + bra({bra_idx}), ket({dummy_idx})), // + ex(coeff_tensor_label, // + bra({dummy_idx}), ket({ket_idx}))}); + } + + Product result; + result.append(1, ex(tnsr.label(), bra(drop_protos(tnsr.bra())), + ket(drop_protos(tnsr.ket())), tnsr.aux())); + + for (auto&& idx : tnsr.bra()) + if (idx.has_proto_indices()) + result.append(1, ex(coeff_tensor_label, bra({idx}), + ket({idx.drop_proto_indices()}), aux({}))); + for (auto&& idx : tnsr.ket()) + if (idx.has_proto_indices()) + result.append( + 1, ex(coeff_tensor_label, bra({idx.drop_proto_indices()}), + ket({idx}), aux({}))); + + return ex(std::move(result)); +} + +ExprPtr csv_transform(ExprPtr const& expr, + std::wstring const& coeff_tensor_label, + container::svector const& csv_tensors) { + using ranges::views::transform; + if (expr->is()) + return ex(*expr // + | transform([&coeff_tensor_label, // + &csv_tensors](auto&& x) { + return csv_transform(x, coeff_tensor_label, csv_tensors); + })); + else if (expr->is()) { + auto const& tnsr = expr->as(); + if (!ranges::contains(csv_tensors, tnsr.label())) return expr; + return csv_transform_impl(tnsr, coeff_tensor_label); + } else if (expr->is()) { + auto const& prod = expr->as(); + + Product result; + result.scale(prod.scalar()); + + for (auto&& f : prod.factors()) { + auto trans = csv_transform(f, coeff_tensor_label, csv_tensors); + result.append(1, trans ? trans : f, + (f->is() || f->is()) ? Product::Flatten::No + : Product::Flatten::Yes); + } + + return ex(std::move(result)); + + } else + return expr; } } // namespace sequant diff --git a/SeQuant/core/parse.hpp b/SeQuant/core/parse.hpp index df9dc9f79..8804abfb7 100644 --- a/SeQuant/core/parse.hpp +++ b/SeQuant/core/parse.hpp @@ -44,7 +44,25 @@ struct ParseError : std::runtime_error { /// '1.0/2.0 * t{i1;a1} * f{i1; a1}' same as above /// 't{i1,i2; a1, a2}' a tensor having indices with proto indices. /// a1 is an index with i1 and i2 as proto-indices. -/// \param tensor_sym The symmetry of all atomic tensors in the +/// Every tensor may optionally be annoted with index symmetry specifications. The general syntax is +/// [: [- [-]]] +/// (no whitespace is allowed at this place). Examples are +/// 't{i1;i2}:A', 't{i1;i2}:A-S', 't{i1;i2}:N-C-S' +/// Possible values for are +/// - 'A' for antisymmetry (sequant::Symmetry::antisymm) +/// - 'S' for symmetric (sequant::Symmetry::symm) +/// - 'N' for non-symmetric (sequant::Symmetry::nonsymm) +/// Possible values for are +/// - 'C' for antisymmetry (sequant::BraKetSymmetry::conjugate) +/// - 'S' for symmetric (sequant::BraKetSymmetry::symm) +/// - 'N' for non-symmetric (sequant::BraKetSymmetry::nonsymm) +/// Possible values for are +/// - 'S' for symmetric (sequant::ParticleSymmetry::symm) +/// - 'N' for non-symmetric (sequant::ParticleSymmetry::nonsymm) +/// \param perm_symm Default index permutation symmetry to be used if tensors don't specify a permutation +/// symmetry explicitly. +/// \param braket_symm Default BraKet symmetry to be used if tensors don't specify a BraKet symmetry explicitly. +/// \param particle_symm Default particle symmetry to be used if tensors don't specify a particle symmetry explicitly. /// @c raw expression. Explicit tensor symmetry can /// be annotated in the expression itself. In that case, the /// annotated symmetry will be used. @@ -54,7 +72,9 @@ struct ParseError : std::runtime_error { /// \return SeQuant expression. // clang-format on ExprPtr parse_expr(std::wstring_view raw, - Symmetry tensor_sym = Symmetry::nonsymm); + Symmetry perm_symm = Symmetry::nonsymm, + BraKetSymmetry braket_symm = BraKetSymmetry::nonsymm, + ParticleSymmetry particle_symm = ParticleSymmetry::symm); /// /// Get a parsable string from an expression. diff --git a/SeQuant/core/parse/ast.hpp b/SeQuant/core/parse/ast.hpp index ad148b8d6..521289652 100644 --- a/SeQuant/core/parse/ast.hpp +++ b/SeQuant/core/parse/ast.hpp @@ -7,6 +7,7 @@ #define BOOST_SPIRIT_X3_UNICODE #include +#include #include #include #include @@ -55,26 +56,34 @@ struct Variable : boost::spirit::x3::position_tagged { struct IndexGroups : boost::spirit::x3::position_tagged { std::vector bra; std::vector ket; + std::vector auxiliaries; bool reverse_bra_ket; IndexGroups(std::vector bra = {}, std::vector ket = {}, - bool reverse_bra_ket = {}) + std::vector auxiliaries = {}, bool reverse_bra_ket = {}) : bra(std::move(bra)), ket(std::move(ket)), + auxiliaries(std::move(auxiliaries)), reverse_bra_ket(reverse_bra_ket) {} }; +struct SymmetrySpec : boost::spirit::x3::position_tagged { + static constexpr char unspecified = '\0'; + char perm_symm = unspecified; + char braket_symm = unspecified; + char particle_symm = unspecified; +}; + struct Tensor : boost::spirit::x3::position_tagged { - static constexpr char unspecified_symmetry = '\0'; std::wstring name; IndexGroups indices; - char symmetry; + boost::optional symmetry; Tensor(std::wstring name = {}, IndexGroups indices = {}, - char symmetry = unspecified_symmetry) + boost::optional symmetry = {}) : name(std::move(name)), indices(std::move(indices)), - symmetry(symmetry) {} + symmetry(std::move(symmetry)) {} }; struct Product; @@ -122,7 +131,9 @@ BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Index, label, protoLabels); BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Number, numerator, denominator); BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Variable, name, conjugated); BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::IndexGroups, bra, ket, - reverse_bra_ket); + auxiliaries, reverse_bra_ket); +BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::SymmetrySpec, perm_symm, + braket_symm, particle_symm); BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Tensor, name, indices, symmetry); BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Product, factors); diff --git a/SeQuant/core/parse/ast_conversions.hpp b/SeQuant/core/parse/ast_conversions.hpp index e41d0c574..f47580d7b 100644 --- a/SeQuant/core/parse/ast_conversions.hpp +++ b/SeQuant/core/parse/ast_conversions.hpp @@ -22,6 +22,9 @@ namespace sequant::parse::transform { +using DefaultSymmetries = + std::tuple; + template std::tuple get_pos(const AST &ast, const PositionCache &cache, @@ -77,11 +80,14 @@ Index to_index(const parse::ast::Index &index, } template -std::tuple, container::vector> make_indices( - const parse::ast::IndexGroups &groups, const PositionCache &position_cache, - const Iterator &begin) { +std::tuple, container::vector, + container::vector> +make_indices(const parse::ast::IndexGroups &groups, + const PositionCache &position_cache, const Iterator &begin) { container::vector braIndices; container::vector ketIndices; + container::vector auxiliaries; + container::vector auxIndices; static_assert(std::is_same_v, "Types for bra and ket indices must be equal for pointer " @@ -103,14 +109,17 @@ std::tuple, container::vector> make_indices( for (const parse::ast::Index ¤t : *ket) { ketIndices.push_back(to_index(current, position_cache, begin)); } + for (const parse::ast::Index ¤t : groups.auxiliaries) { + auxiliaries.push_back(to_index(current, position_cache, begin)); + } - return {std::move(braIndices), std::move(ketIndices)}; + return {std::move(braIndices), std::move(ketIndices), std::move(auxiliaries)}; } template -Symmetry to_symmetry(char c, std::size_t offset, const Iterator &begin, - Symmetry default_symmetry) { - if (c == parse::ast::Tensor::unspecified_symmetry) { +Symmetry to_perm_symmetry(char c, std::size_t offset, const Iterator &begin, + Symmetry default_symmetry) { + if (c == parse::ast::SymmetrySpec::unspecified) { return default_symmetry; } @@ -130,6 +139,52 @@ Symmetry to_symmetry(char c, std::size_t offset, const Iterator &begin, std::string("Invalid symmetry specifier '") + c + "'"); } +template +BraKetSymmetry to_braket_symmetry(char c, std::size_t offset, + const Iterator &begin, + BraKetSymmetry default_symmetry) { + if (c == parse::ast::SymmetrySpec::unspecified) { + return default_symmetry; + } + + switch (c) { + case 'C': + case 'c': + return BraKetSymmetry::conjugate; + case 'S': + case 's': + return BraKetSymmetry::symm; + case 'N': + case 'n': + return BraKetSymmetry::nonsymm; + } + + throw ParseError( + offset, 1, std::string("Invalid BraKet symmetry specifier '") + c + "'"); +} + +template +ParticleSymmetry to_particle_symmetry(char c, std::size_t offset, + const Iterator &begin, + ParticleSymmetry default_symmetry) { + if (c == parse::ast::SymmetrySpec::unspecified) { + return default_symmetry; + } + + switch (c) { + case 'S': + case 's': + return ParticleSymmetry::symm; + case 'N': + case 'n': + return ParticleSymmetry::nonsymm; + } + + throw ParseError( + offset, 1, + std::string("Invalid particle symmetry specifier '") + c + "'"); +} + template Constant to_constant(const parse::ast::Number &number, const PositionCache &position_cache, @@ -146,53 +201,81 @@ Constant to_constant(const parse::ast::Number &number, } } +template +std::tuple to_symmetries( + const boost::optional &symm_spec, + const DefaultSymmetries &default_symms, const PositionCache &cache, + const Iterator &begin) { + if (!symm_spec.has_value()) { + return {std::get<0>(default_symms), std::get<1>(default_symms), + std::get<2>(default_symms)}; + } + + const ast::SymmetrySpec &spec = symm_spec.get(); + + auto [offset, length] = get_pos(spec, cache, begin); + + // Note: symmetry specifications are a separator (colon or dash) followed by + // an uppercase letter each (no whitespace allowed in-between) + Symmetry perm_symm = to_perm_symmetry(spec.perm_symm, offset + 1, begin, + std::get<0>(default_symms)); + BraKetSymmetry braket_symm = to_braket_symmetry( + spec.braket_symm, offset + 3, begin, std::get<1>(default_symms)); + ParticleSymmetry particle_symm = to_particle_symmetry( + spec.particle_symm, offset + 5, begin, std::get<2>(default_symms)); + + return {perm_symm, braket_symm, particle_symm}; +} + template ExprPtr ast_to_expr(const parse::ast::Product &product, const PositionCache &position_cache, const Iterator &begin, - Symmetry default_symmetry); + const DefaultSymmetries &default_symms); template ExprPtr ast_to_expr(const parse::ast::Sum &sum, const PositionCache &position_cache, const Iterator &begin, - Symmetry default_symmetry); + const DefaultSymmetries &default_symms); template ExprPtr ast_to_expr(const parse::ast::NullaryValue &value, const PositionCache &position_cache, const Iterator &begin, - Symmetry default_symmetry) { + const DefaultSymmetries &default_symms) { struct Transformer { std::reference_wrapper position_cache; std::reference_wrapper begin; - std::reference_wrapper default_symmetry; + std::reference_wrapper default_symms; ExprPtr operator()(const parse::ast::Product &product) const { return ast_to_expr(product, position_cache.get(), - begin.get(), default_symmetry); + begin.get(), default_symms.get()); } ExprPtr operator()(const parse::ast::Sum &sum) const { return ast_to_expr(sum, position_cache.get(), begin.get(), - default_symmetry); + default_symms.get()); } ExprPtr operator()(const parse::ast::Tensor &tensor) const { - auto [braIndices, ketIndices] = + auto [braIndices, ketIndices, auxiliaries] = make_indices(tensor.indices, position_cache.get(), begin.get()); - auto [offset, length] = - get_pos(tensor, position_cache.get(), begin.get()); + auto [perm_symm, braket_symm, particle_symm] = + to_symmetries(tensor.symmetry, default_symms.get(), + position_cache.get(), begin.get()); return ex(tensor.name, bra(std::move(braIndices)), - ket(std::move(ketIndices)), - to_symmetry(tensor.symmetry, offset + length - 1, - begin.get(), default_symmetry)); + ket(std::move(ketIndices)), aux(std::move(auxiliaries)), + perm_symm, braket_symm, particle_symm); } ExprPtr operator()(const parse::ast::Variable &variable) const { + ExprPtr var = ex(variable.name); + if (variable.conjugated) { - return ex(variable.name + L"^*"); - } else { - return ex(variable.name); + var->as().conjugate(); } + + return var; } ExprPtr operator()(const parse::ast::Number &number) const { @@ -203,7 +286,7 @@ ExprPtr ast_to_expr(const parse::ast::NullaryValue &value, return boost::apply_visitor( Transformer{std::ref(position_cache), std::ref(begin), - std::ref(default_symmetry)}, + std::ref(default_symms)}, value); } @@ -215,7 +298,7 @@ bool holds_alternative(const boost::variant &v) noexcept { template ExprPtr ast_to_expr(const parse::ast::Product &product, const PositionCache &position_cache, const Iterator &begin, - Symmetry default_symmetry) { + const DefaultSymmetries &default_symms) { if (product.factors.empty()) { // This shouldn't happen assert(false); @@ -225,7 +308,7 @@ ExprPtr ast_to_expr(const parse::ast::Product &product, if (product.factors.size() == 1) { return ast_to_expr(product.factors.front(), position_cache, begin, - default_symmetry); + default_symms); } std::vector factors; @@ -239,7 +322,7 @@ ExprPtr ast_to_expr(const parse::ast::Product &product, position_cache, begin); } else { factors.push_back( - ast_to_expr(value, position_cache, begin, default_symmetry)); + ast_to_expr(value, position_cache, begin, default_symms)); } } @@ -259,13 +342,13 @@ ExprPtr ast_to_expr(const parse::ast::Product &product, template ExprPtr ast_to_expr(const parse::ast::Sum &sum, const PositionCache &position_cache, const Iterator &begin, - Symmetry default_symmetry) { + const DefaultSymmetries &default_symms) { if (sum.summands.empty()) { return {}; } if (sum.summands.size() == 1) { return ast_to_expr(sum.summands.front(), position_cache, begin, - default_symmetry); + default_symms); } std::vector summands; @@ -273,7 +356,7 @@ ExprPtr ast_to_expr(const parse::ast::Sum &sum, std::transform( sum.summands.begin(), sum.summands.end(), std::back_inserter(summands), [&](const parse::ast::Product &product) { - return ast_to_expr(product, position_cache, begin, default_symmetry); + return ast_to_expr(product, position_cache, begin, default_symms); }); return ex(std::move(summands)); diff --git a/SeQuant/core/parse/deparse.cpp b/SeQuant/core/parse/deparse.cpp index ea4087030..24538a8b3 100644 --- a/SeQuant/core/parse/deparse.cpp +++ b/SeQuant/core/parse/deparse.cpp @@ -37,8 +37,8 @@ std::wstring deparse_indices(const Range& indices) { return deparsed; } -std::wstring deparse_sym(Symmetry sym) { - switch (sym) { +std::wstring deparse_symm(Symmetry symm) { + switch (symm) { case Symmetry::symm: return L"S"; case Symmetry::antisymm: @@ -53,6 +53,36 @@ std::wstring deparse_sym(Symmetry sym) { return L"INVALIDANDUNREACHABLE"; } +std::wstring deparse_symm(BraKetSymmetry symm) { + switch (symm) { + case BraKetSymmetry::conjugate: + return L"C"; + case BraKetSymmetry::symm: + return L"S"; + case BraKetSymmetry::nonsymm: + return L"N"; + case BraKetSymmetry::invalid: + return L"INVALID"; + } + + assert(false); + return L"INVALIDANDUNREACHABLE"; +} + +std::wstring deparse_symm(ParticleSymmetry symm) { + switch (symm) { + case ParticleSymmetry::symm: + return L"S"; + case ParticleSymmetry::nonsymm: + return L"N"; + case ParticleSymmetry::invalid: + return L"INVALID"; + } + + assert(false); + return L"INVALIDANDUNREACHABLE"; +} + std::wstring deparse_scalar(const Constant::scalar_type& scalar) { const auto& real = scalar.real(); const auto& realNumerator = boost::multiprecision::numerator(real); @@ -96,18 +126,23 @@ std::wstring deparse_scalar(const Constant::scalar_type& scalar) { } // namespace details std::wstring deparse(const ExprPtr& expr, bool annot_sym) { - using namespace details; if (!expr) return {}; - if (expr->is()) - return deparse(expr->as(), annot_sym); - else if (expr->is()) - return deparse(expr->as(), annot_sym); - else if (expr->is()) - return deparse(expr->as(), annot_sym); - else if (expr->is()) - return deparse(expr->as()); - else if (expr->is()) - return deparse(expr->as()); + + return deparse(*expr, annot_sym); +} + +std::wstring deparse(const Expr& expr, bool annot_sym) { + using namespace details; + if (expr.is()) + return deparse(expr.as(), annot_sym); + else if (expr.is()) + return deparse(expr.as(), annot_sym); + else if (expr.is()) + return deparse(expr.as(), annot_sym); + else if (expr.is()) + return deparse(expr.as()); + else if (expr.is()) + return deparse(expr.as()); else throw std::runtime_error("Unsupported expr type for deparse!"); } @@ -137,10 +172,18 @@ std::wstring deparse(Tensor const& tensor, bool annot_sym) { if (tensor.ket_rank() > 0) { deparsed += L";" + details::deparse_indices(tensor.ket()); } + if (tensor.aux_rank() > 0) { + if (tensor.ket_rank() == 0) { + deparsed += L";"; + } + deparsed += L";" + details::deparse_indices(tensor.aux()); + } deparsed += L"}"; if (annot_sym) { - deparsed += L":" + details::deparse_sym(tensor.symmetry()); + deparsed += L":" + details::deparse_symm(tensor.symmetry()); + deparsed += L"-" + details::deparse_symm(tensor.braket_symmetry()); + deparsed += L"-" + details::deparse_symm(tensor.particle_symmetry()); } return deparsed; @@ -151,7 +194,7 @@ std::wstring deparse(const Constant& constant) { } std::wstring deparse(const Variable& variable) { - return std::wstring(variable.label()); + return std::wstring(variable.label()) + (variable.conjugated() ? L"^*" : L""); } std::wstring deparse(Product const& prod, bool annot_sym) { diff --git a/SeQuant/core/parse/parse.cpp b/SeQuant/core/parse/parse.cpp index 5aa528fdb..0a4c4e23e 100644 --- a/SeQuant/core/parse/parse.cpp +++ b/SeQuant/core/parse/parse.cpp @@ -47,6 +47,7 @@ struct ExprRule; struct IndexLabelRule; struct IndexRule; struct IndexGroupRule; +struct SymmetrySpecRule; // Types x3::rule number{"Number"}; @@ -63,6 +64,7 @@ x3::rule name{"Name"}; x3::rule index_label{"IndexLabel"}; x3::rule index{"Index"}; x3::rule index_groups{"IndexGroups"}; +x3::rule symmetry_spec{"SymmetrySpec"}; auto to_char_type = [](auto c) { return static_cast(c); @@ -97,15 +99,20 @@ auto index_label_def = x3::lexeme[ ]; auto index_def = x3::lexeme[ - index_label >> -('<' >> index_label % ',' >> ">") + index_label >> -x3::skip['<' >> index_label % ',' >> ">"] ]; -auto index_groups_def = L"_{" > -(index % ',') > L"}^{" > -(index % ',') > L"}" >> x3::attr(false) - | L"^{" > -(index % ',') > L"}_{" > -(index % ',') > L"}" >> x3::attr(true) - | '{' > -(index % ',') > ';' > -(index % ',') > '}' >> x3::attr(false); +const std::vector noIndices; +auto index_groups_def = L"_{" > -(index % ',') > L"}^{" > -(index % ',') > L"}" >> x3::attr(noIndices) >> x3::attr(false) + | L"^{" > -(index % ',') > L"}_{" > -(index % ',') > L"}" >> x3::attr(noIndices) >> x3::attr(true) + | '{' > -(index % ',') > -( ';' > -(index % ',')) > -(';' > -(index % ',')) > '}' >> x3::attr(false); + +auto symmetry_spec_def= x3::lexeme[ + ':' >> x3::upper >> -('-' >> x3::upper) >> -('-' >> x3::upper) + ]; auto tensor_def = x3::lexeme[ - name >> x3::skip[index_groups] >> -(':' >> x3::upper) + name >> x3::skip[index_groups] >> -(symmetry_spec) ]; auto nullary = number | tensor | variable; @@ -124,7 +131,7 @@ auto expr_def = -sum > x3::eoi; // clang-format on BOOST_SPIRIT_DEFINE(name, number, variable, index_label, index, index_groups, - tensor, product, sum, expr); + tensor, product, sum, expr, symmetry_spec); struct position_cache_tag; struct error_handler_tag; @@ -162,6 +169,7 @@ struct ExprRule : helpers::annotate_position, helpers::error_handler {}; struct IndexLabelRule : helpers::annotate_position, helpers::error_handler {}; struct IndexRule : helpers::annotate_position, helpers::error_handler {}; struct IndexGroupRule : helpers::annotate_position, helpers::error_handler {}; +struct SymmetrySpecRule : helpers::annotate_position, helpers::error_handler {}; } // namespace parse @@ -179,7 +187,8 @@ struct ErrorHandler { } }; -ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) { +ExprPtr parse_expr(std::wstring_view input, Symmetry perm_symm, + BraKetSymmetry braket_symm, ParticleSymmetry particle_symm) { using iterator_type = decltype(input)::iterator; x3::position_cache> positions(input.begin(), input.end()); @@ -216,8 +225,10 @@ ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) { throw; } - return parse::transform::ast_to_expr(ast, positions, input.begin(), - default_symmetry); + return parse::transform::ast_to_expr( + ast, positions, input.begin(), + parse::transform::DefaultSymmetries{perm_symm, braket_symm, + particle_symm}); } } // namespace sequant diff --git a/SeQuant/core/tensor.cpp b/SeQuant/core/tensor.cpp index 045a7c818..6974b2fc0 100644 --- a/SeQuant/core/tensor.cpp +++ b/SeQuant/core/tensor.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace sequant { diff --git a/SeQuant/core/tensor.hpp b/SeQuant/core/tensor.hpp index 45fea7740..b62cb0ea0 100644 --- a/SeQuant/core/tensor.hpp +++ b/SeQuant/core/tensor.hpp @@ -38,8 +38,11 @@ namespace sequant { DEFINE_STRONG_TYPE_FOR_RANGE_AND_RANGESIZE(bra); // strong type wrapper for objects associated with ket DEFINE_STRONG_TYPE_FOR_RANGE_AND_RANGESIZE(ket); +// strong type wrapper for objects associated with aux +DEFINE_STRONG_TYPE_FOR_RANGE_AND_RANGESIZE(aux); -/// @brief particle-symmetric Tensor, i.e. permuting +/// @brief a Tensor is an instance of AbstractTensor over a scalar field, i.e. +/// Tensors have commutative addition and product operations class Tensor : public Expr, public AbstractTensor, public Labeled { private: using index_container_type = container::svector; @@ -73,9 +76,12 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { } } - /// @return view of the bra+ket index ranges + /// @return concatenated view of the bra and ket index ranges auto braket() { return ranges::views::concat(bra_, ket_); } + /// @return concatenated view of bra, ket, and aux index ranges + auto indices() { return ranges::views::concat(bra_, ket_, aux_); } + /// asserts that @p label is not reserved /// @note Tensor with reserved labels are constructed using friends of Tensor /// @param label a Tensor label candidate @@ -87,22 +93,28 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { // list of friends who can make Tensor objects with reserved labels friend ExprPtr make_overlap(const Index &bra_index, const Index &ket_index); - template < - typename IndexRange1, typename IndexRange2, - typename = std::enable_if_t< - (meta::is_statically_castable_v< - meta::range_value_t, - Index>)&&(meta:: - is_statically_castable_v< - meta::range_value_t, Index>)>> + template , + Index>)&&(meta:: + is_statically_castable_v< + meta::range_value_t, + Index>)&&(meta:: + is_statically_castable_v< + meta::range_value_t< + IndexRange3>, + Index>)>> Tensor(std::wstring_view label, const bra &bra_indices, - const ket &ket_indices, reserved_tag, + const ket &ket_indices, + const aux &aux_indices, reserved_tag, Symmetry s = Symmetry::nonsymm, BraKetSymmetry bks = get_default_context().braket_symmetry(), ParticleSymmetry ps = ParticleSymmetry::symm) : label_(label), bra_(make_indices(bra_indices)), ket_(make_indices(ket_indices)), + aux_(make_indices(aux_indices)), symmetry_(s), braket_symmetry_(bks), particle_symmetry_(ps) { @@ -110,13 +122,15 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { } Tensor(std::wstring_view label, bra &&bra_indices, - ket &&ket_indices, reserved_tag, + ket &&ket_indices, + aux &&aux_indices, reserved_tag, Symmetry s = Symmetry::nonsymm, BraKetSymmetry bks = get_default_context().braket_symmetry(), ParticleSymmetry ps = ParticleSymmetry::symm) : label_(label), bra_(std::move(bra_indices)), ket_(std::move(ket_indices)), + aux_(std::move(aux_indices)), symmetry_(s), braket_symmetry_(bks), particle_symmetry_(ps) { @@ -149,7 +163,40 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { const ket &ket_indices, Symmetry s = Symmetry::nonsymm, BraKetSymmetry bks = get_default_context().braket_symmetry(), ParticleSymmetry ps = ParticleSymmetry::symm) - : Tensor(label, bra_indices, ket_indices, reserved_tag{}, s, bks, ps) { + : Tensor(label, bra_indices, ket_indices, sequant::aux{}, reserved_tag{}, + s, bks, ps) { + assert_nonreserved_label(label_); + } + + /// @param label the tensor label + /// @param bra_indices list of bra indices (or objects that can be converted + /// to indices) + /// @param ket_indices list of ket indices (or objects that can be converted + /// to indices) + /// @param aux_indices list of aux indices (or objects that can be + /// converted to indices) + /// @param s the symmetry of bra or ket + /// @param bks the symmetry with respect to bra-ket exchange + /// @param ps the symmetry under exchange of particles + template , + Index>)&&(meta:: + is_statically_castable_v< + meta::range_value_t, + Index>)&&(meta:: + is_statically_castable_v< + meta::range_value_t< + IndexRange3>, + Index>)>> + Tensor(std::wstring_view label, const bra &bra_indices, + const ket &ket_indices, + const aux &aux_indices, Symmetry s = Symmetry::nonsymm, + BraKetSymmetry bks = get_default_context().braket_symmetry(), + ParticleSymmetry ps = ParticleSymmetry::symm) + : Tensor(label, bra_indices, ket_indices, aux_indices, reserved_tag{}, s, + bks, ps) { assert_nonreserved_label(label_); } @@ -167,7 +214,28 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { BraKetSymmetry bks = get_default_context().braket_symmetry(), ParticleSymmetry ps = ParticleSymmetry::symm) : Tensor(label, std::move(bra_indices), std::move(ket_indices), - reserved_tag{}, s, bks, ps) { + sequant::aux{}, reserved_tag{}, s, bks, ps) { + assert_nonreserved_label(label_); + } + + /// @param label the tensor label + /// @param bra_indices list of bra indices (or objects that can be converted + /// to indices) + /// @param ket_indices list of ket indices (or objects that can be converted + /// to indices) + /// @param aux_indices list of aux indices (or objects that can be + /// converted to indices) + /// @param s the symmetry of bra or ket + /// @param bks the symmetry with respect to bra-ket exchange + /// @param ps the symmetry under exchange of particles + Tensor(std::wstring_view label, bra &&bra_indices, + ket &&ket_indices, + aux &&aux_indices, + Symmetry s = Symmetry::nonsymm, + BraKetSymmetry bks = get_default_context().braket_symmetry(), + ParticleSymmetry ps = ParticleSymmetry::symm) + : Tensor(label, std::move(bra_indices), std::move(ket_indices), + std::move(aux_indices), reserved_tag{}, s, bks, ps) { assert_nonreserved_label(label_); } @@ -180,13 +248,23 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { /// @return "core" label of the tensor std::wstring_view label() const override { return label_; } + /// @return the bra index range const auto &bra() const { return bra_; } + /// @return the ket index range const auto &ket() const { return ket_; } - /// @return joined view of the bra and ket index ranges + /// @return the aux index range + const auto &aux() const { return aux_; } + /// @return concatenated view of the bra and ket index ranges auto braket() const { return ranges::views::concat(bra_, ket_); } + /// @return concatenated view of all indices of this tensor (bra, ket and + /// aux) + auto indices() const { return ranges::views::concat(bra_, ket_, aux_); } /// @return view of the bra+ket index ranges /// @note this is to work around broken lookup rules auto const_braket() const { return this->braket(); } + /// @return view of all indices + /// @note this is to work around broken lookup rules + auto const_indices() const { return this->indices(); } /// Returns the Symmetry object describing the symmetry of the bra and ket of /// the Tensor, i.e. what effect swapping indices in positions @c i and @c j /// in either bra or ket has on the elements of the Tensor; @@ -208,6 +286,8 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { std::size_t bra_rank() const { return bra_.size(); } /// @return number of ket indices std::size_t ket_rank() const { return ket_.size(); } + /// @return number of aux indices + std::size_t aux_rank() const { return aux_.size(); } /// @return number of indices in bra/ket /// @throw std::logic_error if bra and ket ranks do not match std::size_t rank() const { @@ -231,7 +311,20 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { for (const auto &i : this->ket()) result += sequant::to_latex(i); result += L"}_{"; for (const auto &i : this->bra()) result += sequant::to_latex(i); - result += L"}}"; + result += L"}"; + if (!this->aux_.empty()) { + result += L"["; + const index_container_type &__aux = this->aux(); + for (std::size_t i = 0; i < aux_rank(); ++i) { + result += sequant::to_latex(__aux[i]); + + if (i + 1 < aux_rank()) { + result += L","; + } + } + result += L"]"; + } + result += L"}"; return result; } @@ -247,7 +340,7 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { typename... Args> bool transform_indices(const Map &index_map) { bool mutated = false; - ranges::for_each(braket(), [&](auto &idx) { + ranges::for_each(indices(), [&](auto &idx) { if (idx.transform(index_map)) mutated = true; }); if (mutated) this->reset_hash_value(); @@ -259,7 +352,7 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { ExprPtr clone() const override { return ex(*this); } void reset_tags() const { - ranges::for_each(braket(), [](const auto &idx) { idx.reset_tag(); }); + ranges::for_each(indices(), [](const auto &idx) { idx.reset_tag(); }); } hash_type bra_hash_value() const { @@ -272,6 +365,7 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { std::wstring label_{}; sequant::bra bra_{}; sequant::ket ket_{}; + sequant::aux aux_{}; Symmetry symmetry_ = Symmetry::invalid; BraKetSymmetry braket_symmetry_ = BraKetSymmetry::invalid; ParticleSymmetry particle_symmetry_ = ParticleSymmetry::invalid; @@ -292,6 +386,7 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { auto val = hash::range(begin(bra()), end(bra())); bra_hash_value_ = val; hash::range(val, begin(ket()), end(ket())); + hash::range(val, begin(aux()), end(aux())); hash::combine(val, label_); hash::combine(val, symmetry_); hash_value_ = val; @@ -307,7 +402,8 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { if (this->label() == that_cast.label() && this->symmetry() == that_cast.symmetry() && this->bra_rank() == that_cast.bra_rank() && - this->ket_rank() == that_cast.ket_rank()) { + this->ket_rank() == that_cast.ket_rank() && + this->aux_rank() == that_cast.aux_rank()) { // compare hash values first if (this->hash_value() == that.hash_value()) // hash values agree -> do full comparison @@ -319,32 +415,43 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { } bool static_less_than(const Expr &that) const override { - const auto &that_cast = static_cast(that); if (this == &that) return false; - if (this->label() == that_cast.label()) { - if (this->bra_rank() == that_cast.bra_rank()) { - if (this->ket_rank() == that_cast.ket_rank()) { - // v1: compare hashes only - // return Expr::static_less_than(that); - // v2: compare fully - if (this->bra_hash_value() == that_cast.bra_hash_value()) { - return std::lexicographical_compare( - this->ket().begin(), this->ket().end(), that_cast.ket().begin(), - that_cast.ket().end()); - } else { - return std::lexicographical_compare( - this->bra().begin(), this->bra().end(), that_cast.bra().begin(), - that_cast.bra().end()); - } - } else { - return this->ket_rank() < that_cast.ket_rank(); - } - } else { - return this->bra_rank() < that_cast.bra_rank(); - } - } else { + + const auto &that_cast = static_cast(that); + if (this->label() != that_cast.label()) { return this->label() < that_cast.label(); } + + if (this->bra_rank() != that_cast.bra_rank()) { + return this->bra_rank() < that_cast.bra_rank(); + } + + if (this->ket_rank() != that_cast.ket_rank()) { + return this->ket_rank() < that_cast.ket_rank(); + } + + if (this->aux_rank() != that_cast.aux_rank()) { + return this->aux_rank() < that_cast.aux_rank(); + } + + // v1: compare hashes only + // return Expr::static_less_than(that); + // v2: compare fully + if (this->bra_hash_value() != that_cast.bra_hash_value()) { + return std::lexicographical_compare( + this->bra().begin(), this->bra().end(), that_cast.bra().begin(), + that_cast.bra().end()); + } + + if (this->ket() != that_cast.ket()) { + return std::lexicographical_compare( + this->ket().begin(), this->ket().end(), that_cast.ket().begin(), + that_cast.ket().end()); + } + + return std::lexicographical_compare(this->aux().begin(), this->aux().end(), + that_cast.aux().begin(), + that_cast.aux().end()); } // these implement the AbstractTensor interface @@ -356,11 +463,19 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { return ranges::counted_view( ket_.empty() ? nullptr : &(ket_[0]), ket_.size()); } + AbstractTensor::const_any_view_randsz _aux() const override final { + return ranges::counted_view( + aux_.empty() ? nullptr : &(aux_[0]), aux_.size()); + } AbstractTensor::const_any_view_rand _braket() const override final { return braket(); } + AbstractTensor::const_any_view_rand _indices() const override final { + return indices(); + } std::size_t _bra_rank() const override final { return bra_rank(); } std::size_t _ket_rank() const override final { return ket_rank(); } + std::size_t _aux_rank() const override final { return aux_rank(); } Symmetry _symmetry() const override final { return symmetry_; } BraKetSymmetry _braket_symmetry() const override final { return braket_symmetry_; @@ -396,9 +511,18 @@ class Tensor : public Expr, public AbstractTensor, public Labeled { return ranges::counted_view(ket_.empty() ? nullptr : &(ket_[0]), ket_.size()); } + AbstractTensor::any_view_randsz _aux_mutable() override final { + this->reset_hash_value(); + return ranges::counted_view(aux_.empty() ? nullptr : &(aux_[0]), + aux_.size()); + } }; // class Tensor +static_assert(is_tensor_v, + "The Tensor class does not fulfill the requirements of the " + "Tensor interface"); + using TensorPtr = std::shared_ptr; /// make_overlap tensor label is reserved since it is used by low-level SeQuant @@ -407,7 +531,7 @@ inline std::wstring overlap_label() { return L"s"; } inline ExprPtr make_overlap(const Index &bra_index, const Index &ket_index) { return ex(Tensor(overlap_label(), bra{bra_index}, ket{ket_index}, - Tensor::reserved_tag{})); + aux{}, Tensor::reserved_tag{})); } } // namespace sequant diff --git a/SeQuant/core/tensor_canonicalizer.cpp b/SeQuant/core/tensor_canonicalizer.cpp new file mode 100644 index 000000000..de7961e06 --- /dev/null +++ b/SeQuant/core/tensor_canonicalizer.cpp @@ -0,0 +1,234 @@ +// +// Created by Eduard Valeyev on 2019-03-24. +// + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace sequant { + +template +using get_support = decltype(std::get<0>(std::declval())); + +template +constexpr bool is_tuple_like_v = meta::is_detected_v; + +struct TensorBlockIndexComparer { + template + bool operator()(const T& lhs, const T& rhs) const { + return compare(lhs, rhs) < 0; + } + + template + int compare(const T& lhs, const T& rhs) const { + if constexpr (is_tuple_like_v) { + static_assert( + std::tuple_size_v == 2, + "TensorBlockIndexComparer can only deal with tuple-like objects " + "of size 2"); + const auto& lhs_first = std::get<0>(lhs); + const auto& lhs_second = std::get<1>(lhs); + const auto& rhs_first = std::get<0>(rhs); + const auto& rhs_second = std::get<1>(rhs); + + static_assert(std::is_same_v, Index>, + "TensorBlockIndexComparer can only work with indices"); + static_assert(std::is_same_v, Index>, + "TensorBlockIndexComparer can only work with indices"); + static_assert(std::is_same_v, Index>, + "TensorBlockIndexComparer can only work with indices"); + static_assert(std::is_same_v, Index>, + "TensorBlockIndexComparer can only work with indices"); + + int res = lhs_first < rhs_first ? -1 : (rhs_first < lhs_first ? 1 : 0); + if (res != 0) { + return res; + } + + res = lhs_second < rhs_second ? -1 : (rhs_second < lhs_second ? 1 : 0); + return res; + } else { + static_assert(std::is_same_v, Index>, + "TensorBlockIndexComparer can only work with indices"); + + int res = lhs < rhs ? -1 : (rhs < lhs ? 1 : 0); + return res; + } + } +}; + +struct TensorIndexComparer { + template + bool operator()(const T& lhs, const T& rhs) const { + TensorBlockIndexComparer block_comp; + + int res = block_comp.compare(lhs, rhs); + + if (res != 0) { + return res < 0; + } + + // Fall back to regular index compare to break the tie + if constexpr (is_tuple_like_v) { + static_assert(std::tuple_size_v == 2, + "TensorIndexComparer can only deal with tuple-like objects " + "of size 2"); + + const Index& lhs_first = std::get<0>(lhs); + const Index& lhs_second = std::get<1>(lhs); + const Index& rhs_first = std::get<0>(rhs); + const Index& rhs_second = std::get<1>(rhs); + + if (lhs_first != rhs_first) { + return lhs_first < rhs_first; + } + + return lhs_second < rhs_second; + } else { + return lhs < rhs; + } + } +}; + +TensorCanonicalizer::~TensorCanonicalizer() = default; + +std::pair>*, + std::unique_lock> +TensorCanonicalizer::instance_map_accessor() { + static container::map> + map_; + static std::recursive_mutex mtx_; + return std::make_pair(&map_, std::unique_lock{mtx_}); +} + +container::vector& +TensorCanonicalizer::cardinal_tensor_labels_accessor() { + static container::vector ctlabels_; + return ctlabels_; +} + +std::shared_ptr +TensorCanonicalizer::nondefault_instance_ptr(std::wstring_view label) { + auto&& [map_ptr, lock] = instance_map_accessor(); + // look for label-specific canonicalizer + auto it = map_ptr->find(std::wstring{label}); + if (it != map_ptr->end()) { + return it->second; + } else + return {}; +} + +std::shared_ptr TensorCanonicalizer::instance_ptr( + std::wstring_view label) { + auto result = nondefault_instance_ptr(label); + if (!result) // not found? look for default + result = nondefault_instance_ptr(L""); + return result; +} + +std::shared_ptr TensorCanonicalizer::instance( + std::wstring_view label) { + auto inst_ptr = instance_ptr(label); + if (!inst_ptr) + throw std::runtime_error( + "must first register canonicalizer via " + "TensorCanonicalizer::register_instance(...)"); + return inst_ptr; +} + +void TensorCanonicalizer::register_instance( + std::shared_ptr can, std::wstring_view label) { + auto&& [map_ptr, lock] = instance_map_accessor(); + (*map_ptr)[std::wstring{label}] = can; +} + +bool TensorCanonicalizer::try_register_instance( + std::shared_ptr can, std::wstring_view label) { + auto&& [map_ptr, lock] = instance_map_accessor(); + if (!map_ptr->contains(std::wstring{label})) { + (*map_ptr)[std::wstring{label}] = can; + return true; + } else + return false; +} + +void TensorCanonicalizer::deregister_instance(std::wstring_view label) { + auto&& [map_ptr, lock] = instance_map_accessor(); + auto it = map_ptr->find(std::wstring{label}); + if (it != map_ptr->end()) { + map_ptr->erase(it); + } +} + +TensorCanonicalizer::index_comparer_t TensorCanonicalizer::index_comparer_ = + TensorIndexComparer{}; + +TensorCanonicalizer::index_pair_comparer_t + TensorCanonicalizer::index_pair_comparer_ = TensorIndexComparer{}; + +const TensorCanonicalizer::index_comparer_t& +TensorCanonicalizer::index_comparer() { + return index_comparer_; +} + +void TensorCanonicalizer::index_comparer(index_comparer_t comparer) { + index_comparer_ = std::move(comparer); +} + +const TensorCanonicalizer::index_pair_comparer_t& +TensorCanonicalizer::index_pair_comparer() { + return index_pair_comparer_; +} + +void TensorCanonicalizer::index_pair_comparer(index_pair_comparer_t comparer) { + index_pair_comparer_ = std::move(comparer); +} + +ExprPtr NullTensorCanonicalizer::apply(AbstractTensor&) const { return {}; } + +void DefaultTensorCanonicalizer::tag_indices(AbstractTensor& t) const { + // tag all indices as ext->true/ind->false + ranges::for_each(indices(t), [this](auto& idx) { + auto it = external_indices_.find(std::wstring(idx.label())); + auto is_ext = it != external_indices_.end(); + idx.tag().assign( + is_ext ? 0 : 1); // ext -> 0, int -> 1, so ext will come before + }); +} + +ExprPtr DefaultTensorCanonicalizer::apply(AbstractTensor& t) const { + tag_indices(t); + + auto result = + this->apply(t, this->index_comparer_, this->index_pair_comparer_); + + reset_tags(t); + + return result; +} + +template +using suitable_call_operator = + decltype(std::declval()(std::declval()...)); + +ExprPtr TensorBlockCanonicalizer::apply(AbstractTensor& t) const { + tag_indices(t); + + auto result = DefaultTensorCanonicalizer::apply(t, TensorBlockIndexComparer{}, + TensorBlockIndexComparer{}); + + reset_tags(t); + + return result; +} + +} // namespace sequant diff --git a/SeQuant/core/tensor_canonicalizer.hpp b/SeQuant/core/tensor_canonicalizer.hpp new file mode 100644 index 000000000..a8214476b --- /dev/null +++ b/SeQuant/core/tensor_canonicalizer.hpp @@ -0,0 +1,247 @@ +// +// Created by Robert Adam on 2023-09-08 +// + +#ifndef SEQUANT_CORE_TENSOR_CANONICALIZER_HPP +#define SEQUANT_CORE_TENSOR_CANONICALIZER_HPP + +#include "abstract_tensor.hpp" +#include "expr.hpp" + +#include +#include + +namespace sequant { + +/// @brief Base class for Tensor canonicalizers +/// To make custom canonicalizer make a derived class and register an instance +/// of that class with TensorCanonicalizer::register_instance +class TensorCanonicalizer { + public: + using index_comparer_t = std::function; + using index_pair_t = std::pair; + using index_pair_comparer_t = + std::function; + + virtual ~TensorCanonicalizer(); + + /// @return ptr to the TensorCanonicalizer object, if any, that had been + /// previously registered via TensorCanonicalizer::register_instance() + /// with @c label , or to the default canonicalizer, if any + static std::shared_ptr instance_ptr( + std::wstring_view label = L""); + + /// @return ptr to the TensorCanonicalizer object, if any, that had been + /// previously registered via TensorCanonicalizer::register_instance() + /// with @c label + /// @sa instance_ptr + static std::shared_ptr nondefault_instance_ptr( + std::wstring_view label); + + /// @return a TensorCanonicalizer previously registered via + /// TensorCanonicalizer::register_instance() with @c label or to the default + /// canonicalizer + /// @throw std::runtime_error if no canonicalizer has been registered + static std::shared_ptr instance( + std::wstring_view label = L""); + + /// registers @c canonicalizer to be applied to Tensor objects with label + /// @c label ; leave the label empty if @c canonicalizer is to apply to Tensor + /// objects with any label + /// @note if a canonicalizer registered with label @c label exists, it is + /// replaced + static void register_instance( + std::shared_ptr canonicalizer, + std::wstring_view label = L""); + + /// tries to register @c canonicalizer to be applied to Tensor objects + /// with label @c label ; leave the label empty if @c canonicalizer is to + /// apply to Tensor objects with any label + /// @return false if there is already a canonicalizer registered with @c label + /// @sa regiter_instance + static bool try_register_instance( + std::shared_ptr canonicalizer, + std::wstring_view label = L""); + + /// deregisters canonicalizer (if any) registered previously + /// to be applied to tensors with label @c label + static void deregister_instance(std::wstring_view label = L""); + + /// @return a list of Tensor labels with lexicographic preference (in order) + static const auto& cardinal_tensor_labels() { + return cardinal_tensor_labels_accessor(); + } + + /// @param cardinal_tensor_labels a list of Tensor labels with lexicographic + /// preference (in order) + static void set_cardinal_tensor_labels( + const container::vector& labels) { + cardinal_tensor_labels_accessor() = labels; + } + + /// @return a side effect of canonicalization (e.g. phase), or nullptr if none + /// @internal what should be returned if canonicalization requires + /// complex conjugation? Special ExprPtr type (e.g. ConjOp)? Or the actual + /// return of the canonicalization? + /// @note canonicalization compared indices returned by index_comparer + // TODO generalize for complex tensors + virtual ExprPtr apply(AbstractTensor&) const = 0; + + /// @return reference to the object used to compare Index objects + static const index_comparer_t& index_comparer(); + + /// @param comparer the compare object to be used by this + static void index_comparer(index_comparer_t comparer); + + /// @return reference to the object used to compare Index objects + static const index_pair_comparer_t& index_pair_comparer(); + + /// @param comparer the compare object to be used by this + static void index_pair_comparer(index_pair_comparer_t comparer); + + protected: + inline auto bra_range(AbstractTensor& t) const { return t._bra_mutable(); } + inline auto ket_range(AbstractTensor& t) const { return t._ket_mutable(); } + inline auto aux_range(AbstractTensor& t) const { return t._aux_mutable(); } + + /// the object used to compare indices + static index_comparer_t index_comparer_; + /// the object used to compare pairs of indices + static index_pair_comparer_t index_pair_comparer_; + + private: + static std::pair< + container::map>*, + std::unique_lock> + instance_map_accessor(); // map* + locked recursive mutex + static container::vector& cardinal_tensor_labels_accessor(); +}; + +/// @brief null Tensor canonicalizer does nothing +class NullTensorCanonicalizer : public TensorCanonicalizer { + public: + virtual ~NullTensorCanonicalizer() = default; + + ExprPtr apply(AbstractTensor&) const override; +}; + +class DefaultTensorCanonicalizer : public TensorCanonicalizer { + public: + DefaultTensorCanonicalizer() = default; + + /// @tparam IndexContainer a Container of Index objects such that @c + /// IndexContainer::value_type is convertible to Index (e.g. this can be + /// std::vector or std::set , but not std::map) + /// @param external_indices container of external Index objects + /// @warning @c external_indices is assumed to be immutable during the + /// lifetime of this object + template + DefaultTensorCanonicalizer(IndexContainer&& external_indices) { + ranges::for_each(external_indices, [this](const Index& idx) { + this->external_indices_.emplace(idx.label(), idx); + }); + } + virtual ~DefaultTensorCanonicalizer() = default; + + /// Implements TensorCanonicalizer::apply + /// @note Canonicalizes @c t by sorting its bra (if @c + /// t.symmetry()==Symmetry::nonsymm ) or its bra and ket (if @c + /// t.symmetry()!=Symmetry::nonsymm ), + /// with the external indices appearing "before" (smaller particle + /// indices) than the internal indices + ExprPtr apply(AbstractTensor& t) const override; + + /// Core of DefaultTensorCanonicalizer::apply, only does the canonicalization, + /// i.e. no tagging/untagging + template + ExprPtr apply(AbstractTensor& t, const IndexComp& idxcmp, + const IndexPairComp& paircmp) const { + // std::wcout << "abstract tensor: " << to_latex(t) << "\n"; + auto s = symmetry(t); + auto is_antisymm = (s == Symmetry::antisymm); + const auto _bra_rank = bra_rank(t); + const auto _ket_rank = ket_rank(t); + const auto _aux_rank = aux_rank(t); + const auto _rank = std::min(_bra_rank, _ket_rank); + + // nothing to do for rank-1 tensors + if (_bra_rank == 1 && _ket_rank == 1 && _aux_rank == 0) return nullptr; + + using ranges::begin; + using ranges::end; + using ranges::views::counted; + using ranges::views::take; + using ranges::views::zip; + + bool even = true; + switch (s) { + case Symmetry::antisymm: + case Symmetry::symm: { + auto _bra = bra_range(t); + auto _ket = ket_range(t); + // std::wcout << "canonicalizing " << to_latex(t); + IndexSwapper::thread_instance().reset(); + // std::{stable_}sort does not necessarily use swap! so must implement + // sort outselves .. thankfully ranks will be low so can stick with + // bubble + bubble_sort(begin(_bra), end(_bra), idxcmp); + bubble_sort(begin(_ket), end(_ket), idxcmp); + if (is_antisymm) + even = IndexSwapper::thread_instance().even_num_of_swaps(); + // std::wcout << " is " << (even ? "even" : "odd") << " and + // produces " << to_latex(t) << std::endl; + } break; + + case Symmetry::nonsymm: { + // sort particles with bra and ket functions first, + // then the particles with either bra or ket index + auto _bra = bra_range(t); + auto _ket = ket_range(t); + auto _zip_braket = zip(take(_bra, _rank), take(_ket, _rank)); + bubble_sort(begin(_zip_braket), end(_zip_braket), paircmp); + if (_bra_rank > _rank) { + auto size_of_rest = _bra_rank - _rank; + auto rest_of = counted(begin(_bra) + _rank, size_of_rest); + bubble_sort(begin(rest_of), end(rest_of), idxcmp); + } else if (_ket_rank > _rank) { + auto size_of_rest = _ket_rank - _rank; + auto rest_of = counted(begin(_ket) + _rank, size_of_rest); + bubble_sort(begin(rest_of), end(rest_of), idxcmp); + } + } break; + + default: + abort(); + } + + // TODO: Handle auxiliary index symmetries once they are introduced + // auto _aux = aux_range(t); + // ranges::sort(_aux, comp); + + ExprPtr result = + is_antisymm ? (even == false ? ex(-1) : nullptr) : nullptr; + return result; + } + + private: + container::map external_indices_; + + protected: + void tag_indices(AbstractTensor& t) const; +}; + +class TensorBlockCanonicalizer : public DefaultTensorCanonicalizer { + public: + TensorBlockCanonicalizer() = default; + ~TensorBlockCanonicalizer() = default; + + template + TensorBlockCanonicalizer(const IndexContainer& external_indices) + : DefaultTensorCanonicalizer(external_indices) {} + + ExprPtr apply(AbstractTensor& t) const override; +}; + +} // namespace sequant + +#endif // SEQUANT_CORE_TENSOR_CANONICALIZER_HPP diff --git a/SeQuant/core/tensor_network.cpp b/SeQuant/core/tensor_network.cpp index f7529969a..d30cf6543 100644 --- a/SeQuant/core/tensor_network.cpp +++ b/SeQuant/core/tensor_network.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -35,7 +36,7 @@ namespace sequant { ExprPtr TensorNetwork::canonicalize( const container::vector &cardinal_tensor_labels, bool fast, const named_indices_t *named_indices_ptr) { - ExprPtr canon_biproduct = ex(1); + ExprPtr canon_byproduct = ex(1); container::svector idx_terminals_sorted; // to avoid memory allocs if (Logger::instance().canonicalize) { @@ -399,15 +400,15 @@ ExprPtr TensorNetwork::canonicalize( nondefault_canonizer_ptr ? nondefault_canonizer_ptr.get() : &default_tensor_canonizer; auto bp = tensor_canonizer->apply(*tensor); - if (bp) *canon_biproduct *= *bp; + if (bp) *canon_byproduct *= *bp; } } edges_.clear(); ext_indices_.clear(); - assert(canon_biproduct->is()); - return (canon_biproduct->as().value() == 1) ? nullptr - : canon_biproduct; + assert(canon_byproduct->is()); + return (canon_byproduct->as().value() == 1) ? nullptr + : canon_byproduct; } std::tuple, std::vector, @@ -425,6 +426,8 @@ TensorNetwork::make_bliss_graph( const auto &named_indices = named_indices_ptr == nullptr ? this->ext_indices() : *named_indices_ptr; + VertexPainter colorizer(named_indices); + // results std::shared_ptr graph; std::vector vertex_labels( @@ -434,23 +437,6 @@ TensorNetwork::make_bliss_graph( std::vector vertex_type( edges_.size()); // the size will be updated - // N.B. Colors [0, 2 max rank + named_indices.size()) are reserved: - // 0 - the bra vertex (for particle 0, if bra is nonsymm, or for the entire - // bra, if (anti)symm) 1 - the bra vertex for particle 1, if bra is nonsymm - // ... - // max_rank - the ket vertex (for particle 0, if particle-asymmetric, or for - // the entire ket, if particle-symmetric) max_rank+1 - the ket vertex for - // particle 1, if particle-asymmetric - // ... - // 2 max_rank - first named index - // 2 max_rank + 1 - second named index - // ... - // N.B. For braket-symmetric tensors the ket vertices use the same indices as - // the bra vertices - auto nonreserved_color = [&named_indices](size_t color) -> bool { - return color >= 2 * max_rank + named_indices.size(); - }; - // compute # of vertices size_t nv = 0; size_t index_cnt = 0; @@ -468,17 +454,8 @@ TensorNetwork::make_bliss_graph( ++nv; // each index is a vertex vertex_labels.at(index_cnt) = idx.to_latex(); vertex_type.at(index_cnt) = VertexType::Index; - // assign color: named indices use reserved colors - const auto named_index_it = named_indices.find(idx); - if (named_index_it == - named_indices.end()) { // anonymous index? use Index::color - const auto idx_color = idx.color(); - assert(nonreserved_color(idx_color)); - vertex_color.at(index_cnt) = idx_color; - } else { - const auto named_index_rank = named_index_it - named_indices.begin(); - vertex_color.at(index_cnt) = 2 * max_rank + named_index_rank; - } + vertex_color.at(index_cnt) = colorizer(idx); + // each symmetric proto index bundle will have a vertex ... // for now only store the unique protoindex bundles in // symmetric_protoindex_bundles, then commit their data to @@ -505,11 +482,10 @@ TensorNetwork::make_bliss_graph( spbundle_label += L"}"; vertex_labels.push_back(spbundle_label); vertex_type.push_back(VertexType::SPBundle); - const auto idx_proto_indices_color = Index::proto_indices_color(bundle); - assert(nonreserved_color(idx_proto_indices_color)); - vertex_color.push_back(idx_proto_indices_color); + vertex_color.push_back(colorizer(bundle)); spbundle_cnt++; }); + // now account for vertex representation of tensors size_t tensor_cnt = 0; // this will map to tensor index to the first (core) vertex in its @@ -522,10 +498,8 @@ TensorNetwork::make_bliss_graph( const auto tlabel = label(*t); vertex_labels.emplace_back(tlabel); vertex_type.emplace_back(VertexType::TensorCore); - const auto t_color = hash::value(tlabel); - static_assert(sizeof(t_color) == sizeof(unsigned long int)); - assert(nonreserved_color(t_color)); - vertex_color.push_back(t_color); + vertex_color.push_back(colorizer(*t)); + // symmetric/antisymmetric tensors are represented by 3 more vertices: // - bra // - ket @@ -537,18 +511,24 @@ TensorNetwork::make_bliss_graph( std::wstring(L"bra") + to_wstring(bra_rank(tref)) + ((symmetry(tref) == Symmetry::antisymm) ? L"a" : L"s")); vertex_type.push_back(VertexType::TensorBra); - vertex_color.push_back(0); + vertex_color.push_back(colorizer(BraGroup{0})); vertex_labels.push_back( std::wstring(L"ket") + to_wstring(ket_rank(tref)) + ((symmetry(tref) == Symmetry::antisymm) ? L"a" : L"s")); vertex_type.push_back(VertexType::TensorKet); - vertex_color.push_back( - braket_symmetry(tref) == BraKetSymmetry::symm ? 0 : max_rank); + if (braket_symmetry(tref) == BraKetSymmetry::symm) { + // Use BraGroup for kets as well as they are supposed to be + // indistinguishable + vertex_color.push_back(colorizer(BraGroup{0})); + } else { + vertex_color.push_back(colorizer(KetGroup{0})); + } vertex_labels.push_back( std::wstring(L"bk") + ((symmetry(tref) == Symmetry::antisymm) ? L"a" : L"s")); - vertex_type.push_back(VertexType::TensorBraKet); - vertex_color.push_back(t_color); + vertex_type.push_back(VertexType::Particle); + // Color bk node in same color as tensor core + vertex_color.push_back(colorizer(tref)); } // nonsymmetric tensors are represented by 3*rank more vertices (with rank = // max(bra_rank(),ket_rank()) @@ -560,20 +540,36 @@ TensorNetwork::make_bliss_graph( auto pstr = to_wstring(p + 1); vertex_labels.push_back(std::wstring(L"bra") + pstr); vertex_type.push_back(VertexType::TensorBra); - const bool t_is_particle_symmetric = + const bool distinguishable_particles = particle_symmetry(tref) == ParticleSymmetry::nonsymm; - const auto bra_color = t_is_particle_symmetric ? p : 0; - vertex_color.push_back(bra_color); + vertex_color.push_back( + colorizer(BraGroup{distinguishable_particles ? p : 0})); vertex_labels.push_back(std::wstring(L"ket") + pstr); vertex_type.push_back(VertexType::TensorKet); - vertex_color.push_back(braket_symmetry(tref) == BraKetSymmetry::symm - ? bra_color - : bra_color + max_rank); + if (braket_symmetry(tref) == BraKetSymmetry::symm) { + // Use BraGroup for kets as well as they are supposed to be + // indistinguishable + vertex_color.push_back( + colorizer(BraGroup{distinguishable_particles ? p : 0})); + } else { + vertex_color.push_back( + colorizer(KetGroup{distinguishable_particles ? p : 0})); + } vertex_labels.push_back(std::wstring(L"bk") + pstr); - vertex_type.push_back(VertexType::TensorBraKet); - vertex_color.push_back(t_color); + vertex_type.push_back(VertexType::Particle); + vertex_color.push_back(colorizer(tref)); } } + // aux indices currently do not support any symmetry + assert(aux_rank(tref) <= max_rank); + for (size_t p = 0; p != aux_rank(tref); ++p) { + nv += 1; + auto pstr = to_wstring(p + 1); + vertex_labels.push_back(std::wstring(L"aux") + pstr); + vertex_type.push_back(VertexType::TensorAux); + vertex_color.push_back(colorizer(AuxGroup{p})); + } + ++tensor_cnt; }); @@ -646,21 +642,8 @@ TensorNetwork::make_bliss_graph( ++tensor_cnt; }); - // compress vertex colors to 32 bits, as required by Bliss, by hashing - size_t v_cnt = 0; - for (auto &&color : vertex_color) { - auto hash6432shift = [](size_t key) { - static_assert(sizeof(key) == 8); - key = (~key) + (key << 18); // key = (key << 18) - key - 1; - key = key ^ (key >> 31); - key = key * 21; // key = (key + (key << 2)) + (key << 4); - key = key ^ (key >> 11); - key = key + (key << 6); - key = key ^ (key >> 22); - return static_cast(key); - }; - graph->change_color(v_cnt, hash6432shift(color)); - ++v_cnt; + for (const auto [vertex, color] : ranges::views::enumerate(vertex_color)) { + graph->change_color(vertex, color); } return {graph, vertex_labels, vertex_color, vertex_type}; diff --git a/SeQuant/core/tensor_network.hpp b/SeQuant/core/tensor_network.hpp index a9043d150..cdc55afd5 100644 --- a/SeQuant/core/tensor_network.hpp +++ b/SeQuant/core/tensor_network.hpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include #include @@ -132,14 +134,7 @@ class TensorNetwork { int second_position_ = 0; }; - enum class VertexType { - Index, - SPBundle, - TensorBra, - TensorKet, - TensorBraKet, - TensorCore - }; + using VertexType = sequant::VertexType; public: /// @throw std::logic_error if exprptr_range contains a non-tensor @@ -172,7 +167,7 @@ class TensorNetwork { /// @param named_indices specifies the indices that cannot be renamed, i.e. /// their labels are meaningful; default is nullptr, which results in external /// indices treated as named indices - /// @return biproduct of canonicalization (e.g. phase); if none, returns + /// @return byproduct of canonicalization (e.g. phase); if none, returns /// nullptr ExprPtr canonicalize( const container::vector &cardinal_tensor_labels = {}, diff --git a/SeQuant/core/tensor_network_v2.cpp b/SeQuant/core/tensor_network_v2.cpp new file mode 100644 index 000000000..047f78cf0 --- /dev/null +++ b/SeQuant/core/tensor_network_v2.cpp @@ -0,0 +1,1016 @@ +// +// Created by Eduard Valeyev on 2019-02-26. +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace sequant { + +struct FullLabelIndexLocator { + std::wstring_view label; + FullLabelIndexLocator(std::wstring_view label) : label(std::move(label)) {} + + bool operator()(const TensorNetworkV2::Edge &edge) const { + return edge.idx().full_label() == label; + } + + bool operator()(const Index &idx) const { return idx.full_label() == label; } +}; + +bool tensors_commute(const AbstractTensor &lhs, const AbstractTensor &rhs) { + // tensors commute if their colors are different or either one of them + // is a c-number + return !(color(lhs) == color(rhs) && !is_cnumber(lhs) && !is_cnumber(rhs)); +} + +struct TensorBlockCompare { + bool operator()(const AbstractTensor &lhs, const AbstractTensor &rhs) const { + if (label(lhs) != label(rhs)) { + return label(lhs) < label(rhs); + } + + if (bra_rank(lhs) != bra_rank(rhs)) { + return bra_rank(lhs) < bra_rank(rhs); + } + if (ket_rank(lhs) != ket_rank(rhs)) { + return ket_rank(lhs) < ket_rank(rhs); + } + if (aux_rank(lhs) != aux_rank(rhs)) { + return aux_rank(lhs) < aux_rank(rhs); + } + + auto lhs_indices = indices(lhs); + auto rhs_indices = indices(rhs); + + for (auto lhs_it = lhs_indices.begin(), rhs_it = rhs_indices.begin(); + lhs_it != lhs_indices.end() && rhs_it != rhs_indices.end(); + ++lhs_it, ++rhs_it) { + if (lhs_it->space() != rhs_it->space()) { + return lhs_it->space() < rhs_it->space(); + } + } + + // Tensors are identical + return false; + } +}; + +/// Compares tensors based on their label and orders them according to the order +/// of the given cardinal tensor labels. If two tensors can't be discriminated +/// via their label, they are compared based on regular +/// AbstractTensor::operator< or based on their tensor block (the spaces of +/// their indices) - depending on the configuration. If this doesn't +/// discriminate the tensors, they are considered equal +template +struct CanonicalTensorCompare { + const CardinalLabels &labels; + bool blocks_only; + + CanonicalTensorCompare(const CardinalLabels &labels, bool blocks_only) + : labels(labels), blocks_only(blocks_only) {} + + void set_blocks_only(bool blocks_only) { this->blocks_only = blocks_only; } + + bool operator()(const AbstractTensorPtr &lhs_ptr, + const AbstractTensorPtr &rhs_ptr) const { + assert(lhs_ptr); + assert(rhs_ptr); + const AbstractTensor &lhs = *lhs_ptr; + const AbstractTensor &rhs = *rhs_ptr; + + if (!tensors_commute(lhs, rhs)) { + return false; + } + + const auto get_label = [](const auto &t) { + if (label(t).back() == adjoint_label) { + // grab base label if adjoint label is present + return label(t).substr(0, label(t).size() - 1); + } + return label(t); + }; + + const auto lhs_it = std::find(labels.begin(), labels.end(), get_label(lhs)); + const auto rhs_it = std::find(labels.begin(), labels.end(), get_label(rhs)); + + if (lhs_it != rhs_it) { + // At least one of the tensors is a cardinal one + // -> Order by the occurrence in the cardinal label list + return std::distance(labels.begin(), lhs_it) < + std::distance(labels.begin(), rhs_it); + } + + // Either both are the same cardinal tensor or none is a cardinal tensor + if (blocks_only) { + TensorBlockCompare cmp; + return cmp(lhs, rhs); + } else { + return lhs < rhs; + } + } +}; + +TensorNetworkV2::Vertex::Vertex(Origin origin, std::size_t terminal_idx, + std::size_t index_slot, Symmetry terminal_symm) + : origin(origin), + terminal_idx(terminal_idx), + index_slot(index_slot), + terminal_symm(terminal_symm) {} + +TensorNetworkV2::Origin TensorNetworkV2::Vertex::getOrigin() const { + return origin; +} + +std::size_t TensorNetworkV2::Vertex::getTerminalIndex() const { + return terminal_idx; +} + +std::size_t TensorNetworkV2::Vertex::getIndexSlot() const { return index_slot; } + +Symmetry TensorNetworkV2::Vertex::getTerminalSymmetry() const { + return terminal_symm; +} + +bool TensorNetworkV2::Vertex::operator<(const Vertex &rhs) const { + if (terminal_idx != rhs.terminal_idx) { + return terminal_idx < rhs.terminal_idx; + } + + // Both vertices belong to same tensor -> they must have same symmetry + assert(terminal_symm == rhs.terminal_symm); + + if (origin != rhs.origin) { + return origin < rhs.origin; + } + + // We only take the index slot into account for non-symmetric tensors + if (terminal_symm == Symmetry::nonsymm) { + return index_slot < rhs.index_slot; + } else { + return false; + } +} + +bool TensorNetworkV2::Vertex::operator==(const Vertex &rhs) const { + // Slot position is only taken into account for non_symmetric tensors + const std::size_t lhs_slot = + (terminal_symm == Symmetry::nonsymm) * index_slot; + const std::size_t rhs_slot = + (rhs.terminal_symm == Symmetry::nonsymm) * rhs.index_slot; + + assert(terminal_idx != rhs.terminal_idx || + terminal_symm == rhs.terminal_symm); + + return terminal_idx == rhs.terminal_idx && lhs_slot == rhs_slot && + origin == rhs.origin; +} + +std::size_t TensorNetworkV2::Graph::vertex_to_index_idx( + std::size_t vertex) const { + assert(vertex_types.at(vertex) == VertexType::Index); + + std::size_t index_idx = 0; + for (std::size_t i = 0; i <= vertex; ++i) { + if (vertex_types[i] == VertexType::Index) { + ++index_idx; + } + } + + assert(index_idx > 0); + + return index_idx - 1; +} + +std::size_t TensorNetworkV2::Graph::vertex_to_tensor_idx( + std::size_t vertex) const { + assert(vertex_types.at(vertex) == VertexType::TensorCore); + + std::size_t tensor_idx = 0; + for (std::size_t i = 0; i <= vertex; ++i) { + if (vertex_types[i] == VertexType::TensorCore) { + ++tensor_idx; + } + } + + assert(tensor_idx > 0); + + return tensor_idx - 1; +} + +template +auto permute(const ArrayLike &vector, const Permutation &perm) { + using std::size; + auto sz = size(vector); + std::decay_t pvector(sz); + for (size_t i = 0; i != sz; ++i) pvector[perm[i]] = vector[i]; + return pvector; +} + +template +void apply_index_replacements(AbstractTensor &tensor, + const ReplacementMap &replacements) { +#ifndef NDEBUG + // assert that tensors' indices are not tagged since going to tag indices + assert(ranges::none_of( + indices(tensor), [](const Index &idx) { return idx.tag().has_value(); })); +#endif + + bool pass_mutated; + do { + pass_mutated = transform_indices(tensor, replacements); + } while (pass_mutated); // transform till stops changing + + reset_tags(tensor); +} + +template +void apply_index_replacements(ArrayLike &tensors, + const ReplacementMap &replacements) { + for (auto &tensor : tensors) { + apply_index_replacements(*tensor, replacements); + } +} + +template +void order_to_indices(Container &container) { + std::vector indices; + indices.resize(container.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), + [&container](std::size_t lhs, std::size_t rhs) { + return container[lhs] < container[rhs]; + }); + // Overwrite container contents with indices + std::copy(indices.begin(), indices.end(), container.begin()); +} + +template +void sort_via_indices(Container &container, const Comparator &cmp) { + std::vector indices; + indices.resize(container.size()); + std::iota(indices.begin(), indices.end(), 0); + + if constexpr (stable) { + std::stable_sort(indices.begin(), indices.end(), cmp); + } else { + std::sort(indices.begin(), indices.end(), cmp); + } + + // Bring elements in container into the order given by indices + // (the association is container[k] = container[indices[k]]) + // -> implementation from https://stackoverflow.com/a/838789 + + for (std::size_t i = 0; i < container.size(); ++i) { + if (indices[i] == i) { + // This element is already where it is supposed to be + continue; + } + + // Find the offset of the index pointing to i + // -> since we are going to change the content of the vector at position i, + // we have to update the index-mapping referencing i to point to the new + // location of the element that used to be at position i + std::size_t k; + for (k = i + 1; k < container.size(); ++k) { + if (indices[k] == i) { + break; + } + } + std::swap(container[i], container[indices[i]]); + std::swap(indices[i], indices[k]); + } +} + +void TensorNetworkV2::canonicalize_graph(const NamedIndexSet &named_indices) { + if (Logger::instance().canonicalize) { + std::wcout << "TensorNetworkV2::canonicalize_graph: input tensors\n"; + size_t cnt = 0; + ranges::for_each(tensors_, [&](const auto &t) { + std::wcout << "tensor " << cnt++ << ": " << to_latex(*t) << std::endl; + }); + std::wcout << std::endl; + } + + if (!have_edges_) { + init_edges(); + } + + const auto is_anonymous_index = [named_indices](const Index &idx) { + return named_indices.find(idx) == named_indices.end(); + }; + + // index factory to generate anonymous indices + IndexFactory idxfac(is_anonymous_index, 1); + + // make the graph + Graph graph = create_graph(&named_indices); + // graph.bliss_graph->write_dot(std::wcout, graph.vertex_labels); + + if (Logger::instance().canonicalize_input_graph) { + std::wcout << "Input graph for canonicalization:\n"; + graph.bliss_graph->write_dot(std::wcout, graph.vertex_labels); + } + + // canonize the graph + bliss::Stats stats; + graph.bliss_graph->set_splitting_heuristic(bliss::Graph::shs_fsm); + const unsigned int *canonize_perm = + graph.bliss_graph->canonical_form(stats, nullptr, nullptr); + + if (Logger::instance().canonicalize_dot) { + std::wcout << "Canonicalization permutation:\n"; + for (std::size_t i = 0; i < graph.vertex_labels.size(); ++i) { + std::wcout << i << " -> " << canonize_perm[i] << "\n"; + } + std::wcout << "Canonicalized graph:\n"; + bliss::Graph *cgraph = graph.bliss_graph->permute(canonize_perm); + cgraph->write_dot(std::wcout, {}, true); + auto cvlabels = permute(graph.vertex_labels, canonize_perm); + std::wcout << "with our labels:\n"; + cgraph->write_dot(std::wcout, cvlabels); + delete cgraph; + } + + container::map tensor_idx_to_vertex; + container::map> + tensor_idx_to_particle_order; + container::map index_idx_to_vertex; + std::size_t tensor_idx = 0; + std::size_t index_idx = 0; + + for (std::size_t vertex = 0; vertex < graph.vertex_types.size(); ++vertex) { + switch (graph.vertex_types[vertex]) { + case VertexType::Index: + index_idx_to_vertex[index_idx] = vertex; + index_idx++; + break; + case VertexType::Particle: { + assert(tensor_idx > 0); + const std::size_t base_tensor_idx = tensor_idx - 1; + assert(symmetry(*tensors_.at(base_tensor_idx)) == Symmetry::nonsymm); + tensor_idx_to_particle_order[base_tensor_idx].push_back( + canonize_perm[vertex]); + break; + } + case VertexType::TensorCore: + tensor_idx_to_vertex[tensor_idx] = vertex; + tensor_idx++; + break; + case VertexType::TensorBra: + case VertexType::TensorKet: + case VertexType::TensorAux: + case VertexType::SPBundle: + break; + } + } + + assert(index_idx_to_vertex.size() == edges_.size()); + assert(tensor_idx_to_vertex.size() == tensors_.size()); + assert(tensor_idx_to_particle_order.size() <= tensors_.size()); + + // order_to_indices(index_order); + for (auto ¤t : tensor_idx_to_particle_order) { + order_to_indices(current.second); + } + + container::map idxrepl; + // Sort edges so that their order corresponds to the order of indices in the + // canonical graph + // Use this ordering to relabel anonymous indices + const auto index_sorter = [&index_idx_to_vertex, &canonize_perm]( + std::size_t lhs_idx, std::size_t rhs_idx) { + const std::size_t lhs_vertex = index_idx_to_vertex.at(lhs_idx); + const std::size_t rhs_vertex = index_idx_to_vertex.at(rhs_idx); + + return canonize_perm[lhs_vertex] < canonize_perm[rhs_vertex]; + }; + + sort_via_indices(edges_, index_sorter); + + for (const Edge ¤t : edges_) { + const Index &idx = current.idx(); + + if (!is_anonymous_index(idx)) { + continue; + } + + idxrepl.insert(std::make_pair(idx, idxfac.make(idx))); + } + + if (Logger::instance().canonicalize) { + for (const auto &idxpair : idxrepl) { + std::wcout << "TensorNetworkV2::canonicalize_graph: replacing " + << to_latex(idxpair.first) << " with " + << to_latex(idxpair.second) << std::endl; + } + } + + apply_index_replacements(tensors_, idxrepl); + + // Perform particle-1,2-swaps as indicated by the graph canonization + for (std::size_t i = 0; i < tensors_.size(); ++i) { + AbstractTensor &tensor = *tensors_[i]; + const std::size_t num_particles = + std::min(bra_rank(tensor), ket_rank(tensor)); + + auto it = tensor_idx_to_particle_order.find(i); + if (it == tensor_idx_to_particle_order.end()) { + assert(num_particles == 0 || symmetry(*tensors_[i]) != Symmetry::nonsymm); + continue; + } + + const auto &particle_order = it->second; + auto bra_indices = tensor._bra(); + auto ket_indices = tensor._ket(); + + assert(num_particles == particle_order.size()); + + // Swap indices column-wise + idxrepl.clear(); + for (std::size_t col = 0; col < num_particles; ++col) { + if (particle_order[col] == col) { + continue; + } + + idxrepl.insert( + std::make_pair(bra_indices[col], bra_indices[particle_order[col]])); + idxrepl.insert( + std::make_pair(ket_indices[col], ket_indices[particle_order[col]])); + } + + if (!idxrepl.empty()) { + if (Logger::instance().canonicalize) { + for (const auto &idxpair : idxrepl) { + std::wcout + << "TensorNetworkV2::canonicalize_graph: permuting particles in " + << to_latex(tensor) << " by replacing " << to_latex(idxpair.first) + << " with " << to_latex(idxpair.second) << std::endl; + } + } + apply_index_replacements(tensor, idxrepl); + } + } + + // Bring tensors into canonical order (analogously to how we reordered + // indices), but ensure to respect commutativity! + const auto tensor_sorter = [this, &canonize_perm, &tensor_idx_to_vertex]( + std::size_t lhs_idx, std::size_t rhs_idx) { + const AbstractTensor &lhs = *tensors_[lhs_idx]; + const AbstractTensor &rhs = *tensors_[rhs_idx]; + + if (!tensors_commute(lhs, rhs)) { + return false; + } + + const std::size_t lhs_vertex = tensor_idx_to_vertex.at(lhs_idx); + const std::size_t rhs_vertex = tensor_idx_to_vertex.at(rhs_idx); + + // Commuting tensors are sorted based on their canonical order which is + // given by the order of the corresponding vertices in the canonical graph + // representation + return canonize_perm[lhs_vertex] < canonize_perm[rhs_vertex]; + }; + + sort_via_indices(tensors_, tensor_sorter); + + // The tensor reordering and index relabelling made the current set of edges + // invalid + edges_.clear(); + have_edges_ = false; + + if (Logger::instance().canonicalize) { + std::wcout << "TensorNetworkV2::canonicalize_graph: tensors after " + "canonicalization\n"; + size_t cnt = 0; + ranges::for_each(tensors_, [&](const auto &t) { + std::wcout << "tensor " << cnt++ << ": " << to_latex(*t) << std::endl; + }); + } +} + +ExprPtr TensorNetworkV2::canonicalize( + const container::vector &cardinal_tensor_labels, bool fast, + const NamedIndexSet *named_indices_ptr) { + if (Logger::instance().canonicalize) { + std::wcout << "TensorNetworkV2::canonicalize(" << (fast ? "fast" : "slow") + << "): input tensors\n"; + size_t cnt = 0; + ranges::for_each(tensors_, [&](const auto &t) { + std::wcout << "tensor " << cnt++ << ": " << to_latex(*t) << std::endl; + }); + std::wcout << "cardinal_tensor_labels = "; + ranges::for_each(cardinal_tensor_labels, + [](auto &&i) { std::wcout << i << L" "; }); + std::wcout << std::endl; + } + + if (!have_edges_) { + init_edges(); + } + + // initialize named_indices by default to all external indices + const auto &named_indices = + named_indices_ptr == nullptr ? this->ext_indices() : *named_indices_ptr; + + if (!fast) { + // The graph-based canonization is required in call cases in which there are + // indistinguishable tensors present in the expression. Their order and + // indexing can only be determined via this rigorous canonization. + canonicalize_graph(named_indices); + } + + // Ensure each individual tensor is written in the way that its tensor + // block (== order of index spaces) is canonical + ExprPtr byproduct = canonicalize_individual_tensor_blocks(named_indices); + + CanonicalTensorCompare tensor_sorter( + cardinal_tensor_labels, true); + + std::stable_sort(tensors_.begin(), tensors_.end(), tensor_sorter); + + init_edges(); + + if (Logger::instance().canonicalize) { + std::wcout << "TensorNetworkV2::canonicalize(" << (fast ? "fast" : "slow") + << "): tensors after initial sort\n"; + size_t cnt = 0; + ranges::for_each(tensors_, [&](const auto &t) { + std::wcout << "tensor " << cnt++ << ": " << to_latex(*t) << std::endl; + }); + } + + // helpers to filter named ("external" in traditional use case) / anonymous + // ("internal" in traditional use case) + auto is_named_index = [&](const Index &idx) { + return named_indices.find(idx) != named_indices.end(); + }; + auto is_anonymous_index = [&](const Index &idx) { + return named_indices.find(idx) == named_indices.end(); + }; + + // Sort edges based on the order of the tensors they connect + std::stable_sort(edges_.begin(), edges_.end(), + [&is_named_index](const Edge &lhs, const Edge &rhs) { + // Sort first by index's character (named < anonymous), + // then by Edge (not by Index's full label) ... this + // automatically puts named indices first + const bool lhs_is_named = is_named_index(lhs.idx()); + const bool rhs_is_named = is_named_index(rhs.idx()); + + if (lhs_is_named == rhs_is_named) { + return lhs < rhs; + } else { + return lhs_is_named; + } + }); + + // index factory to generate anonymous indices + // -> start reindexing anonymous indices from 1 + IndexFactory idxfac(is_anonymous_index, 1); + + container::map idxrepl; + + // Use the new order of edges as the canonical order of indices and relabel + // accordingly (but only anonymous indices, of course) + for (std::size_t i = named_indices.size(); i < edges_.size(); ++i) { + const Index &index = edges_[i].idx(); + assert(is_anonymous_index(index)); + Index replacement = idxfac.make(index); + idxrepl.emplace(std::make_pair(index, replacement)); + } + + // Done computing canonical index replacement list + + if (Logger::instance().canonicalize) { + for (const auto &idxpair : idxrepl) { + std::wcout << "TensorNetworkV2::canonicalize(" << (fast ? "fast" : "slow") + << "): replacing " << to_latex(idxpair.first) << " with " + << to_latex(idxpair.second) << std::endl; + } + } + + apply_index_replacements(tensors_, idxrepl); + + byproduct *= canonicalize_individual_tensors(named_indices); + + // We assume that re-indexing did not change the canonical order of tensors + assert(std::is_sorted(tensors_.begin(), tensors_.end(), tensor_sorter)); + // However, in order to produce the most aesthetically pleasing result, we now + // reorder tensors based on the regular AbstractTensor::operator<, which takes + // the explicit index labelling of tensors into account. + tensor_sorter.set_blocks_only(false); + std::stable_sort(tensors_.begin(), tensors_.end(), tensor_sorter); + + have_edges_ = false; + + assert(byproduct->is()); + return (byproduct->as().value() == 1) ? nullptr : byproduct; +} + +TensorNetworkV2::Graph TensorNetworkV2::create_graph( + const NamedIndexSet *named_indices_ptr) const { + assert(have_edges_); + + // initialize named_indices by default to all external indices + const NamedIndexSet &named_indices = + named_indices_ptr == nullptr ? this->ext_indices() : *named_indices_ptr; + + VertexPainter colorizer(named_indices); + + // core, bra, ket, auxiliary and optionally (for non-symmetric tensors) a + // particle vertex + constexpr std::size_t num_tensor_components = 5; + + // results + Graph graph; + // We know that at the very least all indices and all tensors will yield + // vertex representations + std::size_t vertex_count_estimate = + edges_.size() + num_tensor_components * tensors_.size(); + graph.vertex_labels.reserve(vertex_count_estimate); + graph.vertex_colors.reserve(vertex_count_estimate); + graph.vertex_types.reserve(vertex_count_estimate); + + container::map proto_bundles; + + container::map tensor_vertices; + tensor_vertices.reserve(tensors_.size()); + + container::vector> edges; + edges.reserve(edges_.size() + tensors_.size()); + + // Add vertices for tensors + for (std::size_t tensor_idx = 0; tensor_idx < tensors_.size(); ++tensor_idx) { + assert(tensor_vertices.find(tensor_idx) == tensor_vertices.end()); + assert(tensors_.at(tensor_idx)); + const AbstractTensor &tensor = *tensors_.at(tensor_idx); + + // Tensor core + graph.vertex_labels.emplace_back(label(tensor)); + graph.vertex_types.emplace_back(VertexType::TensorCore); + graph.vertex_colors.push_back(colorizer(tensor)); + + const std::size_t tensor_vertex = graph.vertex_labels.size() - 1; + tensor_vertices.insert(std::make_pair(tensor_idx, tensor_vertex)); + + // Create vertices to group indices + const Symmetry tensor_sym = symmetry(tensor); + if (tensor_sym == Symmetry::nonsymm) { + // Create separate vertices for every index + // Additionally, we need particle vertices to group indices that belong to + // the same particle (are in the same "column" in the usual tensor + // notation) + const std::size_t num_particle_vertices = + std::min(bra_rank(tensor), ket_rank(tensor)); + const bool is_part_symm = + particle_symmetry(tensor) == ParticleSymmetry::symm; + // TODO: How to handle BraKetSymmetry::conjugate? + const bool is_braket_symm = + braket_symmetry(tensor) == BraKetSymmetry::symm; + + for (std::size_t i = 0; i < num_particle_vertices; ++i) { + graph.vertex_labels.emplace_back(L"p_" + std::to_wstring(i + 1)); + graph.vertex_types.push_back(VertexType::Particle); + // Particles are indistinguishable -> always use same ID + graph.vertex_colors.push_back(colorizer(ParticleGroup{0})); + edges.push_back( + std::make_pair(tensor_vertex, graph.vertex_labels.size() - 1)); + } + + for (std::size_t i = 0; i < bra_rank(tensor); ++i) { + const bool is_unpaired_idx = i >= num_particle_vertices; + const bool color_idx = is_unpaired_idx || !is_part_symm; + + graph.vertex_labels.emplace_back(L"bra_" + std::to_wstring(i + 1)); + graph.vertex_types.push_back(VertexType::TensorBra); + graph.vertex_colors.push_back(colorizer(BraGroup{color_idx ? i : 0})); + + const std::size_t connect_vertex = + tensor_vertex + (is_unpaired_idx ? 0 : (i + 1)); + edges.push_back( + std::make_pair(connect_vertex, graph.vertex_labels.size() - 1)); + } + + for (std::size_t i = 0; i < ket_rank(tensor); ++i) { + const bool is_unpaired_idx = i >= num_particle_vertices; + const bool color_idx = is_unpaired_idx || !is_part_symm; + + graph.vertex_labels.emplace_back(L"ket_" + std::to_wstring(i + 1)); + graph.vertex_types.push_back(VertexType::TensorKet); + if (is_braket_symm) { + // Use BraGroup for kets as well as they are supposed to be + // indistinguishable + graph.vertex_colors.push_back(colorizer(BraGroup{color_idx ? i : 0})); + } else { + graph.vertex_colors.push_back(colorizer(KetGroup{color_idx ? i : 0})); + } + + const std::size_t connect_vertex = + tensor_vertex + (is_unpaired_idx ? 0 : (i + 1)); + edges.push_back( + std::make_pair(connect_vertex, graph.vertex_labels.size() - 1)); + } + } else { + // Shared set of bra/ket vertices for all indices + std::wstring suffix = tensor_sym == Symmetry::symm ? L"_s" : L"_a"; + + graph.vertex_labels.push_back(L"bra" + suffix); + graph.vertex_types.push_back(VertexType::TensorBra); + graph.vertex_colors.push_back(colorizer(BraGroup{0})); + edges.push_back( + std::make_pair(tensor_vertex, graph.vertex_labels.size() - 1)); + + graph.vertex_labels.push_back(L"ket" + suffix); + graph.vertex_types.push_back(VertexType::TensorKet); + // TODO: figure out how to handle BraKetSymmetry::conjugate + if (braket_symmetry(tensor) == BraKetSymmetry::symm) { + // Use BraGroup for kets as well as they should be indistinguishable + graph.vertex_colors.push_back(colorizer(BraGroup{0})); + } else { + graph.vertex_colors.push_back(colorizer(KetGroup{0})); + } + edges.push_back( + std::make_pair(tensor_vertex, graph.vertex_labels.size() - 1)); + } + + // TODO: handle aux indices permutation symmetries once they are supported + // for now, auxiliary indices are considered to always be asymmetric + for (std::size_t i = 0; i < aux_rank(tensor); ++i) { + graph.vertex_labels.emplace_back(L"aux_" + std::to_wstring(i + 1)); + graph.vertex_types.push_back(VertexType::TensorAux); + graph.vertex_colors.push_back(colorizer(AuxGroup{i})); + edges.push_back( + std::make_pair(tensor_vertex, graph.vertex_labels.size() - 1)); + } + } + + // Now add all indices (edges) to the graph + container::map index_vertices; + + for (const Edge ¤t_edge : edges_) { + const Index &index = current_edge.idx(); + graph.vertex_labels.push_back(std::wstring(index.full_label())); + graph.vertex_types.push_back(VertexType::Index); + graph.vertex_colors.push_back(colorizer(index)); + + const std::size_t index_vertex = graph.vertex_labels.size() - 1; + + index_vertices[index] = index_vertex; + + // Handle proto indices + if (index.has_proto_indices()) { + // For now we assume that all proto indices are symmetric + assert(index.symmetric_proto_indices()); + + std::size_t proto_vertex; + if (auto it = proto_bundles.find(index.proto_indices()); + it != proto_bundles.end()) { + proto_vertex = it->second; + } else { + // Create a new vertex for this bundle of proto indices + std::wstring spbundle_label = + L"{" + + (ranges::views::transform( + index.proto_indices(), + [](const Index &idx) { return idx.label(); }) | + ranges::views::join(L", ") | ranges::to()) + + L"}"; + + graph.vertex_labels.push_back(std::move(spbundle_label)); + graph.vertex_types.push_back(VertexType::SPBundle); + graph.vertex_colors.push_back(colorizer(index.proto_indices())); + + proto_vertex = graph.vertex_labels.size() - 1; + proto_bundles.insert( + std::make_pair(index.proto_indices(), proto_vertex)); + } + + edges.push_back(std::make_pair(index_vertex, proto_vertex)); + } + + // Connect index to the tensor(s) it is connected to + for (std::size_t i = 0; i < current_edge.vertex_count(); ++i) { + assert(i <= 1); + const Vertex &vertex = + i == 0 ? current_edge.first_vertex() : current_edge.second_vertex(); + + assert(tensor_vertices.find(vertex.getTerminalIndex()) != + tensor_vertices.end()); + const std::size_t tensor_vertex = + tensor_vertices.find(vertex.getTerminalIndex())->second; + + // Store an edge connecting the index vertex to the corresponding tensor + // vertex + const bool tensor_is_nonsymm = + vertex.getTerminalSymmetry() == Symmetry::nonsymm; + const AbstractTensor &tensor = *tensors_[vertex.getTerminalIndex()]; + std::size_t offset; + if (tensor_is_nonsymm) { + // We have to find the correct vertex to connect this index to (for + // non-symmetric tensors each index has its dedicated "group" vertex) + + // Move off the tensor core's vertex + offset = 1; + // Move past the explicit particle vertices + offset += std::min(bra_rank(tensor), ket_rank(tensor)); + + if (vertex.getOrigin() > Origin::Bra) { + offset += bra_rank(tensor); + } + + offset += vertex.getIndexSlot(); + } else { + static_assert(static_cast(Origin::Bra) == 1); + static_assert(static_cast(Origin::Ket) == 2); + static_assert(static_cast(Origin::Aux) == 3); + offset = static_cast(vertex.getOrigin()); + } + + if (vertex.getOrigin() > Origin::Ket) { + offset += ket_rank(tensor); + } + + const std::size_t tensor_component_vertex = tensor_vertex + offset; + + assert(tensor_component_vertex < graph.vertex_labels.size()); + edges.push_back(std::make_pair(index_vertex, tensor_component_vertex)); + } + } + + // Add edges between proto index bundle vertices and all vertices of the + // indices contained in that bundle i.e. if the bundle is {i_1,i_2}, the + // bundle would be connected with vertices for i_1 and i_2 + for (const auto &[bundle, vertex] : proto_bundles) { + for (const Index &idx : bundle) { + auto it = index_vertices.find(idx); + + assert(it != index_vertices.end()); + if (it == index_vertices.end()) { + std::abort(); + } + + edges.push_back(std::make_pair(it->second, vertex)); + } + } + + assert(graph.vertex_labels.size() == graph.vertex_colors.size()); + assert(graph.vertex_labels.size() == graph.vertex_types.size()); + + // Create the actual BLISS graph object + graph.bliss_graph = + std::make_unique(graph.vertex_labels.size()); + + for (const std::pair ¤t_edge : edges) { + graph.bliss_graph->add_edge(current_edge.first, current_edge.second); + } + + for (const auto [vertex, color] : + ranges::views::enumerate(graph.vertex_colors)) { + graph.bliss_graph->change_color(vertex, color); + } + + return graph; +} + +void TensorNetworkV2::init_edges() { + edges_.clear(); + ext_indices_.clear(); + + auto idx_insert = [this](const Index &idx, Vertex vertex) { + if (Logger::instance().tensor_network) { + std::wcout << "TensorNetworkV2::init_edges: idx=" << to_latex(idx) + << " attached to tensor " << vertex.getTerminalIndex() << " (" + << vertex.getOrigin() << ") at position " + << vertex.getIndexSlot() + << " (sym: " << to_wstring(vertex.getTerminalSymmetry()) << ")" + << std::endl; + } + + auto it = std::find_if(edges_.begin(), edges_.end(), + FullLabelIndexLocator(idx.full_label())); + if (it == edges_.end()) { + edges_.emplace_back(std::move(vertex), idx); + } else { + it->connect_to(std::move(vertex)); + } + }; + + for (std::size_t tensor_idx = 0; tensor_idx < tensors_.size(); ++tensor_idx) { + assert(tensors_[tensor_idx]); + const AbstractTensor &tensor = *tensors_[tensor_idx]; + const Symmetry tensor_symm = symmetry(tensor); + + auto bra_indices = tensor._bra(); + for (std::size_t index_idx = 0; index_idx < bra_indices.size(); + ++index_idx) { + idx_insert(bra_indices[index_idx], + Vertex(Origin::Bra, tensor_idx, index_idx, tensor_symm)); + } + + auto ket_indices = tensor._ket(); + for (std::size_t index_idx = 0; index_idx < ket_indices.size(); + ++index_idx) { + idx_insert(ket_indices[index_idx], + Vertex(Origin::Ket, tensor_idx, index_idx, tensor_symm)); + } + + auto aux_indices = tensor._aux(); + for (std::size_t index_idx = 0; index_idx < aux_indices.size(); + ++index_idx) { + // Note: for the time being we don't have a way of expressing + // permutational symmetry of auxiliary indices so we just assume there is + // no such symmetry + idx_insert(aux_indices[index_idx], + Vertex(Origin::Aux, tensor_idx, index_idx, Symmetry::nonsymm)); + } + } + + // extract external indices + for (const Edge ¤t : edges_) { + assert(current.vertex_count() > 0); + if (current.vertex_count() == 1) { + // External index (== Edge only connected to a single vertex in the + // network) + if (Logger::instance().tensor_network) { + std::wcout << "idx " << to_latex(current.idx()) << " is external" + << std::endl; + } + + bool inserted = ext_indices_.insert(current.idx()).second; + assert(inserted); + } + } + + have_edges_ = true; +} + +container::svector> TensorNetworkV2::factorize() { + abort(); // not yet implemented +} + +ExprPtr TensorNetworkV2::canonicalize_individual_tensor_blocks( + const NamedIndexSet &named_indices) { + return do_individual_canonicalization( + TensorBlockCanonicalizer(named_indices)); +} + +ExprPtr TensorNetworkV2::canonicalize_individual_tensors( + const NamedIndexSet &named_indices) { + return do_individual_canonicalization( + DefaultTensorCanonicalizer(named_indices)); +} + +ExprPtr TensorNetworkV2::do_individual_canonicalization( + const TensorCanonicalizer &canonicalizer) { + ExprPtr byproduct = ex(1); + + for (auto &tensor : tensors_) { + auto nondefault_canonizer_ptr = + TensorCanonicalizer::nondefault_instance_ptr(tensor->_label()); + const TensorCanonicalizer &tensor_canonizer = + nondefault_canonizer_ptr ? *nondefault_canonizer_ptr : canonicalizer; + + auto bp = canonicalizer.apply(*tensor); + + if (bp) { + byproduct *= bp; + } + } + + return byproduct; +} + +} // namespace sequant diff --git a/SeQuant/core/tensor_network_v2.hpp b/SeQuant/core/tensor_network_v2.hpp new file mode 100644 index 000000000..64f5e0c21 --- /dev/null +++ b/SeQuant/core/tensor_network_v2.hpp @@ -0,0 +1,332 @@ +// +// Created by Eduard Valeyev on 2019-02-02. +// + +#ifndef SEQUANT_TENSOR_NETWORK_V2_H +#define SEQUANT_TENSOR_NETWORK_V2_H + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// forward declarations +namespace bliss { +class Graph; +} + +namespace sequant { + +/// @brief A (non-directed) graph view of a sequence of AbstractTensor objects + +/// @note The main role of this is to canonize itself. Since Tensor objects can +/// be connected by multiple Index'es (thus edges are colored), what is +/// canonized is actually the graph of indices (roughly the dual of the tensor +/// graph), with Tensor objects represented by one or more vertices. +class TensorNetworkV2 { + public: + friend class TensorNetworkV2Accessor; + + enum class Origin { + Bra = 1, + Ket, + Aux, + }; + + class Vertex { + public: + Vertex(Origin origin, std::size_t terminal_idx, std::size_t index_slot, + Symmetry terminal_symm); + + Origin getOrigin() const; + std::size_t getTerminalIndex() const; + std::size_t getIndexSlot() const; + Symmetry getTerminalSymmetry() const; + + bool operator<(const Vertex &rhs) const; + bool operator==(const Vertex &rhs) const; + + private: + Origin origin; + std::size_t terminal_idx; + std::size_t index_slot; + Symmetry terminal_symm; + }; + + // clang-format off + /// @brief Edge in a TensorNetworkV2 = the Index annotating it + a pair of indices to identify which Tensor terminals it's connected to + + /// @note tensor terminals in a sequence of tensors are indexed as follows: + /// - >0 for bra terminals (i.e. "+7" indicated connection to a bra terminal + /// of 7th tensor object in the sequence) + /// - <0 for ket terminals + /// - 0 if free (not attached to any tensor objects) + /// - position records the terminal's location in the sequence of bra/ket + /// terminals (always 0 for symmetric/antisymmetric tensors) Terminal indices + /// are sorted by the tensor index (i.e. by the absolute value of the terminal + /// index), followed by position + // clang-format on + class Edge { + public: + Edge() = default; + explicit Edge(Vertex vertex) : first(std::move(vertex)), second() {} + Edge(Vertex vertex, Index index) + : first(std::move(vertex)), second(), index(std::move(index)) {} + + Edge &connect_to(Vertex vertex) { + assert(!second.has_value()); + + if (!first.has_value()) { + // unconnected Edge + first = std::move(vertex); + } else { + second = std::move(vertex); + if (second < first) { + // Ensure first <= second + std::swap(first, second); + } + } + return *this; + } + + bool operator<(const Edge &other) const { + if (vertex_count() != other.vertex_count()) { + // Ensure external indices (edges that are only attached to a tensor on + // one side) always come before internal ones + return vertex_count() < other.vertex_count(); + } + + if (!(first == other.first)) { + return first < other.first; + } + + if (second < other.second) { + return second < other.second; + } + + return index.space() < other.index.space(); + } + + bool operator==(const Edge &other) const { + return first == other.first && second == other.second; + } + + const Vertex &first_vertex() const { + assert(first.has_value()); + return first.value(); + } + const Vertex &second_vertex() const { + assert(second.has_value()); + return second.value(); + } + + /// @return the number of attached terminals (0, 1, or 2) + std::size_t vertex_count() const { + return second.has_value() ? 2 : (first.has_value() ? 1 : 0); + } + + const Index &idx() const { return index; } + + private: + std::optional first; + std::optional second; + Index index; + }; + + struct Graph { + /// The type used to encode the color of a vertex. The restriction of this + /// being as 32-bit integer comes from how BLISS is trying to convert these + /// into RGB values. + using VertexColor = std::uint32_t; + + std::unique_ptr bliss_graph; + std::vector vertex_labels; + std::vector vertex_colors; + std::vector vertex_types; + + Graph() = default; + + std::size_t vertex_to_index_idx(std::size_t vertex) const; + std::size_t vertex_to_tensor_idx(std::size_t vertex) const; + }; + + TensorNetworkV2(const Expr &expr) { + if (expr.size() > 0) { + for (const ExprPtr &subexpr : expr) { + add_expr(*subexpr); + } + } else { + add_expr(expr); + } + + init_edges(); + } + + TensorNetworkV2(const ExprPtr &expr) : TensorNetworkV2(*expr) {} + + template < + typename ExprPtrRange, + typename = std::enable_if_t && + !std::is_base_of_v>> + TensorNetworkV2(const ExprPtrRange &exprptr_range) { + static_assert( + std::is_base_of_v); + for (const ExprPtr ¤t : exprptr_range) { + add_expr(*current); + } + + init_edges(); + } + + /// @return const reference to the sequence of tensors + /// @note the order of tensors may be different from that provided as input + const auto &tensors() const { return tensors_; } + + using NamedIndexSet = container::set; + + /// @param cardinal_tensor_labels move all tensors with these labels to the + /// front before canonicalizing indices + /// @param fast if true (default), does fast canonicalization that is only + /// optimal if all tensors are distinct; set to false to perform complete + /// canonicalization + /// @param named_indices specifies the indices that cannot be renamed, i.e. + /// their labels are meaningful; default is nullptr, which results in external + /// indices treated as named indices + /// @return byproduct of canonicalization (e.g. phase); if none, returns + /// nullptr + ExprPtr canonicalize( + const container::vector &cardinal_tensor_labels = {}, + bool fast = true, const NamedIndexSet *named_indices = nullptr); + + /// Factorizes tensor network + /// @return sequence of binary products; each element encodes the tensors to + /// be + /// multiplied (values >0 refer to the tensors in tensors(), + /// values <0 refer to the elements of this sequence. E.g. sequences + /// @c {{0,1},{-1,2},{-2,3}} , @c {{0,2},{1,3},{-1,-2}} , @c + /// {{3,1},{2,-1},{0,-2}} encode the following respective + /// factorizations @c (((T0*T1)*T2)*T3) , @c ((T0*T2)*(T1*T3)) , and + /// @c (((T3*T1)*T2)*T0) . + container::svector> factorize(); + + /// accessor for the Edge object sequence + /// @return const reference to the sequence container of Edge objects, sorted + /// by their Index's full label + /// @sa Edge + const auto &edges() const { + assert(have_edges_); + return edges_; + } + + /// @brief Returns a range of external indices, i.e. those indices that do not + /// connect tensors + const auto &ext_indices() const { + assert(have_edges_); + return ext_indices_; + } + + /// @brief converts the network into a Bliss graph whose vertices are indices + /// and tensor vertex representations + /// @param[in] named_indices pointer to the set of named indices (ordinarily, + /// this includes all external indices); + /// default is nullptr, which means use all external indices for + /// named indices + /// @return The created Graph object + + /// @note Rules for constructing the graph: + /// - Indices with protoindices are connected to their protoindices, + /// either directly or (if protoindices are symmetric) via a protoindex + /// vertex. + /// - Indices are colored by their space, which in general encodes also + /// the space of the protoindices. + /// - An anti/symmetric n-body tensor has 2 terminals, each connected to + /// each other + to n index vertices. + /// - A nonsymmetric n-body tensor has n terminals, each connected to 2 + /// indices and 1 tensor vertex which is connected to all n terminal + /// indices. + /// - tensor vertices are colored by the label+rank+symmetry of the + /// tensor; terminal vertices are colored by the color of its tensor, + /// with the color of symm/antisymm terminals augmented by the + /// terminal's type (bra/ket). + Graph create_graph(const NamedIndexSet *named_indices = nullptr) const; + + private: + // source tensors and indices + container::svector tensors_; + + container::vector edges_; + bool have_edges_ = false; + // ext indices do not connect tensors + // sorted by *label* (not full label) of the corresponding value (Index) + // this ensures that proto indices are not considered and all internal indices + // have unique labels (not full labels) + NamedIndexSet ext_indices_; + + /// initializes edges_ and ext_indices_ + void init_edges(); + + /// Canonicalizes the network graph representation + /// Note: The explicit order of tensors and labelling of indices + /// remains undefined. + void canonicalize_graph(const NamedIndexSet &named_indices); + + /// Canonicalizes every individual tensor for itself, taking into account only + /// tensor blocks + /// @returns The byproduct of the canonicalizations + ExprPtr canonicalize_individual_tensor_blocks( + const NamedIndexSet &named_indices); + + /// Canonicalizes every individual tensor for itself + /// @returns The byproduct of the canonicalizations + ExprPtr canonicalize_individual_tensors(const NamedIndexSet &named_indices); + + ExprPtr do_individual_canonicalization( + const TensorCanonicalizer &canonicalizer); + + void add_expr(const Expr &expr) { + ExprPtr clone = expr.clone(); + + auto tensor_ptr = std::dynamic_pointer_cast(clone); + if (!tensor_ptr) { + throw std::invalid_argument( + "TensorNetworkV2::TensorNetworkV2: tried to add non-tensor to " + "network"); + } + + tensors_.push_back(std::move(tensor_ptr)); + } +}; + +template +std::basic_ostream &operator<<( + std::basic_ostream &stream, TensorNetworkV2::Origin origin) { + switch (origin) { + case TensorNetworkV2::Origin::Bra: + stream << "Bra"; + break; + case TensorNetworkV2::Origin::Ket: + stream << "Ket"; + break; + case TensorNetworkV2::Origin::Aux: + stream << "Aux"; + break; + } + return stream; +} + +} // namespace sequant + +#endif // SEQUANT_TENSOR_NETWORK_H diff --git a/SeQuant/core/utility/indices.hpp b/SeQuant/core/utility/indices.hpp index 5617cd591..55977e825 100644 --- a/SeQuant/core/utility/indices.hpp +++ b/SeQuant/core/utility/indices.hpp @@ -1,16 +1,31 @@ #ifndef SEQUANT_CORE_UTILITY_INDICES_HPP #define SEQUANT_CORE_UTILITY_INDICES_HPP +#include #include #include #include +#include #include +#include +#include #include namespace sequant { namespace detail { +template +struct not_in { + const Range& range; + + not_in(const Range& range) : range(range) {} + + template + bool operator()(const T& element) const { + return std::find(range.begin(), range.end(), element) == range.end(); + } +}; /// This function is equal to std::remove in case the given container /// contains none or only a single occurrence of the given element. @@ -36,9 +51,10 @@ template > struct IndexGroups { Container bra; Container ket; + Container aux; bool operator==(const IndexGroups& other) const { - return bra == other.bra && ket == other.ket; + return bra == other.bra && ket == other.ket && aux == other.aux; } bool operator!=(const IndexGroups& other) const { @@ -46,6 +62,13 @@ struct IndexGroups { } }; +/// A composite type for holding tensor-of-tensor indices +template > +struct TensorOfTensorIndices { + Container outer; + Container inner; +}; + template > IndexGroups get_unique_indices(const ExprPtr& expr); @@ -59,6 +82,36 @@ IndexGroups get_unique_indices(const Variable&) { return {}; } +/// @returns Lists of non-contracted indices arising when contracting the two +/// given tensors in the order bra, ket, auxiliary +template > +IndexGroups get_uncontracted_indices(const Tensor& t1, + const Tensor& t2) { + static_assert(std::is_same_v); + + IndexGroups groups; + + // Bra indices + std::copy_if(t1.bra().begin(), t1.bra().end(), std::back_inserter(groups.bra), + detail::not_in{t2.ket()}); + std::copy_if(t2.bra().begin(), t2.bra().end(), std::back_inserter(groups.bra), + detail::not_in{t1.ket()}); + + // Ket indices + std::copy_if(t1.ket().begin(), t1.ket().end(), std::back_inserter(groups.ket), + detail::not_in{t2.bra()}); + std::copy_if(t2.ket().begin(), t2.ket().end(), std::back_inserter(groups.ket), + detail::not_in{t1.bra()}); + + // Auxiliary indices + std::copy_if(t1.aux().begin(), t1.aux().end(), std::back_inserter(groups.aux), + detail::not_in{t2.aux()}); + std::copy_if(t2.aux().begin(), t2.aux().end(), std::back_inserter(groups.aux), + detail::not_in{t1.aux()}); + + return groups; +} + /// Obtains the set of unique (non-repeated) indices used in the given tensor template > IndexGroups get_unique_indices(const Tensor& tensor) { @@ -86,6 +139,17 @@ IndexGroups get_unique_indices(const Tensor& tensor) { } } + for (const Index& current : tensor.aux()) { + if (encounteredIndices.find(current) == encounteredIndices.end()) { + groups.aux.push_back(current); + encounteredIndices.insert(current); + } else { + detail::remove_one(groups.bra, current); + detail::remove_one(groups.ket, current); + detail::remove_one(groups.aux, current); + } + } + return groups; } @@ -131,6 +195,18 @@ IndexGroups get_unique_indices(const Product& product) { detail::remove_one(groups.ket, current); } } + + // Same for aux indices + for (Index& current : currentGroups.aux) { + if (encounteredIndices.find(current) == encounteredIndices.end()) { + encounteredIndices.insert(current); + groups.aux.push_back(std::move(current)); + } else { + detail::remove_one(groups.bra, current); + detail::remove_one(groups.ket, current); + detail::remove_one(groups.aux, current); + } + } } return groups; @@ -156,6 +232,45 @@ IndexGroups get_unique_indices(const ExprPtr& expr) { } } +template , typename Rng> +TensorOfTensorIndices tot_indices(Rng const& idxs) { + using ranges::not_fn; + using ranges::views::concat; + using ranges::views::filter; + using ranges::views::join; + using ranges::views::transform; + + // Container indep_idxs; + + TensorOfTensorIndices result; + auto& outer = result.outer; + + for (auto&& i : idxs | transform(&Index::proto_indices) | join) + if (!ranges::contains(outer, i)) outer.emplace_back(i); + + for (auto&& i : idxs | filter(not_fn(&Index::has_proto_indices))) + if (!ranges::contains(outer, i)) outer.emplace_back(i); + + auto& inner = result.inner; + for (auto&& i : idxs | filter(&Index::has_proto_indices)) + inner.emplace_back(i); + + return result; +} + +/// +/// Does the numeric comparison of the index suffixes using less-than operator. +/// +/// \param idx1 +/// \param idx2 +/// \return True if the numeric suffix of \c idx1 is less than that of \c idx2. +/// +inline bool suffix_compare(Index const& idx1, Index const& idx2) { + auto&& s1 = idx1.suffix(); + auto&& s2 = idx2.suffix(); + return (s1 && s2) && s1.value() < s2.value(); +} + } // namespace sequant #endif // SEQUANT_CORE_UTILITY_INDICES_HPP diff --git a/SeQuant/core/vertex_painter.cpp b/SeQuant/core/vertex_painter.cpp new file mode 100644 index 000000000..06b23cce2 --- /dev/null +++ b/SeQuant/core/vertex_painter.cpp @@ -0,0 +1,159 @@ +#include +#include + +namespace sequant { + +VertexPainter::VertexPainter( + const TensorNetworkV2::NamedIndexSet &named_indices) + : used_colors_(), named_indices_(named_indices) {} + +VertexPainter::Color VertexPainter::operator()(const AbstractTensor &tensor) { + Color color = to_color(hash::value(label(tensor))); + + return ensure_uniqueness(color, tensor); +} + +VertexPainter::Color VertexPainter::operator()(const BraGroup &group) { + Color color = to_color(group.id + 0xff); + + return ensure_uniqueness(color, group); +} + +VertexPainter::Color VertexPainter::operator()(const KetGroup &group) { + Color color = to_color(group.id + 0xff00); + + return ensure_uniqueness(color, group); +} + +VertexPainter::Color VertexPainter::operator()(const AuxGroup &group) { + Color color = to_color(group.id + 3 * 0xff0000); + + return ensure_uniqueness(color, group); +} + +VertexPainter::Color VertexPainter::operator()(const ParticleGroup &group) { + Color color = to_color(group.id); + + return ensure_uniqueness(color, group); +} + +VertexPainter::Color VertexPainter::operator()(const Index &idx) { + auto it = named_indices_.find(idx); + + std::size_t pre_color; + if (it == named_indices_.end()) { + // anonymous index + pre_color = idx.color(); + } else { + pre_color = static_cast( + std::distance(named_indices_.begin(), it)); + } + // shift + pre_color += 0xaa; + + return ensure_uniqueness(to_color(pre_color), idx); +} + +VertexPainter::Color VertexPainter::operator()(const ProtoBundle &bundle) { + Color color = to_color(Index::proto_indices_color(bundle)); + + return ensure_uniqueness(color, bundle); +} + +VertexPainter::Color VertexPainter::to_color(std::size_t color) const { + // Due to the way we compute the input color, different colors might only + // differ by a value of 1. This is fine for the algorithmic purpose (after + // all, colors need only be different - by how much is irrelevant), but + // sometimes we'll want to use those colors as actual colors to show to a + // human being. In those cases, having larger differences makes it easier to + // recognize different colors. Therefore, we hash-combine with an + // arbitrarily chosen salt with the goal that this will uniformly spread out + // all input values and therefore increase color differences. + constexpr std::size_t salt = 0x43d2c59cb15b73f0; + hash::combine(color, salt); + + if constexpr (sizeof(Color) >= sizeof(std::size_t)) { + return color; + } + + // Need to somehow fit the color into a lower precision integer. In the + // general case, this is necessarily a lossy conversion. We make the + // assumption that the input color is + // - a hash, or + // - computed from some object ID + // In the first case, we assume that the used hash function has a uniform + // distribution or if there is a bias, the bias is towards lower numbers. + // This allows us to simply reuse the lower x bits of the hash as a new hash + // (where x == CHAR_BIT * sizeof(VertexColor)). In the second case we assume + // that such values never exceed the possible value range of VertexColor so + // that again, we can simply take the lower x bits of color and in this case + // even retain the numeric value representing the color. Handily, this is + // exactly what happens when we perform a conversion into a narrower type. + // We only have to make sure that the underlying types are unsigned as + // otherwise the behavior is undefined. + static_assert(sizeof(Color) < sizeof(std::size_t)); + static_assert(std::is_unsigned_v, + "Narrowing conversion are undefined for signed integers"); + static_assert(std::is_unsigned_v, + "Narrowing conversion are undefined for signed integers"); + return static_cast(color); +} + +bool VertexPainter::may_have_same_color(const VertexData &data, + const AbstractTensor &tensor) { + return std::holds_alternative(data) && + label(*std::get(data)) == label(tensor); +} + +bool VertexPainter::may_have_same_color(const VertexData &data, + const BraGroup &group) { + return std::holds_alternative(data) && + std::get(data).id == group.id; +} + +bool VertexPainter::may_have_same_color(const VertexData &data, + const KetGroup &group) { + return std::holds_alternative(data) && + std::get(data).id == group.id; +} + +bool VertexPainter::may_have_same_color(const VertexData &data, + const AuxGroup &group) { + return std::holds_alternative(data) && + std::get(data).id == group.id; +} + +bool VertexPainter::may_have_same_color(const VertexData &data, + const ParticleGroup &group) { + return std::holds_alternative(data) && + std::get(data).id == group.id; +} + +bool VertexPainter::may_have_same_color(const VertexData &data, + const Index &idx) { + if (!std::holds_alternative(data)) { + return false; + } + + const Index &lhs = std::get(data); + + auto it1 = named_indices_.find(lhs); + auto it2 = named_indices_.find(idx); + + if (it1 != it2) { + // Either one index is named and the other is not or both are named, but + // are different indices + return false; + } + + return lhs.color() == idx.color(); +} + +bool VertexPainter::may_have_same_color(const VertexData &data, + const ProtoBundle &bundle) { + return std::holds_alternative(data) && + Index::proto_indices_color(*std::get(data)) == + Index::proto_indices_color(bundle); +} + +} // namespace sequant diff --git a/SeQuant/core/vertex_painter.hpp b/SeQuant/core/vertex_painter.hpp new file mode 100644 index 000000000..896285707 --- /dev/null +++ b/SeQuant/core/vertex_painter.hpp @@ -0,0 +1,106 @@ +#ifndef SEQUANT_VERTEX_PAINTER_H +#define SEQUANT_VERTEX_PAINTER_H + +#include +#include +#include +#include + +#include +#include + +namespace sequant { + +using ProtoBundle = + std::decay_t().proto_indices())>; + +struct BraGroup { + explicit BraGroup(std::size_t id) : id(id) {} + + std::size_t id; +}; +struct KetGroup { + explicit KetGroup(std::size_t id) : id(id) {} + + std::size_t id; +}; +struct AuxGroup { + explicit AuxGroup(std::size_t id) : id(id) {} + + std::size_t id; +}; +struct ParticleGroup { + explicit ParticleGroup(std::size_t id) : id(id) {} + + std::size_t id; +}; + +/// Can be used to assign unique colors to a set of objects. The class +/// automatically ensures that there are no accidental color duplications for +/// objects that actually should have different colors (i.e. this is more than a +/// hash function). It is intended to be used to determine the vertex colors in +/// a colored graph representing a tensor network. +class VertexPainter { + public: + using Color = TensorNetworkV2::Graph::VertexColor; + using VertexData = + std::variant; + using ColorMap = container::map; + + VertexPainter(const TensorNetworkV2::NamedIndexSet &named_indices); + + const ColorMap &used_colors() const; + + Color operator()(const AbstractTensor &tensor); + Color operator()(const BraGroup &group); + Color operator()(const KetGroup &group); + Color operator()(const AuxGroup &group); + Color operator()(const ParticleGroup &group); + Color operator()(const Index &idx); + Color operator()(const ProtoBundle &bundle); + + private: + ColorMap used_colors_; + const TensorNetworkV2::NamedIndexSet &named_indices_; + + Color to_color(std::size_t color) const; + + template + Color ensure_uniqueness(Color color, const T &val) { + auto it = used_colors_.find(color); + while (it != used_colors_.end() && !may_have_same_color(it->second, val)) { + // Color collision: val was computed to have the same color + // as another object, but these objects do not compare equal (for + // the purpose of color assigning). + // -> Need to modify color until conflict is resolved. + color++; + it = used_colors_.find(color); + } + + if (it == used_colors_.end()) { + // We have not yet seen this color before -> add it to cache + if constexpr (std::is_same_v || + std::is_same_v) { + used_colors_[color] = &val; + } else { + used_colors_[color] = val; + } + } + + return color; + } + + bool may_have_same_color(const VertexData &data, + const AbstractTensor &tensor); + bool may_have_same_color(const VertexData &data, const BraGroup &group); + bool may_have_same_color(const VertexData &data, const KetGroup &group); + bool may_have_same_color(const VertexData &data, const AuxGroup &group); + bool may_have_same_color(const VertexData &data, const ParticleGroup &group); + bool may_have_same_color(const VertexData &data, const Index &idx); + bool may_have_same_color(const VertexData &data, const ProtoBundle &bundle); +}; + +} // namespace sequant + +#endif diff --git a/SeQuant/core/vertex_type.hpp b/SeQuant/core/vertex_type.hpp new file mode 100644 index 000000000..3a82d80d0 --- /dev/null +++ b/SeQuant/core/vertex_type.hpp @@ -0,0 +1,18 @@ +#ifndef SEQUANT_VERTEX_TYPE_H +#define SEQUANT_VERTEX_TYPE_H + +namespace sequant { + +enum class VertexType { + Index, + SPBundle, + TensorBra, + TensorKet, + TensorAux, + TensorCore, + Particle, +}; + +} + +#endif diff --git a/SeQuant/core/wick.impl.hpp b/SeQuant/core/wick.impl.hpp index 84c69094e..59e0a8fd0 100644 --- a/SeQuant/core/wick.impl.hpp +++ b/SeQuant/core/wick.impl.hpp @@ -7,7 +7,9 @@ #include #include +#include #include +#include #ifdef SEQUANT_HAS_EXECUTION_HEADER #include @@ -238,7 +240,7 @@ inline bool apply_index_replacement_rules( const auto &factor = *it; if (factor->is()) { auto &tensor = factor->as(); - assert(ranges::none_of(tensor.const_braket(), [](const Index &idx) { + assert(ranges::none_of(tensor.const_indices(), [](const Index &idx) { return idx.tag().has_value(); })); } @@ -363,7 +365,7 @@ void reduce_wick_impl(std::shared_ptr &expr, std::set all_indices; ranges::for_each(*expr, [&all_indices](const auto &factor) { if (factor->template is()) { - ranges::for_each(factor->template as().braket(), + ranges::for_each(factor->template as().indices(), [&all_indices](const Index &idx) { [[maybe_unused]] auto result = all_indices.insert(idx); @@ -640,14 +642,13 @@ ExprPtr WickTheorem::compute(const bool count_only, // ordinal can be computed by counting std::size_t nop_ord = 0; for (size_t v = 0; v != n; ++v) { - if (vtypes[v] == TensorNetwork::VertexType::TensorCore && + if (vtypes[v] == VertexType::TensorCore && (std::find(nop_labels_begin, nop_labels_end, vlabels[v]) != nop_labels_end)) { auto insertion_result = nop_vidx_ord.emplace(v, nop_ord++); assert(insertion_result.second); } - if (vtypes[v] == TensorNetwork::VertexType::Index && - !input_->empty()) { + if (vtypes[v] == VertexType::Index && !input_->empty()) { auto &idx = (tn_edges.begin() + v)->idx(); auto idx_it_in_opseq = ranges::find_if( opseq_view, @@ -882,7 +883,7 @@ ExprPtr WickTheorem::compute(const bool count_only, auto exclude_index_vertex_pair = [&tn_tensors, &tn_edges](size_t v1, size_t v2) { // v1 and v2 are vertex indices and also index the edges in the - // TensorNetwork + // WickGraph assert(v1 < tn_edges.size()); assert(v2 < tn_edges.size()); const auto &edge1 = *(tn_edges.begin() + v1); diff --git a/SeQuant/domain/eval/eval.cpp b/SeQuant/domain/eval/eval.cpp index 1723e08b7..1f7157945 100644 --- a/SeQuant/domain/eval/eval.cpp +++ b/SeQuant/domain/eval/eval.cpp @@ -47,7 +47,8 @@ EvalExprBTAS::annot_t const& EvalExprBTAS::annot() const noexcept { } EvalExprBTAS::EvalExprBTAS(Tensor const& t) noexcept - : EvalExpr{t}, annot_{index_hash(t.const_braket()) | ranges::to} {} + : EvalExpr{t}, + annot_{index_hash(t.const_indices()) | ranges::to} {} EvalExprBTAS::EvalExprBTAS(Constant const& c) noexcept : EvalExpr{c} {} @@ -59,7 +60,7 @@ EvalExprBTAS::EvalExprBTAS(EvalExprBTAS const& left, // : EvalExpr{left, right, op} { if (result_type() == ResultType::Tensor) { assert(!tot() && "Tensor of tensor not supported in BTAS"); - annot_ = index_hash(as_tensor().const_braket()) | ranges::to; + annot_ = index_hash(as_tensor().const_indices()) | ranges::to; } } diff --git a/SeQuant/domain/mbpt/convention.cpp b/SeQuant/domain/mbpt/convention.cpp index f2ea93952..b88494735 100644 --- a/SeQuant/domain/mbpt/convention.cpp +++ b/SeQuant/domain/mbpt/convention.cpp @@ -8,7 +8,14 @@ #include #include -#include +#include + +#include +#include +#include +#include +#include +#include namespace sequant { namespace mbpt { diff --git a/SeQuant/domain/mbpt/op.ipp b/SeQuant/domain/mbpt/op.ipp index 7cfa6fa1b..ba0f97cdb 100644 --- a/SeQuant/domain/mbpt/op.ipp +++ b/SeQuant/domain/mbpt/op.ipp @@ -5,6 +5,7 @@ #ifndef SEQUANT_DOMAIN_MBPT_OP_IPP #define SEQUANT_DOMAIN_MBPT_OP_IPP +#include #include namespace sequant { @@ -149,13 +150,19 @@ ExprPtr Operator::clone() const { return ex(*this); } -// Expresses general operators in human interpretable form. for example: \hat{T}_2 is a particle conserving 2-body excitation operator -// a non-particle conserving operator \hat{R}_2_1 implies that two particles are created followed by a single hole creation. -// conversely \hat{R}_1_2 implies the that only one particle is annihilated followed by two holes being created. -// The rule being, that for non-particle conserving operators, the first position indicates where the quasiparticle is going to and the second position indicates where it comes from. -// for the case of adjoint operators, the adjoint is represented by the symbol ⁺ and superscripting the quasi-particle numbers. for example: hat{R⁺}^{1,2}} -// For operators in which one or more quasi-particles has only partial coverage in the particle_space or hole_space, this notation is unsuitable, and we default to -// level printing of the operator. +// Expresses general operators in human interpretable form. for example: +// \hat{T}_2 is a particle conserving 2-body excitation operator a non-particle +// conserving operator \hat{R}_2_1 implies that two particles are created +// followed by a single hole creation. conversely \hat{R}_1_2 implies the that +// only one particle is annihilated followed by two holes being created. The +// rule being, that for non-particle conserving operators, the first position +// indicates where the quasiparticle is going to and the second position +// indicates where it comes from. for the case of adjoint operators, the adjoint +// is represented by the symbol ⁺ and superscripting the quasi-particle numbers. +// for example: hat{R⁺}^{1,2}} For operators in which one or more +// quasi-particles has only partial coverage in the particle_space or +// hole_space, this notation is unsuitable, and we default to level printing of +// the operator. template std::wstring Operator::to_latex() const { return sequant::to_latex(*this); diff --git a/SeQuant/domain/mbpt/spin.cpp b/SeQuant/domain/mbpt/spin.cpp index 79e1304ee..8fe29abfa 100644 --- a/SeQuant/domain/mbpt/spin.cpp +++ b/SeQuant/domain/mbpt/spin.cpp @@ -199,7 +199,7 @@ ExprPtr remove_spin(const ExprPtr& expr) { } } Tensor result(tensor.label(), bra(std::move(b)), ket(std::move(k)), - tensor.symmetry(), tensor.braket_symmetry()); + tensor.aux(), tensor.symmetry(), tensor.braket_symmetry()); return std::make_shared(std::move(result)); }; @@ -287,7 +287,7 @@ ExprPtr expand_antisymm(const Tensor& tensor, bool skip_spinsymm) { assert(tensor.bra_rank() == tensor.ket_rank()); // Return non-symmetric tensor if rank is 1 if (tensor.bra_rank() == 1) { - Tensor new_tensor(tensor.label(), tensor.bra(), tensor.ket(), + Tensor new_tensor(tensor.label(), tensor.bra(), tensor.ket(), tensor.aux(), Symmetry::nonsymm, tensor.braket_symmetry(), tensor.particle_symmetry()); return std::make_shared(new_tensor); @@ -320,7 +320,7 @@ ExprPtr expand_antisymm(const Tensor& tensor, bool skip_spinsymm) { do { // N.B. must copy auto new_tensor = Tensor(tensor.label(), bra(bra_list), ket(ket_list), - Symmetry::nonsymm); + tensor.aux(), Symmetry::nonsymm); if (spin_symm_tensor(new_tensor)) { auto new_tensor_product = std::make_shared(); @@ -525,7 +525,8 @@ ExprPtr symmetrize_expr(const Product& product) { auto S = Tensor{}; if (A_is_nconserving) { - S = Tensor(L"S", A_tensor.bra(), A_tensor.ket(), Symmetry::nonsymm); + S = Tensor(L"S", A_tensor.bra(), A_tensor.ket(), A_tensor.aux(), + Symmetry::nonsymm); } else { // A is N-nonconserving auto n = std::min(A_tensor.bra_rank(), A_tensor.ket_rank()); container::svector bra_list(A_tensor.bra().begin(), @@ -533,7 +534,7 @@ ExprPtr symmetrize_expr(const Product& product) { container::svector ket_list(A_tensor.ket().begin(), A_tensor.ket().begin() + n); S = Tensor(L"S", bra(std::move(bra_list)), ket(std::move(ket_list)), - Symmetry::nonsymm); + A_tensor.aux(), Symmetry::nonsymm); } // Generate replacement maps from a list of Index type (could be a bra or a @@ -1080,7 +1081,7 @@ Tensor swap_spin(const Tensor& t) { k.at(i) = spin_flipped_idx(t.ket().at(i)); } - return {t.label(), bra(std::move(b)), ket(std::move(k)), + return {t.label(), bra(std::move(b)), ket(std::move(k)), t.aux(), t.symmetry(), t.braket_symmetry(), t.particle_symmetry()}; } @@ -1118,7 +1119,8 @@ ExprPtr merge_tensors(const Tensor& O1, const Tensor& O2) { assert(O1.symmetry() == O2.symmetry()); auto b = ranges::views::concat(O1.bra(), O2.bra()); auto k = ranges::views::concat(O1.ket(), O2.ket()); - return ex(Tensor(O1.label(), bra(b), ket(k), O1.symmetry())); + auto a = ranges::views::concat(O1.aux(), O2.aux()); + return ex(Tensor(O1.label(), bra(b), ket(k), aux(a), O1.symmetry())); } std::vector open_shell_A_op(const Tensor& A) { @@ -1143,8 +1145,8 @@ std::vector open_shell_A_op(const Tensor& A) { make_spinbeta); ranges::for_each(spin_bra, [](const Index& i) { i.reset_tag(); }); ranges::for_each(spin_ket, [](const Index& i) { i.reset_tag(); }); - result.at(i) = - ex(Tensor(L"A", spin_bra, spin_ket, Symmetry::antisymm)); + result.at(i) = ex( + Tensor(L"A", spin_bra, spin_ket, A.aux(), Symmetry::antisymm)); // std::wcout << to_latex(result.at(i)) << " "; } // std::wcout << "\n" << std::endl; diff --git a/SeQuant/domain/mbpt/vac_av.ipp b/SeQuant/domain/mbpt/vac_av.ipp index 0ce50c527..6e5657ccb 100644 --- a/SeQuant/domain/mbpt/vac_av.ipp +++ b/SeQuant/domain/mbpt/vac_av.ipp @@ -124,8 +124,7 @@ ExprPtr vac_av( } else if (expr.is() || expr.is()) { return expr; // vacuum is normalized } - throw std::invalid_argument( - "mpbt::*::vac_av(expr): unknown expression type"); + throw std::invalid_argument("mpbt::*::vac_av(expr): unknown expression type"); } ExprPtr vac_av( diff --git a/SeQuant/external/bliss/graph.hh b/SeQuant/external/bliss/graph.hh index c0db33dae..ec5d987fb 100644 --- a/SeQuant/external/bliss/graph.hh +++ b/SeQuant/external/bliss/graph.hh @@ -35,6 +35,7 @@ class AbstractGraph; #include #include #include +#include #include "bignum.hh" #include "heap.hh" #include "kqueue.hh" @@ -693,6 +694,8 @@ class Graph : public AbstractGraph { auto int_to_rgb = [](unsigned int i) { std::basic_stringstream stream; + // Set locale of this stream to C to avoid any kind of thousands separator + stream.imbue(std::locale::classic()); stream << std::setfill(Char('0')) << std::setw(6) << std::hex << ((i << 8) >> 8); return stream.str(); @@ -711,7 +714,8 @@ class Graph : public AbstractGraph { } else os << vnum; if (rgb_colors) { - os << "\"; color=\"#" << int_to_rgb(v.color) << "\"];\n"; + auto color = int_to_rgb(v.color); + os << "\"; style=filled; color=\"#" << color << "\"; fillcolor=\"#" << color << "80\"; penwidth=2];\n"; } else { os << ":" << v.color << "\"];\n"; } diff --git a/examples/eval/btas/main.cpp b/examples/eval/btas/main.cpp index ae3f7f11c..c6ced0cd6 100644 --- a/examples/eval/btas/main.cpp +++ b/examples/eval/btas/main.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include diff --git a/examples/eval/ta/main.cpp b/examples/eval/ta/main.cpp index 2f961fbfa..187c53745 100644 --- a/examples/eval/ta/main.cpp +++ b/examples/eval/ta/main.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include diff --git a/examples/osstcc/osstcc.cpp b/examples/osstcc/osstcc.cpp index e2a243f4a..51a0742f8 100644 --- a/examples/osstcc/osstcc.cpp +++ b/examples/osstcc/osstcc.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include diff --git a/examples/stcc/stcc.cpp b/examples/stcc/stcc.cpp index 3369ecb73..551371769 100644 --- a/examples/stcc/stcc.cpp +++ b/examples/stcc/stcc.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include diff --git a/examples/stcc_rigorous/stcc_rigorous.cpp b/examples/stcc_rigorous/stcc_rigorous.cpp index a3aa1751b..b1573d963 100644 --- a/examples/stcc_rigorous/stcc_rigorous.cpp +++ b/examples/stcc_rigorous/stcc_rigorous.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include diff --git a/examples/tensor_network_graphs/tensor_network_graphs.cpp b/examples/tensor_network_graphs/tensor_network_graphs.cpp new file mode 100644 index 000000000..3a6fd7062 --- /dev/null +++ b/examples/tensor_network_graphs/tensor_network_graphs.cpp @@ -0,0 +1,133 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +using namespace sequant; + +std::wstring from_utf8(std::string_view str) { + std::wstring_convert> converter; + return converter.from_bytes(std::string(str)); +} + +std::optional to_network(const ExprPtr &expr) { + if (expr.is()) { + return TensorNetwork({expr}); + } else if (expr.is()) { + for (const ExprPtr &factor : expr.as().factors()) { + if (!factor.is()) { + return {}; + } + } + + return TensorNetwork(expr.as().factors()); + } else { + return {}; + } +} + +std::optional to_network_v2(const ExprPtr &expr) { + if (expr.is()) { + return TensorNetworkV2({expr}); + } else if (expr.is()) { + for (const ExprPtr &factor : expr.as().factors()) { + if (!factor.is()) { + return {}; + } + } + + return TensorNetworkV2(expr.as().factors()); + } else { + return {}; + } +} + +void print_help() { + std::wcout << "Helper to generate dot (GraphViz) representations of tensor " + "network graphs.\n"; + std::wcout << "Usage:\n"; + std::wcout + << " [options] [ [... [] ] ]\n"; + std::wcout << "Options:\n"; + std::wcout << " --help Shows this help message\n"; + std::wcout << " --v2 Use TensorNetworkV2\n"; + std::wcout << " --no-named Treat all indices as unnamed (even if they are " + "external)\n"; +} + +int main(int argc, char **argv) { + set_locale(); + sequant::set_default_context(Context( + mbpt::make_sr_spaces(), Vacuum::SingleProduct, IndexSpaceMetric::Unit, + BraKetSymmetry::conjugate, SPBasis::spinorbital)); + + bool use_named_indices = true; + bool use_tnv2 = false; + const TensorNetwork::named_indices_t empty_named_indices; + + if (argc <= 1) { + print_help(); + return 0; + } + + for (std::size_t i = 1; i < static_cast(argc); ++i) { + std::wstring current = from_utf8(argv[i]); + if (current == L"--help") { + print_help(); + return 0; + } else if (current == L"--no-named") { + use_named_indices = false; + continue; + } else if (current == L"--v2") { + use_tnv2 = true; + continue; + } + + ExprPtr expr; + try { + expr = parse_expr(current); + } catch (const ParseError &e) { + std::wcout << "Failed to parse expression '" << current + << "': " << e.what() << std::endl; + return 1; + } + assert(expr); + + if (!use_tnv2) { + std::optional network = to_network(expr); + if (!network.has_value()) { + std::wcout << "Failed to construct tensor network for input '" + << current << "'" << std::endl; + return 2; + } + + auto [graph, vlabels, vcolors, vtypes] = network->make_bliss_graph( + use_named_indices ? nullptr : &empty_named_indices); + std::wcout << "Graph for '" << current << "'\n"; + graph->write_dot(std::wcout, vlabels); + } else { + std::optional network = to_network_v2(expr); + if (!network.has_value()) { + std::wcout << "Failed to construct tensor network for input '" + << current << "'" << std::endl; + return 2; + } + + TensorNetworkV2::Graph graph = network->create_graph( + use_named_indices ? nullptr : &empty_named_indices); + std::wcout << "Graph for '" << current << "'\n"; + graph.bliss_graph->write_dot(std::wcout, graph.vertex_labels); + } + } +} diff --git a/python/src/sequant/_sequant.cc b/python/src/sequant/_sequant.cc index 5990c1eb8..1f553bb43 100644 --- a/python/src/sequant/_sequant.cc +++ b/python/src/sequant/_sequant.cc @@ -25,8 +25,10 @@ inline std::vector make_index(std::vector labels) { std::shared_ptr make_tensor(std::wstring label, std::vector bra, - std::vector ket) { - return std::make_shared(label, make_index(bra), make_index(ket)); + std::vector ket, + std::vector auxiliary) { + return std::make_shared(label, make_index(bra), make_index(ket), + make_index(auxiliary)); } std::shared_ptr make_constant(py::float_ number) { @@ -135,9 +137,16 @@ PYBIND11_MODULE(_sequant, m) { .def_property_readonly("label", &Tensor::label) .def_property_readonly("bra", &Tensor::bra) .def_property_readonly("ket", &Tensor::ket) - .def_property_readonly("braket", [](const Tensor &t) { - auto braket = t.braket(); - return std::vector(braket.begin(), braket.end()); + .def_property_readonly("auxikiary", &Tensor::auxiliary) + .def_property_readonly("braket", + [](const Tensor &t) { + auto braket = t.braket(); + return std::vector(braket.begin(), + braket.end()); + }) + .def_property_readonly("indices", [](const Tensor &t) { + auto indices = t.indices(); + return std::vector(indices.begin(), indices.end()); }); py::class_>(m, "zRational") diff --git a/python/src/sequant/mbpt.h b/python/src/sequant/mbpt.h index 07afc34cb..77db07d04 100644 --- a/python/src/sequant/mbpt.h +++ b/python/src/sequant/mbpt.h @@ -7,6 +7,7 @@ #include #include #include +#include #include "python.h" diff --git a/tests/unit/catch2_sequant.hpp b/tests/unit/catch2_sequant.hpp new file mode 100644 index 000000000..55f22e589 --- /dev/null +++ b/tests/unit/catch2_sequant.hpp @@ -0,0 +1,170 @@ +#ifndef SEQUANT_TESTS_CATCH2_SEQUANT_H +#define SEQUANT_TESTS_CATCH2_SEQUANT_H + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace Catch { + +// Make sure Catch uses proper string representation for SeQuant types + +template <> +struct StringMaker { + static std::string convert(const sequant::Expr &expr, + bool include_canonical = true) { + std::string str; + try { + str = sequant::to_string(sequant::deparse(expr, true)); + } catch (const std::exception &) { + // deparse doesn't support all kinds of expressions -> fall back to LaTeX + // representation + str = sequant::to_string(sequant::to_latex(expr)); + } + + if (include_canonical) { + sequant::ExprPtr clone = expr.clone(); + canonicalize(clone); + simplify(clone); + std::string canon_str = + StringMaker::convert(*clone, false); + + if (canon_str != str) { + str += " (canonicalized: " + canon_str + ")"; + } + } + + return str; + } +}; +template <> +struct StringMaker { + static std::string convert(const sequant::ExprPtr &expr) { + return StringMaker::convert(*expr); + } +}; +template <> +struct StringMaker { + static std::string convert(const sequant::Tensor &tensor) { + return StringMaker::convert(tensor); + } +}; + +template <> +struct StringMaker { + static std::string convert(const sequant::Index &idx) { + return sequant::to_string(idx.full_label()); + } +}; + +} // namespace Catch + +namespace { + +/// Converts the given expression-like object into an actual ExprPtr. +/// It accepts either an actual expression object (as Expr & or ExprPtr) or +/// a (w)string-like object which will then be parsed to yield the actual +/// expression object. +template +sequant::ExprPtr to_expression(T &&expression) { + if constexpr (std::is_convertible_v) { + return sequant::parse_expr( + sequant::to_wstring(std::string(std::forward(expression))), + sequant::Symmetry::nonsymm); + } else if constexpr (std::is_convertible_v) { + return sequant::parse_expr(std::wstring(std::forward(expression)), + sequant::Symmetry::nonsymm); + } else if constexpr (std::is_convertible_v) { + // Clone in order to not have to worry about later modification + return expression.clone(); + } else { + static_assert(std::is_convertible_v, + "Invalid type for expression"); + + // Clone in order to not have to worry about later modification + return expression->clone(); + } +} + +template +class ExpressionMatcher : public Catch::Matchers::MatcherGenericBase { + public: + template + ExpressionMatcher(T &&expression) + : m_expr(to_expression(std::forward(expression))) { + assert(m_expr); + Subclass::pre_comparison(m_expr); + } + + bool match(const sequant::ExprPtr &expr) const { return match(*expr); } + + bool match(const sequant::Expr &expr) const { + // Never modify the expression that we are trying to check in order to avoid + // side-effects + sequant::ExprPtr clone = expr.clone(); + + Subclass::pre_comparison(clone); + + return *clone == *m_expr; + } + + std::string describe() const override { + return Subclass::comparison_requirement() + ": " + + Catch::Detail::stringify(m_expr); + } + + protected: + sequant::ExprPtr m_expr; +}; + +/// Matches that the tested expression is equivalent to the given one. Two +/// expressions are considered equivalent if they both have the same canonical +/// form (i.e. they are the same expression after canonicalization) +struct EquivalentToMatcher : ExpressionMatcher { + using ExpressionMatcher::ExpressionMatcher; + + static void pre_comparison(sequant::ExprPtr &expr) { + sequant::canonicalize(expr); + sequant::simplify(expr); + } + + static std::string comparison_requirement() { return "Equivalent to"; } +}; + +/// Matches that the tested expression simplifies (**without** +/// re-canonicalization!) to the same form as the given one +struct SimplifiesToMatcher : ExpressionMatcher { + using ExpressionMatcher::ExpressionMatcher; + + static void pre_comparison(sequant::ExprPtr &expr) { + sequant::rapid_simplify(expr); + } + + static std::string comparison_requirement() { return "Simplifies to"; } +}; + +} // namespace + +template +EquivalentToMatcher EquivalentTo(T &&expression) { + return EquivalentToMatcher(std::forward(expression)); +} + +template +SimplifiesToMatcher SimplifiesTo(T &&expression) { + return SimplifiesToMatcher(std::forward(expression)); +} + +#endif diff --git a/tests/unit/test_asy_cost.cpp b/tests/unit/test_asy_cost.cpp index 338a870d9..167c81f3d 100644 --- a/tests/unit/test_asy_cost.cpp +++ b/tests/unit/test_asy_cost.cpp @@ -3,6 +3,8 @@ #include #include +#include "catch2_sequant.hpp" + #include #include #include diff --git a/tests/unit/test_binary_node.cpp b/tests/unit/test_binary_node.cpp index 4a1b53fa2..27e04e149 100644 --- a/tests/unit/test_binary_node.cpp +++ b/tests/unit/test_binary_node.cpp @@ -2,6 +2,8 @@ #include +#include "catch2_sequant.hpp" + #include #include #include diff --git a/tests/unit/test_bliss.cpp b/tests/unit/test_bliss.cpp index eba806aea..2dadef649 100644 --- a/tests/unit/test_bliss.cpp +++ b/tests/unit/test_bliss.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include diff --git a/tests/unit/test_cache_manager.cpp b/tests/unit/test_cache_manager.cpp index a0595058b..744cc06ea 100644 --- a/tests/unit/test_cache_manager.cpp +++ b/tests/unit/test_cache_manager.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + namespace sequant { struct TestCacheManager {}; diff --git a/tests/unit/test_canonicalize.cpp b/tests/unit/test_canonicalize.cpp index 7e5d868a7..77284c749 100644 --- a/tests/unit/test_canonicalize.cpp +++ b/tests/unit/test_canonicalize.cpp @@ -1,5 +1,7 @@ #include +#include "catch2_sequant.hpp" + #include #include #include @@ -8,6 +10,7 @@ #include #include #include +#include #include #include @@ -29,25 +32,63 @@ TEST_CASE("Canonicalizer", "[algorithms]") { auto op = ex(L"g", bra{L"p_1", L"p_2"}, ket{L"p_3", L"p_4"}, Symmetry::nonsymm); canonicalize(op); - REQUIRE(to_latex(op) == L"{g^{{p_3}{p_4}}_{{p_1}{p_2}}}"); + REQUIRE_THAT(op, SimplifiesTo("g{p1,p2;p3,p4}")); } { auto op = ex(L"g", bra{L"p_2", L"p_1"}, ket{L"p_3", L"p_4"}, Symmetry::nonsymm); canonicalize(op); - REQUIRE(to_latex(op) == L"{g^{{p_4}{p_3}}_{{p_1}{p_2}}}"); + REQUIRE_THAT(op, SimplifiesTo("g{p1,p2;p4,p3}")); } { auto op = ex(L"g", bra{L"p_1", L"p_2"}, ket{L"p_4", L"p_3"}, Symmetry::nonsymm); canonicalize(op); - REQUIRE(to_latex(op) == L"{g^{{p_4}{p_3}}_{{p_1}{p_2}}}"); + REQUIRE_THAT(op, SimplifiesTo("g{p1,p2;p4,p3}")); } { auto op = ex(L"g", bra{L"p_2", L"p_1"}, ket{L"p_4", L"p_3"}, Symmetry::nonsymm); canonicalize(op); - REQUIRE(to_latex(op) == L"{g^{{p_3}{p_4}}_{{p_1}{p_2}}}"); + REQUIRE_THAT(op, SimplifiesTo("g{p1,p2;p3,p4}")); + } + { + auto op = ex(L"g", bra{L"p_1", L"p_2"}, ket{L"p_4", L"p_3"}, + Symmetry::symm); + canonicalize(op); + REQUIRE_THAT(op, SimplifiesTo("g{p1,p2;p3,p4}:S")); + } + { + auto op = ex(L"g", bra{L"p_1", L"p_2"}, ket{L"p_4", L"p_3"}, + Symmetry::antisymm); + canonicalize(op); + REQUIRE_THAT(op, SimplifiesTo("-g{p1,p2;p3,p4}:A")); + } + + // aux indices + { + auto op = ex(L"B", bra{L"p_1"}, ket{L"p_2"}, aux{L"p_3"}, + Symmetry::nonsymm); + canonicalize(op); + REQUIRE_THAT(op, SimplifiesTo("B{p1;p2;p3}")); + } + { + auto op = ex(L"B", bra{L"p_1", L"p_2"}, ket{L"p_4", L"p_3"}, + aux{L"p_5"}, Symmetry::nonsymm); + canonicalize(op); + REQUIRE_THAT(op, SimplifiesTo("B{p1,p2;p4,p3;p5}")); + } + { + auto op = ex(L"B", bra{L"p_1", L"p_2"}, ket{L"p_4", L"p_3"}, + aux{L"p_5"}, Symmetry::symm); + canonicalize(op); + REQUIRE_THAT(op, SimplifiesTo("B{p1,p2;p3,p4;p5}:S")); + } + { + auto op = ex(L"B", bra{L"p_1", L"p_2"}, ket{L"p_4", L"p_3"}, + aux{L"p_5"}, Symmetry::antisymm); + canonicalize(op); + REQUIRE_THAT(op, SimplifiesTo("-B{p1,p2;p3,p4;p5}:A")); } } @@ -61,9 +102,9 @@ TEST_CASE("Canonicalizer", "[algorithms]") { ex(L"t", bra{L"i_1", L"i_2"}, ket{L"a_5", L"a_2"}, Symmetry::nonsymm); canonicalize(input); - REQUIRE(to_latex(input) == - L"{{S^{{i_1}{i_2}}_{{a_1}{a_3}}}{f^{{i_3}}_{{" - L"a_2}}}{t^{{a_3}}_{{i_3}}}{t^{{a_1}{a_2}}_{{i_1}{i_2}}}}"); + REQUIRE_THAT( + input, + SimplifiesTo("S{a1,a2;i1,i2} f{a3;i3} t{i3;a2} t{i1,i2;a1,a3}")); } { auto input = @@ -74,9 +115,9 @@ TEST_CASE("Canonicalizer", "[algorithms]") { ex(L"t", bra{L"i_5", L"i_2"}, ket{L"a_1", L"a_2"}, Symmetry::nonsymm); canonicalize(input); - REQUIRE(to_latex(input) == - L"{{S^{{i_1}{i_2}}_{{a_1}{a_3}}}{f^{{i_3}}_{{" - L"a_2}}}{t^{{a_2}}_{{i_2}}}{t^{{a_1}{a_3}}_{{i_1}{i_3}}}}"); + REQUIRE_THAT( + input, + SimplifiesTo("S{a1,a3;i1,i2} f{a2;i3} t{i_2;a_2} t{i1,i3;a1,a3}")); } { // Product containing Variables auto q2 = ex(L"q2"); @@ -91,9 +132,10 @@ TEST_CASE("Canonicalizer", "[algorithms]") { ex(L"t", bra{L"i_5", L"i_2"}, ket{L"a_1", L"a_2"}, Symmetry::nonsymm); canonicalize(input); - REQUIRE(to_latex(input) == - L"{{p}{q1}{{q2}^*}{S^{{i_1}{i_2}}_{{a_1}{a_3}}}{f^{{i_3}}_{{" - L"a_2}}}{t^{{a_2}}_{{i_2}}}{t^{{a_1}{a_3}}_{{i_1}{i_3}}}}"); + REQUIRE_THAT( + input, + SimplifiesTo( + "p q1 q2^* S{a1,a3;i1,i2} f{a2;i3} t {i2;a2} t{i1,i3;a1,a3}")); } { // Product containing adjoint of a Tensor auto f2 = ex(L"f", bra{L"i_5", L"i_2"}, ket{L"a_1", L"a_2"}, @@ -105,9 +147,9 @@ TEST_CASE("Canonicalizer", "[algorithms]") { ex(L"f", bra{L"a_5"}, ket{L"i_5"}, Symmetry::nonsymm) * ex(L"t", bra{L"i_1"}, ket{L"a_5"}, Symmetry::nonsymm) * f2; canonicalize(input1); - REQUIRE(to_latex(input1) == - L"{{S^{{i_1}{i_2}}_{{a_1}{a_3}}}{f^{{i_3}}_{{a_2}}}{f⁺^{{i_1}{i_" - L"3}}_{{a_1}{a_3}}}{t^{{a_2}}_{{i_2}}}}"); + REQUIRE_THAT( + input1, + SimplifiesTo("S{a1,a2;i1,i3} f{a3;i2} f⁺{a1,a2;i1,i2} t{i3;a3}")); auto input2 = ex(L"S", bra{L"a_1", L"a_2"}, ket{L"i_1", L"i_2"}, Symmetry::nonsymm) * @@ -115,11 +157,24 @@ TEST_CASE("Canonicalizer", "[algorithms]") { ex(L"t", bra{L"i_1"}, ket{L"a_5"}, Symmetry::nonsymm) * f2 * ex(L"w") * ex(rational{1, 2}); canonicalize(input2); - REQUIRE(to_latex(input2) == - L"{{{\\frac{1}{2}}}{w}{S^{{i_1}{i_2}}_{{a_1}{a_3}}}{f^{{i_3}}_{{" - L"a_2}}}{f⁺^{{i_1}{i_3}}_{{a_1}{a_3}}}{t^{{a_2}}_{{i_2}}}}"); + REQUIRE_THAT( + input2, + SimplifiesTo( + "1/2 w S{a1,a2;i1,i3} f{a3;i2} f⁺{a1,a2;i1,i2} t{i3;a3}")); } } + { + auto input = ex(rational{1, 2}) * + ex(L"B", bra{L"p_2"}, ket{L"p_4"}, aux{L"p_5"}, + Symmetry::nonsymm) * + ex(L"B", bra{L"p_1"}, ket{L"p_3"}, aux{L"p_5"}, + Symmetry::nonsymm) * + ex(L"t", bra{L"p_4"}, ket{L"p_2"}, Symmetry::nonsymm) * + ex(L"t", bra{L"p_3"}, ket{L"p_1"}, Symmetry::nonsymm); + canonicalize(input); + REQUIRE_THAT(input, + SimplifiesTo("1/2 t{p1;p3} t{p2;p4} B{p3;p1;p5} B{p4;p2;p5}")); + } SECTION("sum of products") { { @@ -135,11 +190,9 @@ TEST_CASE("Canonicalizer", "[algorithms]") { Symmetry::nonsymm) * ex(L"t", bra{L"p_3"}, ket{L"p_1"}, Symmetry::nonsymm) * ex(L"t", bra{L"p_4"}, ket{L"p_2"}, Symmetry::nonsymm); + simplify(input); canonicalize(input); - REQUIRE(to_latex(input) == - L"{ " - L"\\bigl({{g^{{p_1}{p_4}}_{{p_2}{p_3}}}{t^{{p_2}}_{{p_1}}}{t^{{" - L"p_3}}_{{p_4}}}}\\bigr) }"); + REQUIRE_THAT(input, SimplifiesTo("g{p3,p4;p1,p2} t{p1;p3} t{p2;p4}")); } // CASE 2: Symmetric tensors @@ -156,10 +209,7 @@ TEST_CASE("Canonicalizer", "[algorithms]") { ex(L"t", bra{L"p_3"}, ket{L"p_1"}, Symmetry::nonsymm) * ex(L"t", bra{L"p_4"}, ket{L"p_2"}, Symmetry::nonsymm); canonicalize(input); - REQUIRE(to_latex(input) == - L"{ " - L"\\bigl({{g^{{p_1}{p_4}}_{{p_2}{p_3}}}{t^{{p_2}}_{{p_1}}}{t^{{p_" - L"3}}_{{p_4}}}}\\bigr) }"); + REQUIRE_THAT(input, SimplifiesTo("g{p2,p3;p1,p4}:S t{p1;p2} t{p4;p3}")); } // Case 3: Anti-symmetric tensors @@ -176,11 +226,7 @@ TEST_CASE("Canonicalizer", "[algorithms]") { ex(L"t", bra{L"p_3"}, ket{L"p_1"}, Symmetry::nonsymm) * ex(L"t", bra{L"p_4"}, ket{L"p_2"}, Symmetry::nonsymm); canonicalize(input); - REQUIRE(to_latex(input) == - L"{ " - L"\\bigl({{\\bar{g}^{{p_1}{p_4}}_{{p_2}{p_3}}}{t^{{p_2}}_{{p_1}}}" - L"{t^{{p_" - L"3}}_{{p_4}}}}\\bigr) }"); + REQUIRE_THAT(input, SimplifiesTo("g{p2,p3;p1,p4}:A t{p1;p2} t{p4;p3}")); } // Case 4: permuted indices @@ -200,10 +246,8 @@ TEST_CASE("Canonicalizer", "[algorithms]") { Symmetry::antisymm); canonicalize(input); REQUIRE(input->size() == 1); - REQUIRE(to_latex(input) == - L"{ " - "\\bigl({{\\bar{g}^{{i_1}{a_3}}_{{i_3}{i_4}}}{t^{{i_3}}_{{a_2}}}{" - "\\bar{t}^{{i_2}{i_4}}_{{a_1}{a_3}}}}\\bigr) }"); + REQUIRE_THAT(input, + SimplifiesTo("g{i3,i4;i1,a3}:A t{a2;i3} t{a1,a3;i2,i4}:A")); } // Case 4: permuted indices from CCSD R2 biorthogonal configuration @@ -224,10 +268,8 @@ TEST_CASE("Canonicalizer", "[algorithms]") { canonicalize(input); REQUIRE(input->size() == 1); - REQUIRE(to_latex(input) == - L"{ " - L"\\bigl({{g^{{a_3}{i_1}}_{{i_3}{i_4}}}{t^{{i_3}}_{{a_2}}}{t^{{i_" - L"4}{i_2}}_{{a_1}{a_3}}}}\\bigr) }"); + REQUIRE_THAT(input, + SimplifiesTo("g{i3,i4;i1,a3} t{a2;i4} t{a1,a3;i3,i2}")); } { // Case 5: CCSDT R3: S3 * F * T3 @@ -247,10 +289,10 @@ TEST_CASE("Canonicalizer", "[algorithms]") { ex(L"t", bra{L"a_1", L"a_2", L"a_3"}, ket{L"i_2", L"i_4", L"i_3"}, Symmetry::nonsymm); canonicalize(input); - REQUIRE(to_latex(input) == - L"{ \\bigl( - " - L"{{{8}}{S^{{a_1}{a_2}{a_3}}_{{i_1}{i_2}{i_3}}}{f^{{i_3}}_{{i_" - L"4}}}{t^{{i_1}{i_4}{i_2}}_{{a_1}{a_2}{a_3}}}}\\bigr) }"); + REQUIRE_THAT( + input, + SimplifiesTo( + "-8 S{i1,i2,i3;a1,a2,a3} f{i4;i3} t{a1,a2,a3;i1,i4,i2}")); } { @@ -270,17 +312,20 @@ TEST_CASE("Canonicalizer", "[algorithms]") { ket{L"i_2", L"i_4", L"i_3"}, Symmetry::nonsymm); canonicalize(term1); canonicalize(term2); - REQUIRE(to_latex(term1) == - L"{{{-4}}{S^{{a_1}{a_2}{a_3}}_{{i_1}{i_3}{i_4}}}{f^{{i_4}}_{{i_" - L"2}}}{t^{{i_1}{i_2}{i_3}}_{{a_1}{a_2}{a_3}}}}"); - REQUIRE(to_latex(term2) == - L"{{{-4}}{S^{{a_1}{a_2}{a_3}}_{{i_1}{i_3}{i_4}}}{f^{{i_4}}_{{i_" - L"2}}}{t^{{i_1}{i_2}{i_3}}_{{a_1}{a_2}{a_3}}}}"); + REQUIRE_THAT( + term1, + SimplifiesTo( + "-4 S{i1,i3,i4;a1,a2,a3} f{i2;i4} t{a1,a2,a3;i1,i2,i3}")); + REQUIRE_THAT( + term2, + SimplifiesTo( + "-4 S{i1,i3,i4;a1,a2,a3} f{i2;i4} t{a1,a2,a3;i1,i2,i3}")); auto sum_of_terms = term1 + term2; simplify(sum_of_terms); - REQUIRE(to_latex(sum_of_terms) == - L"{{{-8}}{S^{{a_1}{a_2}{a_3}}_{{i_1}{i_2}{i_3}}}{f^{{i_3}}_{{i_" - L"4}}}{t^{{i_1}{i_4}{i_2}}_{{a_1}{a_2}{a_3}}}}"); + REQUIRE_THAT( + sum_of_terms, + SimplifiesTo( + "-8 S{i1,i2,i3;a1,a2,a3} f{i4;i3} t{a1,a2,a3;i1,i4,i2}")); } { // Terms 2 and 4 from spin-traced result @@ -298,11 +343,37 @@ TEST_CASE("Canonicalizer", "[algorithms]") { ex(L"t", bra{L"a_1", L"a_2", L"a_3"}, ket{L"i_2", L"i_3", L"i_4"}, Symmetry::nonsymm); canonicalize(input); - REQUIRE(to_latex(input) == - L"{ " - L"\\bigl({{{4}}{S^{{a_1}{a_2}{a_3}}_{{i_1}{i_2}{i_3}}}{f^{{i_3}" - L"}_{{i_4}}}{t^{{i_4}{i_1}{i_2}}_{{a_1}{a_2}{a_3}}}}\\bigr) }"); + REQUIRE_THAT( + input, SimplifiesTo( + "4 S{i1,i2,i3;a1,a2,a3} f{i4;i3} t{a1,a2,a3;i4,i1,i2}")); } } + + // Case 6: Case 4 w/ aux indices + { + auto input = + ex(rational{4, 3}) * + ex(L"B", bra{L"i_3"}, ket{L"a_3"}, aux{L"p_5"}, + Symmetry::nonsymm) * + ex(L"B", bra{L"i_4"}, ket{L"i_1"}, aux{L"p_5"}, + Symmetry::nonsymm) * + ex(L"t", bra{L"a_2"}, ket{L"i_3"}, Symmetry::nonsymm) * + ex(L"t", bra{L"a_1", L"a_3"}, ket{L"i_4", L"i_2"}, + Symmetry::nonsymm) - + ex(rational{1, 3}) * + ex(L"B", bra{L"i_3"}, ket{L"i_1"}, aux{L"p_5"}, + Symmetry::nonsymm) * + ex(L"B", bra{L"i_4"}, ket{L"a_3"}, aux{L"p_5"}, + Symmetry::nonsymm) * + ex(L"t", bra{L"a_2"}, ket{L"i_4"}, Symmetry::nonsymm) * + ex(L"t", bra{L"a_1", L"a_3"}, ket{L"i_3", L"i_2"}, + Symmetry::nonsymm); + + canonicalize(input); + simplify(input); + REQUIRE_THAT( + input, + SimplifiesTo("t{a2;i4} t{a1,a3;i3,i2} B{i3;i1;p5} B{i4;a3;p5}")); + } } } diff --git a/tests/unit/test_eval_btas.cpp b/tests/unit/test_eval_btas.cpp index cb365c48c..9282551c2 100644 --- a/tests/unit/test_eval_btas.cpp +++ b/tests/unit/test_eval_btas.cpp @@ -1,6 +1,8 @@ #include #include +#include "catch2_sequant.hpp" + #include #include #include diff --git a/tests/unit/test_eval_expr.cpp b/tests/unit/test_eval_expr.cpp index 785faf6ab..5a8b6b85c 100644 --- a/tests/unit/test_eval_expr.cpp +++ b/tests/unit/test_eval_expr.cpp @@ -1,5 +1,7 @@ #include +#include "catch2_sequant.hpp" + #include #include #include @@ -9,6 +11,7 @@ #include #include #include +#include #include #include diff --git a/tests/unit/test_eval_node.cpp b/tests/unit/test_eval_node.cpp index 224d6e4e8..af8cb2898 100644 --- a/tests/unit/test_eval_node.cpp +++ b/tests/unit/test_eval_node.cpp @@ -1,5 +1,7 @@ #include +#include "catch2_sequant.hpp" + #include #include #include @@ -9,6 +11,7 @@ #include #include #include +#include #include #include @@ -20,21 +23,8 @@ #include -#include namespace { -// validates if x is constructible from tspec using parse_expr -auto validate_eval_expr = [](const sequant::EvalExpr& x, - std::wstring_view tspec) -> bool { - return x.to_latex() == - sequant::parse_expr(tspec, sequant::Symmetry::antisymm)->to_latex(); -}; - -auto validate_tensor = [](const sequant::Tensor& x, - std::wstring_view tspec) -> bool { - return x.to_latex() == sequant::parse_expr(tspec, x.symmetry())->to_latex(); -}; - auto eval_node(sequant::ExprPtr const& expr) { return sequant::eval_node(expr); } @@ -84,24 +74,24 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") { auto node1 = eval_node(p1); - REQUIRE(validate_tensor(node(node1, {}).as_tensor(), L"I_{a1,a2}^{i1,i2}")); + REQUIRE_THAT(node(node1, {}).as_tensor(), EquivalentTo("I{a1,a2;i1,i2}:A")); REQUIRE(node(node1, {R}).as_constant() == Constant{rational{1, 16}}); - REQUIRE( - validate_tensor(node(node1, {L}).as_tensor(), L"I_{a1,a2}^{i1,i2}")); + REQUIRE_THAT(node(node1, {L}).as_tensor(), + EquivalentTo("I{a1,a2;i1,i2}:A")); - REQUIRE( - validate_tensor(node(node1, {L, L}).as_tensor(), L"I_{a1,a2}^{a3,a4}")); + REQUIRE_THAT(node(node1, {L, L}).as_tensor(), + EquivalentTo("I{a1,a2;a3,a4}:A")); - REQUIRE( - validate_tensor(node(node1, {L, R}).as_tensor(), L"t_{a3,a4}^{i1,i2}")); + REQUIRE_THAT(node(node1, {L, R}).as_tensor(), + EquivalentTo("t{a3,a4;i1,i2}:A")); - REQUIRE(validate_tensor(node(node1, {L, L, L}).as_tensor(), - L"g_{i3,i4}^{a3,a4}")); + REQUIRE_THAT(node(node1, {L, L, L}).as_tensor(), + EquivalentTo("g{i3,i4;a3,a4}:A")); - REQUIRE(validate_tensor(node(node1, {L, L, R}).as_tensor(), - L"t_{a1,a2}^{i3,i4}")); + REQUIRE_THAT(node(node1, {L, L, R}).as_tensor(), + EquivalentTo("t{a1,a2;i3,i4}:A")); // 1/16 * A * (B * C) auto node2p = Product{p1->as().scalar(), {}}; @@ -111,24 +101,24 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") { auto const node2 = eval_node(ex(node2p)); - REQUIRE(validate_tensor(node(node2, {}).as_tensor(), L"I_{a1,a2}^{i1,i2}")); + REQUIRE_THAT(node(node2, {}).as_tensor(), EquivalentTo("I{a1,a2;i1,i2}:N")); - REQUIRE( - validate_tensor(node(node2, {L}).as_tensor(), L"I_{a1,a2}^{i1,i2}")); + REQUIRE_THAT(node(node2, {L}).as_tensor(), + EquivalentTo("I{a1,a2;i1,i2}:N")); REQUIRE(node(node2, {R}).as_constant() == Constant{rational{1, 16}}); - REQUIRE( - validate_tensor(node(node2, {L, L}).as_tensor(), L"g{i3,i4; a3,a4}")); + REQUIRE_THAT(node(node2, {L, L}).as_tensor(), + EquivalentTo("g{i3,i4; a3,a4}:A")); - REQUIRE(validate_tensor(node(node2, {L, R}).as_tensor(), - L"I{a1,a2,a3,a4;i3,i4,i1,i2}")); + REQUIRE_THAT(node(node2, {L, R}).as_tensor(), + EquivalentTo("I{a1,a2,a3,a4;i3,i4,i1,i2}:A")); - REQUIRE( - validate_tensor(node(node2, {L, R, L}).as_tensor(), L"t{a1,a2;i3,i4}")); + REQUIRE_THAT(node(node2, {L, R, L}).as_tensor(), + EquivalentTo("t{a1,a2;i3,i4}:A")); - REQUIRE( - validate_tensor(node(node2, {L, R, R}).as_tensor(), L"t{a3,a4;i1,i2}")); + REQUIRE_THAT(node(node2, {L, R, R}).as_tensor(), + EquivalentTo("t{a3,a4;i1,i2}:A")); } SECTION("sum") { @@ -140,20 +130,18 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") { auto const node1 = eval_node(sum1); REQUIRE(node1->op_type() == EvalOp::Sum); REQUIRE(node1.left()->op_type() == EvalOp::Sum); - REQUIRE(validate_tensor(node1.left()->as_tensor(), L"I^{i1,i2}_{a1,a2}")); - REQUIRE(validate_tensor(node1.left().left()->as_tensor(), - L"X^{i1,i2}_{a1,a2}")); - REQUIRE(validate_tensor(node1.left().right()->as_tensor(), - L"Y^{i1,i2}_{a1,a2}")); + REQUIRE_THAT(node1.left()->as_tensor(), EquivalentTo("I{a1,a2;i1,i2}:A")); + REQUIRE_THAT(node1.left().left()->as_tensor(), + EquivalentTo("X{a1,a2;i1,i2}:A")); + REQUIRE_THAT(node1.left().right()->as_tensor(), + EquivalentTo("Y{a1,a2;i1,i2}:A")); REQUIRE(node1.right()->op_type() == EvalOp::Prod); - REQUIRE( - (validate_tensor(node1.right()->as_tensor(), L"I_{a2,a1}^{i1,i2}") || - validate_tensor(node1.right()->as_tensor(), L"I_{a1,a2}^{i2,i1}"))); - REQUIRE(validate_tensor(node1.right().left()->as_tensor(), - L"g_{i3,a1}^{i1,i2}")); - REQUIRE( - validate_tensor(node1.right().right()->as_tensor(), L"t_{a2}^{i3}")); + REQUIRE_THAT(node1.right()->as_tensor(), EquivalentTo("I{a2,a1;i1,i2}:N")); + REQUIRE_THAT(node1.right().left()->as_tensor(), + EquivalentTo("g{i3,a1;i1,i2}:A")); + REQUIRE_THAT(node1.right().right()->as_tensor(), + EquivalentTo("t{a2;i3}:A")); } SECTION("variable") { @@ -185,8 +173,8 @@ TEST_CASE("TEST EVAL_NODE", "[EvalNode]") { auto prod2 = parse_expr(L"a * t{i1;a1}"); auto node3 = eval_node(prod2); - REQUIRE(validate_eval_expr(node(node3, {}), L"I{i1;a1}")); - REQUIRE(validate_eval_expr(node(node3, {R}), L"t{i1;a1}")); + REQUIRE_THAT(node(node3, {}).as_tensor(), EquivalentTo("I{i1;a1}")); + REQUIRE_THAT(node(node3, {R}).as_tensor(), EquivalentTo("t{i1;a1}")); REQUIRE(node(node3, {L}).as_variable() == Variable{L"a"}); } diff --git a/tests/unit/test_eval_ta.cpp b/tests/unit/test_eval_ta.cpp index 699296554..9e7368ad1 100644 --- a/tests/unit/test_eval_ta.cpp +++ b/tests/unit/test_eval_ta.cpp @@ -1,6 +1,8 @@ #include #include +#include "catch2_sequant.hpp" + #include #include #include diff --git a/tests/unit/test_export.cpp b/tests/unit/test_export.cpp index 3a66add1b..727796511 100644 --- a/tests/unit/test_export.cpp +++ b/tests/unit/test_export.cpp @@ -1,5 +1,7 @@ #include +#include "catch2_sequant.hpp" + #include #include @@ -7,22 +9,6 @@ #include #include -namespace Catch { - -// Note: Again, template specialization doesn't seem to be used from inside -// ::Catch::Details::stringify for some reason -template <> -struct StringMaker { - static std::string convert(const sequant::ExprPtr &expr) { - using convert_type = std::codecvt_utf8; - std::wstring_convert converter; - - return converter.to_bytes(sequant::deparse(expr, false)); - } -}; - -} // namespace Catch - std::vector> twoElectronIntegralSymmetries() { // Symmetries of spin-summed (skeleton) two-electron integrals return { diff --git a/tests/unit/test_expr.cpp b/tests/unit/test_expr.cpp index cee9e4db0..b858ce737 100644 --- a/tests/unit/test_expr.cpp +++ b/tests/unit/test_expr.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include #include @@ -641,6 +643,8 @@ TEST_CASE("Expr", "[elements]") { REQUIRE_NOTHROW(hash_value(ex5_init)); REQUIRE(hash_value(ex5_init) != hash_value(ex(1))); + REQUIRE(hash_value(ex(1)) == hash_value(ex(1))); + auto hasher = [](const std::shared_ptr &) -> unsigned int { return 0; }; diff --git a/tests/unit/test_fusion.cpp b/tests/unit/test_fusion.cpp index 4b582c4c6..f3bca6f90 100644 --- a/tests/unit/test_fusion.cpp +++ b/tests/unit/test_fusion.cpp @@ -1,5 +1,7 @@ #include +#include "catch2_sequant.hpp" + #include #include #include diff --git a/tests/unit/test_index.cpp b/tests/unit/test_index.cpp index ffdfab480..1d9f5fbec 100644 --- a/tests/unit/test_index.cpp +++ b/tests/unit/test_index.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include #include @@ -132,9 +134,15 @@ TEST_CASE("Index", "[elements][index]") { // compare by qns, then tag, then space, then label, then proto indices Index i1(L"i_1"); Index i2(L"i_2"); + Index i3(L"i_11"); REQUIRE(i1 < i2); REQUIRE(!(i2 < i1)); REQUIRE(!(i1 < i1)); + REQUIRE(i1 < i3); + REQUIRE(!(i3 < i1)); + REQUIRE(i2 < i3); + REQUIRE(!(i3 < i2)); + REQUIRE(!(i3 < i3)); Index a1(L"a_2"); REQUIRE(i1 < a1); REQUIRE(!(a1 < i1)); diff --git a/tests/unit/test_iterator.cpp b/tests/unit/test_iterator.cpp index 3ea1d0a65..a0093d2f6 100644 --- a/tests/unit/test_iterator.cpp +++ b/tests/unit/test_iterator.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include #include diff --git a/tests/unit/test_latex.cpp b/tests/unit/test_latex.cpp index a0bd275b5..b9c49ca2b 100644 --- a/tests/unit/test_latex.cpp +++ b/tests/unit/test_latex.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include diff --git a/tests/unit/test_main.cpp b/tests/unit/test_main.cpp index d3f5bb3fb..688d69ef6 100644 --- a/tests/unit/test_main.cpp +++ b/tests/unit/test_main.cpp @@ -6,10 +6,13 @@ #include #include +#include "catch2_sequant.hpp" + #include #include #include #include +#include #include #include #include diff --git a/tests/unit/test_math.cpp b/tests/unit/test_math.cpp index 789e4b3ce..1216a52d2 100644 --- a/tests/unit/test_math.cpp +++ b/tests/unit/test_math.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include #include diff --git a/tests/unit/test_mbpt.cpp b/tests/unit/test_mbpt.cpp index 48cea7c71..d18dcf995 100644 --- a/tests/unit/test_mbpt.cpp +++ b/tests/unit/test_mbpt.cpp @@ -9,13 +9,17 @@ #include #include #include +#include #include +#include #include #include #include #include #include +#include +#include "catch2_sequant.hpp" #include "test_config.hpp" #include diff --git a/tests/unit/test_mbpt_cc.cpp b/tests/unit/test_mbpt_cc.cpp index b46bccbef..d82a86c71 100644 --- a/tests/unit/test_mbpt_cc.cpp +++ b/tests/unit/test_mbpt_cc.cpp @@ -7,6 +7,7 @@ #include #include +#include "catch2_sequant.hpp" #include "test_config.hpp" TEST_CASE("SR-TCC", "[mbpt/cc]") { diff --git a/tests/unit/test_meta.cpp b/tests/unit/test_meta.cpp index d36892c7f..bfcaf966c 100644 --- a/tests/unit/test_meta.cpp +++ b/tests/unit/test_meta.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include diff --git a/tests/unit/test_op.cpp b/tests/unit/test_op.cpp index 21a696831..4f5e84f41 100644 --- a/tests/unit/test_op.cpp +++ b/tests/unit/test_op.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include #include diff --git a/tests/unit/test_optimize.cpp b/tests/unit/test_optimize.cpp index 7a42edcf4..3bae4b5de 100644 --- a/tests/unit/test_optimize.cpp +++ b/tests/unit/test_optimize.cpp @@ -1,5 +1,7 @@ #include +#include "catch2_sequant.hpp" + #include #include #include @@ -27,15 +29,18 @@ sequant::ExprPtr extract(sequant::ExprPtr expr, TEST_CASE("TEST_OPTIMIZE", "[optimize]") { using namespace sequant; - // for optimization tests, set occupied and unoccupied index extents + // for optimization tests, set index space sizes { auto reg = get_default_context().mutable_index_space_registry(); auto occ = reg->retrieve_ptr(L"i"); auto uocc = reg->retrieve_ptr(L"a"); + auto aux = reg->retrieve_ptr(L"x"); assert(occ); assert(uocc); + assert(aux); occ->approximate_size(10); uocc->approximate_size(100); + aux->approximate_size(4); assert(uocc->approximate_size() == 100); } @@ -143,6 +148,37 @@ TEST_CASE("TEST_OPTIMIZE", "[optimize]") { REQUIRE(extract(res6, {3, 0}) == prod6.at(3)); REQUIRE(extract(res6, {3, 1}) == prod6.at(5)); REQUIRE(extract(res6, {4}) == prod6.at(4)); + + // + // single-term optimization including tensors with auxiliary indices + // + auto prod7 = parse_expr( + L"DF{a_1;a_3;x_1} " // T1 + "DF{a_2;i_1;x_1} " // T2 + "t{a_3;i_2}" // T3 + ) + ->as(); + auto res7 = single_term_opt(prod7); + + // this is the one we want to find + // (T1 T3) T2: V^2 O^1 A^1 + V^2 O^2 A^1 best if nvirt > nocc and nvirt > + // nact + REQUIRE(extract(res7, {0, 0}) == prod7.at(0)); + REQUIRE(extract(res7, {0, 1}) == prod7.at(2)); + REQUIRE(extract(res7, {1}) == prod7.at(1)); + + auto prod8 = parse_expr( + L"T1{i_1;i_2;x_1,x_2,x_3,x_4} T2{i_2;i_1;x_5,x_6,x_7,x_8} " + L"T3{i_3;;x_1,x_2,x_3,x_4} T4{i_4;;x_5,x_6,x_7,x_8}") + ->as(); + auto res8 = single_term_opt(prod8); + + // this is the one we want to find + // (T1 T3)(T2 T4) + REQUIRE(extract(res8, {0, 0}) == prod8.at(0)); + REQUIRE(extract(res8, {0, 1}) == prod8.at(2)); + REQUIRE(extract(res8, {1, 0}) == prod8.at(1)); + REQUIRE(extract(res8, {1, 1}) == prod8.at(3)); } SECTION("Ensure single-value sums/products are not discarded") { diff --git a/tests/unit/test_parse.cpp b/tests/unit/test_parse.cpp index 8d34bdd6b..d9a91140f 100644 --- a/tests/unit/test_parse.cpp +++ b/tests/unit/test_parse.cpp @@ -1,6 +1,8 @@ #include #include +#include "catch2_sequant.hpp" + #include #include #include @@ -79,9 +81,22 @@ ParseErrorMatcher parseErrorMatches(std::size_t offset, std::size_t length, TEST_CASE("parse_expr", "[parse]") { using namespace sequant; - // use a minimal spinbasis registry + auto ctx_resetter = set_scoped_default_context( - Context(mbpt::make_min_sr_spaces(), Vacuum::SingleProduct)); + Context(mbpt::make_sr_spaces(), Vacuum::SingleProduct)); + + SECTION("Scalar tensor") { + auto expr = parse_expr(L"t{}"); + REQUIRE(expr->is()); + REQUIRE(expr->as().bra().empty()); + REQUIRE(expr->as().ket().empty()); + REQUIRE(expr->as().aux().empty()); + + REQUIRE(expr == parse_expr(L"t{;}")); + REQUIRE(expr == parse_expr(L"t{;;}")); + REQUIRE(expr == parse_expr(L"t^{}_{}")); + REQUIRE(expr == parse_expr(L"t_{}^{}")); + } SECTION("Tensor") { auto expr = parse_expr(L"t{i1;a1}"); REQUIRE(expr->is()); @@ -90,10 +105,12 @@ TEST_CASE("parse_expr", "[parse]") { REQUIRE(expr->as().bra().at(0).label() == L"i_1"); REQUIRE(expr->as().ket().size() == 1); REQUIRE(expr->as().ket().at(0) == L"a_1"); + REQUIRE(expr->as().aux().empty()); REQUIRE(expr == parse_expr(L"t_{i1}^{a1}")); REQUIRE(expr == parse_expr(L"t^{a1}_{i1}")); REQUIRE(expr == parse_expr(L"t{i_1; a_1}")); + REQUIRE(expr == parse_expr(L"t{i_1; a_1;}")); REQUIRE(expr == parse_expr(L"t_{i_1}^{a_1}")); expr = parse_expr(L"t{i1,i2;a1,a2}"); @@ -103,6 +120,7 @@ TEST_CASE("parse_expr", "[parse]") { REQUIRE(expr->as().ket().size() == 2); REQUIRE(expr->as().ket().at(0).label() == L"a_1"); REQUIRE(expr->as().ket().at(1).label() == L"a_2"); + REQUIRE(expr->as().aux().empty()); REQUIRE(expr == parse_expr(L"+t{i1, i2; a1, a2}")); REQUIRE(parse_expr(L"-t{i1;a1}")->is()); @@ -128,15 +146,45 @@ TEST_CASE("parse_expr", "[parse]") { auto expr1 = parse_expr(L"t{a↓1;i↑1}"); REQUIRE(expr1->as().bra().at(0).label() == L"a↓_1"); REQUIRE(expr1->as().ket().at(0).label() == L"i↑_1"); + + // Auxiliary indices + expr = parse_expr(L"t{;;i1}"); + REQUIRE(expr->is()); + REQUIRE(expr->as().bra().empty()); + REQUIRE(expr->as().ket().empty()); + REQUIRE(expr->as().aux().size() == 1); + REQUIRE(expr->as().aux()[0].label() == L"i_1"); + + // All index groups at once + expr = parse_expr(L"t{i1,i2;a1;x1,x2}"); + REQUIRE(expr->is()); + REQUIRE(expr->as().bra().size() == 2); + REQUIRE(expr->as().bra().at(0).label() == L"i_1"); + REQUIRE(expr->as().bra().at(1).label() == L"i_2"); + REQUIRE(expr->as().ket().size() == 1); + REQUIRE(expr->as().ket().at(0).label() == L"a_1"); + REQUIRE(expr->as().aux().size() == 2); + REQUIRE(expr->as().aux().at(0).label() == L"x_1"); + REQUIRE(expr->as().aux().at(1).label() == L"x_2"); } SECTION("Tensor with symmetry annotation") { auto expr1 = parse_expr(L"t{a1;i1}:A"); - auto expr2 = parse_expr(L"t{a1;i1}:S"); - auto expr3 = parse_expr(L"t{a1;i1}:N"); - REQUIRE(expr1->as().symmetry() == sequant::Symmetry::antisymm); - REQUIRE(expr2->as().symmetry() == sequant::Symmetry::symm); - REQUIRE(expr3->as().symmetry() == sequant::Symmetry::nonsymm); + auto expr2 = parse_expr(L"t{a1;i1}:S-C"); + auto expr3 = parse_expr(L"t{a1;i1}:N-S-N"); + + const Tensor& t1 = expr1->as(); + const Tensor& t2 = expr2->as(); + const Tensor& t3 = expr3->as(); + + REQUIRE(t1.symmetry() == Symmetry::antisymm); + + REQUIRE(t2.symmetry() == Symmetry::symm); + REQUIRE(t2.braket_symmetry() == BraKetSymmetry::conjugate); + + REQUIRE(t3.symmetry() == Symmetry::nonsymm); + REQUIRE(t3.braket_symmetry() == BraKetSymmetry::symm); + REQUIRE(t3.particle_symmetry() == ParticleSymmetry::nonsymm); } SECTION("Constant") { @@ -166,9 +214,8 @@ TEST_CASE("parse_expr", "[parse]") { REQUIRE(parse_expr(L"α^*")->is()); REQUIRE(parse_expr(L"β^*")->is()); REQUIRE(parse_expr(L"b^*")->is()); - // Currently the conjugated "property" really just is part of the - // variable's name - REQUIRE(parse_expr(L"b^*")->as().label() == L"b^*"); + REQUIRE(parse_expr(L"b^*")->as().conjugated()); + REQUIRE(parse_expr(L"b^*")->as().label() == L"b"); } SECTION("Product") { @@ -324,13 +371,16 @@ TEST_CASE("deparse", "[parse]") { using namespace sequant; std::vector expressions = { - L"t{a_1,a_2;a_3,a_4}:N", + L"t{a_1,a_2;a_3,a_4}:N-C-S", L"42", L"1/2", - L"-1/4 t{a_1,i_1;a_2,i_2}:S", + L"-1/4 t{a_1,i_1;a_2,i_2}:S-N-S", L"a + b - 4 specialVariable", - L"variable + A{a_1;i_1}:N * B{i_1;a_1}:A", - L"1/2 (a + b) * c"}; + L"variable + A{a_1;i_1}:N-N-S * B{i_1;a_1}:A-C-S", + L"1/2 (a + b) * c", + L"T1{}:N-N-N + T2{;;x_1}:N-N-N * T3{;;x_1}:N-N-N + T4{a_1;;x_2}:S-C-S * " + L"T5{;a_1;x_2}:S-S-S", + L"q1 * q2^* * q3"}; for (const std::wstring& current : expressions) { ExprPtr expression = parse_expr(current); diff --git a/tests/unit/test_runtime.cpp b/tests/unit/test_runtime.cpp index beb05fb13..5b249ff05 100644 --- a/tests/unit/test_runtime.cpp +++ b/tests/unit/test_runtime.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include #include diff --git a/tests/unit/test_space.cpp b/tests/unit/test_space.cpp index b6cf69f0e..8c171240e 100644 --- a/tests/unit/test_space.cpp +++ b/tests/unit/test_space.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include #include diff --git a/tests/unit/test_spin.cpp b/tests/unit/test_spin.cpp index f6074422c..18b7d55d4 100644 --- a/tests/unit/test_spin.cpp +++ b/tests/unit/test_spin.cpp @@ -10,6 +10,7 @@ // the old spin attribute representation is used to avoid changing the tests #include +#include "catch2_sequant.hpp" #include "test_config.hpp" #include @@ -23,9 +24,13 @@ #include #include #include +#include #include #include +#include +#include "test_config.hpp" + #include #include #include @@ -135,19 +140,19 @@ TEST_CASE("Spin", "[spin]") { // proto REQUIRE_NOTHROW(make_spinalpha(p_i)); REQUIRE(make_spinalpha(p_i).label() == L"p↑"); - REQUIRE(make_spinalpha(p_i).full_label() == L"p↑i↑"); + REQUIRE(make_spinalpha(p_i).full_label() == L"p↑"); REQUIRE(make_spinalpha(p_i).to_latex() == L"{p↑^{{i↑}}}"); REQUIRE_NOTHROW(make_spinalpha(p1_i)); REQUIRE(make_spinalpha(p1_i).label() == L"p↑_1"); - REQUIRE(make_spinalpha(p1_i).full_label() == L"p↑_1i↑"); + REQUIRE(make_spinalpha(p1_i).full_label() == L"p↑_1"); REQUIRE(make_spinalpha(p1_i).to_latex() == L"{p↑_1^{{i↑}}}"); REQUIRE_NOTHROW(make_spinalpha(p_i1)); REQUIRE(make_spinalpha(p_i1).label() == L"p↑"); - REQUIRE(make_spinalpha(p_i1).full_label() == L"p↑i↑_1"); + REQUIRE(make_spinalpha(p_i1).full_label() == L"p↑"); REQUIRE(make_spinalpha(p_i1).to_latex() == L"{p↑^{{i↑_1}}}"); REQUIRE_NOTHROW(make_spinalpha(p1_i1)); REQUIRE(make_spinalpha(p1_i1).label() == L"p↑_1"); - REQUIRE(make_spinalpha(p1_i1).full_label() == L"p↑_1i↑_1"); + REQUIRE(make_spinalpha(p1_i1).full_label() == L"p↑_1"); REQUIRE(make_spinalpha(p1_i1).to_latex() == L"{p↑_1^{{i↑_1}}}"); } @@ -265,9 +270,8 @@ TEST_CASE("Spin", "[spin]") { auto result = spintrace(expr); REQUIRE(result->is()); canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl( - {{g^{{p_4}{p_3}}_{{p_1}{p_2}}}} + " - L"{{g^{{p_3}{p_4}}_{{p_1}{p_2}}}}\\bigr) }"); + REQUIRE(result->size() == 2); + REQUIRE_THAT(result, EquivalentTo("- g{p1,p2;p4,p3} + g{p1,p2;p3,p4}")); } SECTION("Product") { @@ -319,15 +323,13 @@ TEST_CASE("Spin", "[spin]") { expand(result); rapid_simplify(result); canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl({{f^{{a_1}}_{{i_1}}}{t^{{i_1}}_{{a_1}}}} + " - L"{{g^{{a_1}{a_2}}_{{i_1}{i_2}}}{t^{{i_1}{i_2}}_{{a_1}{a_2}}}} - " - L"{{{\\frac{1}{2}}}{g^{{a_1}{a_2}}_{{i_1}{i_2}}}{t^{{i_2}{i_1}}_{{" - L"a_1}{a_2}}}} - " - L"{{{\\frac{1}{2}}}{g^{{a_1}{a_2}}_{{i_1}{i_2}}}{t^{{i_2}}_{{a_1}}}" - L"{t^{{i_1}}_{{a_2}}}} + " - L"{{g^{{a_1}{a_2}}_{{i_1}{i_2}}}{t^{{i_1}}_{{a_1}}}{t^{{i_2}}_{{a_" - L"2}}}}\\bigr) }"); + REQUIRE(result->is()); + REQUIRE(result->size() == 5); + REQUIRE_THAT( + result, + EquivalentTo("-1/2 g{i1,i2;a1,a2} t{a1,a2;i2,i1} + g{i1,i2;a1,a2} " + "t{a1,a2;i1,i2} + f{i1;a1} t{a1;i1} - 1/2 g{i1,i2;a1,a2} " + "t{a1;i2} t{a2;i1} + g{i1,i2;a1,a2} t{a1;i1} t{a2;i2}")); } // Sum SECTION("Expand Antisymmetrizer"){// 0-body @@ -547,21 +549,15 @@ SECTION("Expand Symmetrizer") { REQUIRE(result->size() == 6); result->canonicalize(); rapid_simplify(result); - REQUIRE( - to_latex(result) == - L"{ " - L"\\bigl({{{4}}{g^{{a_4}{a_5}}_{{i_4}{i_5}}}{t^{{i_4}}_{{a_2}}}{t^{{" - L"i_3}}_{{a_4}}}{t^{{i_1}}_{{a_5}}}{t^{{i_5}{i_2}}_{{a_1}{a_3}}}} + " - L"{{{4}}{g^{{a_4}{a_5}}_{{i_4}{i_5}}}{t^{{i_4}}_{{a_1}}}{t^{{i_2}}_{{" - L"a_4}}}{t^{{i_3}}_{{a_5}}}{t^{{i_1}{i_5}}_{{a_2}{a_3}}}} + " - L"{{{4}}{g^{{a_4}{a_5}}_{{i_4}{i_5}}}{t^{{i_4}}_{{a_1}}}{t^{{i_3}}_{{" - L"a_4}}}{t^{{i_2}}_{{a_5}}}{t^{{i_5}{i_1}}_{{a_2}{a_3}}}} + " - L"{{{4}}{g^{{a_4}{a_5}}_{{i_4}{i_5}}}{t^{{i_5}}_{{a_3}}}{t^{{i_2}}_{{" - L"a_4}}}{t^{{i_1}}_{{a_5}}}{t^{{i_3}{i_4}}_{{a_1}{a_2}}}} + " - L"{{{4}}{g^{{a_4}{a_5}}_{{i_4}{i_5}}}{t^{{i_4}}_{{a_3}}}{t^{{i_2}}_{{" - L"a_4}}}{t^{{i_1}}_{{a_5}}}{t^{{i_5}{i_3}}_{{a_1}{a_2}}}} + " - L"{{{4}}{g^{{a_4}{a_5}}_{{i_4}{i_5}}}{t^{{i_5}}_{{a_2}}}{t^{{i_3}}_{{" - L"a_4}}}{t^{{i_1}}_{{a_5}}}{t^{{i_2}{i_4}}_{{a_1}{a_3}}}}\\bigr) }"); + REQUIRE_THAT( + result, + EquivalentTo( + "4 g{i4,i5;a4,a5} t{a2;i4} t{a4;i3} t{a5;i1} t{a1,a3;i5,i2} + " + "4 g{i4,i5;a4,a5} t{a1;i5} t{a4;i2} t{a5;i3} t{a2,a3;i4,i1} + " + "4 g{i4,i5;a4,a5} t{a3;i4} t{a4;i1} t{a5;i2} t{a1,a2;i3,i5} + " + "4 g{i4,i5;a4,a5} t{a1;i5} t{a4;i3} t{a5;i2} t{a2,a3;i1,i4} + " + "4 g{i4,i5;a4,a5} t{a2;i4} t{a4;i1} t{a5;i3} t{a1,a3;i2,i5} + " + "4 g{i4,i5;a4,a5} t{a3;i4} t{a4;i2} t{a5;i1} t{a1,a2;i5,i3}")); } } @@ -580,9 +576,8 @@ SECTION("Symmetrize expression") { ex(L"t", bra{L"a_3"}, ket{L"i_1"}); auto result = factorize_S(input, {{L"i_1", L"a_1"}, {L"i_2", L"a_2"}}, true); - REQUIRE(to_latex(result) == - L"{{S^{{a_1}{a_2}}_{{i_1}{i_2}}}{g^{{i_2}{a_3}}_{{a_1}{a_2}}}{t^{" - L"{i_1}}_{{a_3}}}}"); + REQUIRE_THAT(result, + EquivalentTo("S{i1,i2;a1,a2} g{a1,a2;i2,a3}:S t{a3;i1}")); } { @@ -598,9 +593,10 @@ SECTION("Symmetrize expression") { ex(L"t", bra{L"a_1"}, ket{L"i_4"}) * ex(L"t", bra{L"a_3"}, ket{L"i_1"}); auto result = factorize_S(input, {{L"i_1", L"a_1"}, {L"i_2", L"a_2"}}); - REQUIRE(to_latex(result) == - L"{{S^{{a_2}{a_1}}_{{i_3}{i_4}}}{g^{{i_4}{a_3}}_{{i_1}{i_2}}}{t^{" - L"{i_1}}_{{a_1}}}{t^{{i_2}}_{{a_2}}}{t^{{i_3}}_{{a_3}}}}"); + REQUIRE_THAT( + result, + EquivalentTo( + "S{i3,i4;a2,a1} g{i1,i2;i4,a3}:S t{a1;i1} t{a2;i2} t{a3;i3}")); } { @@ -620,10 +616,9 @@ SECTION("Symmetrize expression") { ex(L"t", bra{L"a_2", L"a_4"}, ket{L"i_2", L"i_1"}); auto result = factorize_S(input, {{L"i_1", L"a_1"}, {L"i_2", L"a_2"}}, true); - REQUIRE( - to_latex(result) == - L"{{{2}}{S^{{a_1}{a_2}}_{{i_1}{i_2}}}{g^{{a_3}{a_4}}_{{i_3}{i_4}}}{t^" - L"{{i_3}}_{{a_4}}}{t^{{i_4}}_{{a_2}}}{t^{{i_1}{i_2}}_{{a_1}{a_3}}}}"); + REQUIRE(result->is() == false); + REQUIRE_THAT(result, EquivalentTo("2 S{i1,i2;a1,a2} g{i3,i4;a3,a4}:S " + "t{a4;i3} t{a2;i4} t{a1,a3;i1,i2}")); } } @@ -639,16 +634,18 @@ SECTION("Transform expression") { expand(result); rapid_simplify(result); canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl( - {{g^{{a_2}{i_1}}_{{a_1}{i_2}}}{t^{{i_2}}_{{a_2}}}} + " - L"{{{2}}{g^{{i_1}{a_2}}_{{a_1}{i_2}}}{t^{{i_2}}_{{a_2}}}}\\bigr) }"); + REQUIRE_THAT( + result, + EquivalentTo("- g{a1,i2;a2,i1} t{a2;i2} + 2 g{a1,i2;i1,a2} t{a2;i2}")); container::map idxmap = {{Index{L"i_1"}, Index{L"i_2"}}, {Index{L"i_2"}, Index{L"i_1"}}}; auto transformed_result = transform_expr(result, idxmap); - REQUIRE(to_latex(transformed_result) == - L"{ \\bigl( - {{g^{{a_2}{i_2}}_{{a_1}{i_1}}}{t^{{i_1}}_{{a_2}}}} + " - L"{{{2}}{g^{{i_2}{a_2}}_{{a_1}{i_1}}}{t^{{i_1}}_{{a_2}}}}\\bigr) }"); + REQUIRE(transformed_result->is()); + REQUIRE(transformed_result->size() == 2); + REQUIRE_THAT( + transformed_result, + EquivalentTo("- g{a1,i1;a2,i2} t{a2;i1} + 2 g{a1,i1;i2,a2} t{a2;i1}")); } SECTION("Swap bra kets") { @@ -697,9 +694,9 @@ SECTION("Closed-shell spintrace CCD") { const auto input = ex(ExprPtrList{parse_expr( L"1/4 g{i_1,i_2;a_1,a_2} t{a_1,a_2;i_1,i_2}", Symmetry::antisymm)}); auto result = closed_shell_CC_spintrace(input); - REQUIRE(result == parse_expr(L"2 g{i_1,i_2;a_1,a_2} t{a_1,a_2;i_1,i_2} - " - L"g{i_1,i_2;a_1,a_2} t{a_1,a_2;i_2,i_1}", - Symmetry::nonsymm)); + REQUIRE_THAT(result, + EquivalentTo(L"- g{i_1,i_2;a_1,a_2} t{a_1,a_2;i_2,i_1} + " + L"2 g{i_1,i_2;a_1,a_2} t{a_1,a_2;i_1,i_2}")); } { // CSV (aka PNO) Index i1(L"i_1", {L"i", 0b01}); @@ -779,10 +776,7 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl( - {{f^{{i_1}}_{{i_2}}}{t^{{i_2}}_{{a_1}}}}\\bigr) }"); + REQUIRE_THAT(result, EquivalentTo("- f{i2;i1} t{a1;i2}")); } { @@ -793,10 +787,7 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl({{f^{{a_2}}_{{a_1}}}{t^{{i_1}}_{{a_2}}}}\\bigr) }"); + REQUIRE_THAT(result, EquivalentTo("f{a1;a2} t{a2;i1}")); } { @@ -810,14 +801,10 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl( - " - L"{{{2}}{g^{{a_2}{i_1}}_{{i_2}{i_3}}}{t^{{i_3}{i_2}}_{{a_1}{a_2}}" - L"}} + " - L"{{g^{{a_2}{i_1}}_{{i_2}{i_3}}}{t^{{i_2}{i_3}}_{{a_1}{a_2}}}}" - L"\\bigr) }"); + REQUIRE_THAT( + result, + EquivalentTo( + "g{i2,i3;i1,a2} t{a1,a2;i3,i2} - 2 g{i2,i3;a2,i1} t{a1,a2;i3,i2}")); } { @@ -831,14 +818,9 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl( - " - L"{{g^{{a_3}{a_2}}_{{a_1}{i_2}}}{t^{{i_1}{i_2}}_{{a_2}{a_3}}}} + " - L"{{{2}}{g^{{a_3}{a_2}}_{{a_1}{i_2}}}{t^{{i_2}{i_1}}_{{a_2}{a_3}}" - L"}}\\bigr) }"); + REQUIRE_THAT(result, EquivalentTo("- g{a1,i2;a3,a2} t{a2,a3;i1,i2} + 2 " + "g{a1,i2;a3,a2} t{a2,a3;i2,i1}")); } { @@ -851,12 +833,9 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE( - to_latex(result) == - L"{ \\bigl( - {{f^{{a_2}}_{{i_2}}}{t^{{i_2}{i_1}}_{{a_1}{a_2}}}} + " - L"{{{2}}{f^{{a_2}}_{{i_2}}}{t^{{i_1}{i_2}}_{{a_1}{a_2}}}}\\bigr) }"); + REQUIRE_THAT( + result, + EquivalentTo("-f{i2;a2} t{a1,a2;i2,i1} + 2 f{i2;a2} t{a1,a2;i1,i2}")); } { @@ -869,14 +848,8 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE(to_latex(result) == - L"{ " - L"\\bigl({{{2}}{g^{{a_3}{a_2}}_{{a_1}{i_2}}}{t^{{i_1}}_{{a_3}}}{t^{" - L"{i_2}}_{{a_2}}}} - " - L"{{g^{{a_3}{a_2}}_{{a_1}{i_2}}}{t^{{i_1}}_{{a_2}}}{t^{{i_2}}_{{a_" - L"3}}}}\\bigr) }"); + REQUIRE_THAT(result, EquivalentTo("2 g{a1,i2;a3,a2} t{a2;i2} t{a3;i1} - " + "g{a1,i2;a3,a2} t{a2;i1} t{a3;i2}")); } { @@ -889,14 +862,8 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE(to_latex(result) == - L"{ " - L"\\bigl({{g^{{a_2}{i_1}}_{{i_2}{i_3}}}{t^{{i_2}}_{{a_1}}}{t^{{i_" - L"3}}_{{a_2}}}} - " - L"{{{2}}{g^{{a_2}{i_1}}_{{i_2}{i_3}}}{t^{{i_3}}_{{a_1}}}{t^{{i_2}" - L"}_{{a_2}}}}\\bigr) }"); + REQUIRE_THAT(result, EquivalentTo("g{i2,i3;i1,a2} t{a1;i3} t{a2;i2} - 2 " + "g{i2,i3;a2,i1} t{a1;i3} t{a2;i2}")); } { @@ -909,12 +876,7 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl( - " - L"{{f^{{a_2}}_{{i_2}}}{t^{{i_2}}_{{a_1}}}{t^{{i_1}}_{{a_2}}}}" - L"\\bigr) }"); + REQUIRE_THAT(result, EquivalentTo("-f{i2;a2} t{a1;i2} t{a2;i1}")); } { @@ -929,14 +891,9 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl( - " - L"{{{2}}{g^{{a_2}{a_3}}_{{i_2}{i_3}}}{t^{{i_3}}_{{a_1}}}{t^{{" - L"i_2}{i_1}}_{{a_2}{a_3}}}} + " - L"{{g^{{a_2}{a_3}}_{{i_2}{i_3}}}{t^{{i_2}}_{{a_1}}}{t^{{i_3}{" - L"i_1}}_{{a_2}{a_3}}}}\\bigr) }"); + REQUIRE_THAT(result, + EquivalentTo("g{i2,i3;a2,a3} t{a1;i3} t{a2,a3;i1,i2} - 2 " + "g{i2,i3;a2,a3} t{a1;i3} t{a2,a3;i2,i1}")); } { @@ -951,14 +908,9 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl( - " - L"{{{2}}{g^{{a_2}{a_3}}_{{i_2}{i_3}}}{t^{{i_1}}_{{a_3}}}{t^{{i_3}" - L"{i_2}}_{{a_1}{a_2}}}} + " - L"{{g^{{a_2}{a_3}}_{{i_2}{i_3}}}{t^{{i_1}}_{{a_3}}}{t^{{i_2}{i_3}" - L"}_{{a_1}{a_2}}}}\\bigr) }"); + REQUIRE_THAT(result, + EquivalentTo("g{i2,i3;a2,a3} t{a2;i1} t{a1,a3;i3,i2} - 2 " + "g{i2,i3;a2,a3} t{a3;i1} t{a1,a2;i3,i2}")); } { @@ -973,18 +925,11 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl( - " - L"{{{2}}{g^{{a_2}{a_3}}_{{i_2}{i_3}}}{t^{{i_2}}_{{a_3}}}{t^{{i_1}" - L"{i_3}}_{{a_1}{a_2}}}} + " - L"{{{4}}{g^{{a_2}{a_3}}_{{i_2}{i_3}}}{t^{{i_2}}_{{a_2}}}{t^{{i_1}" - L"{i_3}}_{{a_1}{a_3}}}} - " - L"{{{2}}{g^{{a_2}{a_3}}_{{i_2}{i_3}}}{t^{{i_2}}_{{a_2}}}{t^{{i_3}" - L"{i_1}}_{{a_1}{a_3}}}} + " - L"{{g^{{a_2}{a_3}}_{{i_2}{i_3}}}{t^{{i_3}}_{{a_2}}}{t^{{i_2}{i_1}" - L"}_{{a_1}{a_3}}}}\\bigr) }"); + REQUIRE_THAT(result, + EquivalentTo("-2 g{i2,i3;a2,a3} t{a3;i2} t{a1,a2;i1,i3} + 4 " + "g{i2,i3;a2,a3} t{a2;i2} t{a1,a3;i1,i3} + " + "g{i2,i3;a2,a3} t{a3;i2} t{a1,a2;i3,i1} - 2 " + "g{i2,i3;a2,a3} t{a2;i2} t{a1,a3;i3,i1}")); } { @@ -998,14 +943,9 @@ SECTION("Closed-shell spintrace CCSD") { auto result = ex(rational{1, 2}) * spintrace(input, {{L"i_1", L"a_1"}}); expand(result); - rapid_simplify(result); - canonicalize(result); - REQUIRE(to_latex(result) == - L"{ \\bigl( - " - L"{{{2}}{g^{{a_2}{a_3}}_{{i_2}{i_3}}}{t^{{i_3}}_{{a_1}}}{t^{{i_2}" - L"}_{{a_2}}}{t^{{i_1}}_{{a_3}}}} + " - L"{{g^{{a_2}{a_3}}_{{i_2}{i_3}}}{t^{{i_2}}_{{a_1}}}{t^{{i_3}}_{{" - L"a_2}}}{t^{{i_1}}_{{a_3}}}}\\bigr) }"); + REQUIRE_THAT(result, + EquivalentTo("-2 g{i2,i3;a2,a3} t{a1;i3} t{a2;i2} t{a3;i1} + " + "g{i2,i3;a2,a3} t{a1;i3} t{a2;i1} t{a3;i2}")); } } // CCSD R1 @@ -1024,16 +964,13 @@ SECTION("Closed-shell spintrace CCSDT terms") { result = closed_shell_spintrace( input, {{L"i_1", L"a_1"}, {L"i_2", L"a_2"}, {L"i_3", L"a_3"}}); simplify(result); - REQUIRE(to_latex(result) == - L"{ \\bigl( - " - L"{{{2}}{S^{{a_1}{a_2}{a_3}}_{{i_1}{i_2}{i_3}}}{f^{{i_3}}_{{i_4}}" - L"}{t^{{i_2}{i_1}{i_4}}_{{a_1}{a_2}{a_3}}}} + " - L"{{{4}}{S^{{a_1}{a_2}{a_3}}_{{i_1}{i_2}{i_3}}}{f^{{i_3}}_{{i_4}}" - L"}{t^{{i_1}{i_2}{i_4}}_{{a_1}{a_2}{a_3}}}} + " - L"{{{2}}{S^{{a_1}{a_2}{a_3}}_{{i_1}{i_2}{i_3}}}{f^{{i_3}}_{{i_4}}" - L"}{t^{{i_4}{i_1}{i_2}}_{{a_1}{a_2}{a_3}}}} - " - L"{{{4}}{S^{{a_1}{a_2}{a_3}}_{{i_1}{i_2}{i_3}}}{f^{{i_3}}_{{i_4}}" - L"}{t^{{i_1}{i_4}{i_2}}_{{a_1}{a_2}{a_3}}}}\\bigr) }"); + REQUIRE(result->size() == 4); + REQUIRE_THAT( + result, + EquivalentTo("2 S{i1,i2,i3;a1,a2,a3} f{i4;i3} t{a1,a2,a3;i4,i1,i2} - 4 " + "S{i1,i2,i3;a1,a2,a3} f{i4;i3} t{a1,a2,a3;i1,i4,i2} + 4 " + "S{i1,i2,i3;a1,a2,a3} f{i4;i3} t{a1,a2,a3;i1,i2,i4} - 2 " + "S{i1,i2,i3;a1,a2,a3} f{i4;i3} t{a1,a2,a3;i2,i1,i4}")); } { // f * t3 @@ -1356,9 +1293,11 @@ SECTION("Open-shell spin-tracing") { auto result = open_shell_spintrace(input, {{L"i_1", L"a_1"}, {L"i_2", L"a_2"}}); REQUIRE(result.size() == 3); - REQUIRE(to_latex(result[0]) == + REQUIRE( + toUtf8(to_latex(result[0])) == + toUtf8( L"{{{-\\frac{1}{2}}}{\\bar{g}^{{i↑_1}{i↑_2}}_{{a↑_1}{i↑_3}}}{t^{{" - L"i↑_3}}_{{a↑_2}}}}"); + L"i↑_3}}_{{a↑_2}}}}")); REQUIRE(to_latex(result[1]) == L"{{{-\\frac{1}{2}}}{g^{{i↑_1}{i↓_2}}_{{a↑_1}{i↓_1}}}{t^{{i↓_1}}_" L"{{a↓_2}}}}"); @@ -1379,18 +1318,12 @@ SECTION("Open-shell spin-tracing") { REQUIRE(to_latex(result[0]) == L"{{{\\frac{1}{12}}}{f^{{a↑_4}}_{{a↑_1}}}{\\bar{t}^{{i↑_1}{i↑_2}{" L"i↑_3}}_{{a↑_2}{a↑_3}{a↑_4}}}}"); - REQUIRE(to_latex(result[1]) == - L"{ \\bigl( - " - L"{{{\\frac{1}{12}}}{f^{{a↑_3}}_{{a↑_1}}}{t^{{i↑_1}{i↑_2}{i↓_3}}_" - L"{{a↑_2}{a↑_3}{a↓_3}}}} + " - L"{{{\\frac{1}{12}}}{f^{{a↑_3}}_{{a↑_1}}}{t^{{i↑_2}{i↑_1}{i↓_3}}_" - L"{{a↑_2}{a↑_3}{a↓_3}}}}\\bigr) }"); - REQUIRE(to_latex(result[2]) == - L"{ \\bigl( - " - L"{{{\\frac{1}{12}}}{f^{{a↑_2}}_{{a↑_1}}}{t^{{i↑_1}{i↓_3}{i↓_2}}_" - L"{{a↑_2}{a↓_2}{a↓_3}}}} + " - L"{{{\\frac{1}{12}}}{f^{{a↑_2}}_{{a↑_1}}}{t^{{i↑_1}{i↓_2}{i↓_3}}_{{" - L"a↑_2}{a↓_2}{a↓_3}}}}\\bigr) }"); + REQUIRE_THAT(result[1], + EquivalentTo("-1/12 f{a↑1;a↑3} t{a↑2,a↑3,a↓3;i↑_1,i↑2,i↓3} + " + "1/12 f{a↑1;a↑3} t{a↑2,a↑3,a↓3;i↑2,i↑1,i↓3}")); + REQUIRE_THAT(result[2], + EquivalentTo("-1/12 f{a↑1;a↑2} t{a↑2,a↓2,a↓3;i↑1,i↓3,i↓2} + " + "1/12 f{a↑1;a↑2} t{a↑2,a↓2,a↓3;i↑1,i↓2,i↓3}")); REQUIRE(to_latex(result[3]) == L"{{{\\frac{1}{12}}}{f^{{a↓_4}}_{{a↓_1}}}{\\bar{t}^{{i↓_1}{i↓_2}{" L"i↓_3}}_{{a↓_2}{a↓_3}{a↓_4}}}}"); @@ -1411,11 +1344,9 @@ SECTION("Open-shell spin-tracing") { ex(g) * ex(t3); auto result = expand_A_op(input); result->visit(reset_idx_tags); - canonicalize(result); - rapid_simplify(result); - REQUIRE(to_latex(result) == - L"{{{\\frac{1}{3}}}{\\bar{g}^{{i↑_1}{i↑_2}}_{{i↑_3}{i↑_4}}}{t^{{" - L"i↑_3}{i↑_4}{i↓_3}}_{{a↑_1}{a↑_2}{a↓_3}}}}"); + REQUIRE_THAT( + result, + EquivalentTo("1/3 g{i↑3,i↑4;i↑1,i↑2}:A t{a↑1,a↑2,a↓3;i↑3,i↑4,i↓3}:N")); g = Tensor(L"g", bra{i4A, i5A}, ket{i1A, i2A}, Symmetry::antisymm); t3 = @@ -1425,11 +1356,9 @@ SECTION("Open-shell spin-tracing") { ex(t3); result = expand_A_op(input); result->visit(reset_idx_tags); - canonicalize(result); - rapid_simplify(result); - REQUIRE(to_latex(result) == - L"{{{\\frac{1}{3}}}{\\bar{g}^{{i↑_1}{i↑_2}}_{{i↑_3}{i↑_4}}}{t^{{" - L"i↑_3}{i↑_4}{i↓_3}}_{{a↑_1}{a↑_2}{a↓_3}}}}"); + REQUIRE_THAT( + result, + EquivalentTo("1/3 g{i↑3,i↑4;i↑1,i↑2}:A t{a↑_1,a↑2,a↓3;i↑3,i↑4,i↓3}:N")); } // CCSDT R3 10 aaa, bbb diff --git a/tests/unit/test_string.cpp b/tests/unit/test_string.cpp index c7f586152..841699519 100644 --- a/tests/unit/test_string.cpp +++ b/tests/unit/test_string.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include diff --git a/tests/unit/test_tensor.cpp b/tests/unit/test_tensor.cpp index 9b48de88b..7fd4a8fc0 100644 --- a/tests/unit/test_tensor.cpp +++ b/tests/unit/test_tensor.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include #include @@ -32,6 +34,7 @@ TEST_CASE("Tensor", "[elements]") { REQUIRE(!t1); REQUIRE(t1.bra_rank() == 0); REQUIRE(t1.ket_rank() == 0); + REQUIRE(t1.aux_rank() == 0); REQUIRE(t1.rank() == 0); REQUIRE(t1.symmetry() == Symmetry::invalid); REQUIRE(t1.braket_symmetry() == BraKetSymmetry::invalid); @@ -43,33 +46,41 @@ TEST_CASE("Tensor", "[elements]") { REQUIRE(t2); REQUIRE(t2.bra_rank() == 1); REQUIRE(t2.ket_rank() == 1); + REQUIRE(t2.aux_rank() == 0); REQUIRE(t2.rank() == 1); + REQUIRE(t2.const_indices().size() == 2); REQUIRE(t2.symmetry() == Symmetry::nonsymm); REQUIRE(t2.braket_symmetry() == BraKetSymmetry::conjugate); REQUIRE(t2.particle_symmetry() == ParticleSymmetry::symm); REQUIRE(t2.label() == L"F"); - REQUIRE_NOTHROW(Tensor(L"N", bra{L"i_1"}, ket{})); - auto t3 = Tensor(L"N", bra{L"i_1"}, ket{}); + REQUIRE_NOTHROW(Tensor(L"N", bra{L"i_1"}, ket{}, aux{L"a_1"})); + auto t3 = Tensor(L"N", bra{L"i_1"}, ket{}, aux{L"a_1"}); REQUIRE(t3); REQUIRE(t3.bra_rank() == 1); REQUIRE(t3.ket_rank() == 0); + REQUIRE(t3.aux_rank() == 1); REQUIRE_THROWS(t3.rank()); + REQUIRE(t3.const_indices().size() == 2); REQUIRE(t3.symmetry() == Symmetry::nonsymm); REQUIRE(t3.braket_symmetry() == BraKetSymmetry::conjugate); REQUIRE(t3.particle_symmetry() == ParticleSymmetry::symm); REQUIRE(t3.label() == L"N"); REQUIRE_NOTHROW(Tensor(L"g", bra{Index{L"i_1"}, Index{L"i_2"}}, - ket{Index{L"i_3"}, Index{L"i_4"}}, Symmetry::nonsymm, + ket{Index{L"i_3"}, Index{L"i_4"}}, + aux{Index{L"i_5"}}, Symmetry::nonsymm, BraKetSymmetry::symm, ParticleSymmetry::nonsymm)); auto t4 = Tensor(L"g", bra{Index{L"i_1"}, Index{L"i_2"}}, - ket{Index{L"i_3"}, Index{L"i_4"}}, Symmetry::nonsymm, - BraKetSymmetry::symm, ParticleSymmetry::nonsymm); + ket{Index{L"i_3"}, Index{L"i_4"}}, aux{Index{L"i_5"}}, + Symmetry::nonsymm, BraKetSymmetry::symm, + ParticleSymmetry::nonsymm); REQUIRE(t4); REQUIRE(t4.bra_rank() == 2); REQUIRE(t4.ket_rank() == 2); + REQUIRE(t4.aux_rank() == 1); REQUIRE(t4.rank() == 2); + REQUIRE(t4.const_indices().size() == 5); REQUIRE(t4.symmetry() == Symmetry::nonsymm); REQUIRE(t4.braket_symmetry() == BraKetSymmetry::symm); REQUIRE(t4.particle_symmetry() == ParticleSymmetry::nonsymm); @@ -78,7 +89,8 @@ TEST_CASE("Tensor", "[elements]") { SECTION("index transformation") { auto t = Tensor(L"g", bra{Index{L"i_1"}, Index{L"i_2"}}, - ket{Index{L"i_3"}, Index{L"i_4"}}, Symmetry::antisymm); + ket{Index{L"i_3"}, Index{L"i_4"}}, aux{Index{L"i_5"}}, + Symmetry::antisymm); std::map idxmap = {{Index{L"i_1"}, Index{L"i_2"}}, {Index{L"i_2"}, Index{L"i_1"}}}; REQUIRE(t.transform_indices(idxmap)); @@ -91,13 +103,15 @@ TEST_CASE("Tensor", "[elements]") { REQUIRE(!t.ket()[0].tag().has_value()); REQUIRE(!t.ket()[1].tag().has_value()); REQUIRE(t == Tensor(L"g", bra{Index{L"i_2"}, Index{L"i_1"}}, - ket{Index{L"i_3"}, Index{L"i_4"}}, Symmetry::antisymm)); + ket{Index{L"i_3"}, Index{L"i_4"}}, aux{Index{L"i_5"}}, + Symmetry::antisymm)); // tagged indices are protected, so no replacements the second goaround REQUIRE(!t.transform_indices(idxmap)); t.reset_tags(); REQUIRE(t.transform_indices(idxmap)); REQUIRE(t == Tensor(L"g", bra{Index{L"i_1"}, Index{L"i_2"}}, - ket{Index{L"i_3"}, Index{L"i_4"}}, Symmetry::antisymm)); + ket{Index{L"i_3"}, Index{L"i_4"}}, aux{Index{L"i_5"}}, + Symmetry::antisymm)); t.reset_tags(); REQUIRE(!t.bra()[0].tag().has_value()); REQUIRE(!t.bra()[1].tag().has_value()); @@ -115,12 +129,24 @@ TEST_CASE("Tensor", "[elements]") { REQUIRE_NOTHROW(t2_hash = hash_value(t2)); REQUIRE(t1_hash != t2_hash); + auto t3 = Tensor(L"F", bra{L"i_2"}, ket{L"i_1"}, aux{L"i_3"}); + size_t t3_hash; + REQUIRE_NOTHROW(t3_hash = hash_value(t3)); + REQUIRE(t2_hash != t3_hash); + REQUIRE(t1_hash != t3_hash); + } // SECTION("hash") SECTION("latex") { auto t1 = Tensor(L"F", bra{L"i_1"}, ket{L"i_2"}); REQUIRE(to_latex(t1) == L"{F^{{i_2}}_{{i_1}}}"); + auto t2 = Tensor(L"F", bra{L"i_1"}, ket{L"i_2"}, aux{L"i_3"}); + REQUIRE(to_latex(t2) == L"{F^{{i_2}}_{{i_1}}[{i_3}]}"); + + auto t3 = Tensor(L"F", bra{L"i_1"}, ket{L"i_2"}, aux{L"i_3", L"i_4"}); + REQUIRE(to_latex(t3) == L"{F^{{i_2}}_{{i_1}}[{i_3},{i_4}]}"); + auto h1 = ex(L"F", bra{L"i_1"}, ket{L"i_2"}) * ex(cre({L"i_1"}), ann({L"i_2"})); REQUIRE(to_latex(h1) == @@ -141,7 +167,7 @@ TEST_CASE("Tensor", "[elements]") { REQUIRE(to_latex(t1) == L"{t^{{i_1}}_{{a_1}}}"); auto h1 = ex(L"F", bra{L"i_1"}, ket{L"i_2"}) * - ex(cre({L"i_1"}), ann({L"i_2"})); + ex(cre{L"i_1"}, ann{L"i_2"}); h1 = adjoint(h1); REQUIRE(to_latex(h1) == L"{{\\tilde{a}^{{i_2}}_{{i_1}}}{F^{{i_1}}_{{i_2}}}}"); diff --git a/tests/unit/test_tensor_network.cpp b/tests/unit/test_tensor_network.cpp index 173cb0c09..3e5ab6960 100644 --- a/tests/unit/test_tensor_network.cpp +++ b/tests/unit/test_tensor_network.cpp @@ -4,6 +4,8 @@ #include +#include "catch2_sequant.hpp" + #include #include #include @@ -13,17 +15,23 @@ #include #include #include +#include #include +#include #include +#include #include +#include #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -126,9 +134,9 @@ TEST_CASE("TensorNetwork", "[elements]") { // to_latex(std::dynamic_pointer_cast(tn.tensors()[1])) << // std::endl; REQUIRE(to_latex(std::dynamic_pointer_cast(tn.tensors()[0])) == - L"{F^{{i_1}}_{{i_2}}}"); + L"{F^{{i_2}}_{{i_1}}}"); REQUIRE(to_latex(std::dynamic_pointer_cast(tn.tensors()[1])) == - L"{\\tilde{a}^{{i_2}}_{{i_1}}}"); + L"{\\tilde{a}^{{i_1}}_{{i_2}}}"); REQUIRE(tn.idxrepl().size() == 2); } @@ -208,139 +216,202 @@ TEST_CASE("TensorNetwork", "[elements]") { std::basic_ostringstream oss; REQUIRE_NOTHROW(graph->write_dot(oss, vlabels)); // std::wcout << "oss.str() = " << std::endl << oss.str() << std::endl; - REQUIRE(oss.str() == - L"graph g {\n" - "v0 [label=\"{a_1}\"; color=\"#64f,acf\"];\n" - "v0 -- v29\n" - "v0 -- v58\n" - "v1 [label=\"{a_2}\"; color=\"#64f,acf\"];\n" - "v1 -- v29\n" - "v1 -- v58\n" - "v2 [label=\"{a_3}\"; color=\"#64f,acf\"];\n" - "v2 -- v33\n" - "v2 -- v54\n" - "v3 [label=\"{a_4}\"; color=\"#64f,acf\"];\n" - "v3 -- v33\n" - "v3 -- v54\n" - "v4 [label=\"{a_5}\"; color=\"#64f,acf\"];\n" - "v4 -- v37\n" - "v4 -- v50\n" - "v5 [label=\"{a_6}\"; color=\"#64f,acf\"];\n" - "v5 -- v37\n" - "v5 -- v50\n" - "v6 [label=\"{a_7}\"; color=\"#64f,acf\"];\n" - "v6 -- v22\n" - "v6 -- v41\n" - "v7 [label=\"{a_8}\"; color=\"#64f,acf\"];\n" - "v7 -- v22\n" - "v7 -- v41\n" - "v8 [label=\"{i_1}\"; color=\"#9c2,a20\"];\n" - "v8 -- v30\n" - "v8 -- v57\n" - "v9 [label=\"{i_2}\"; color=\"#9c2,a20\"];\n" - "v9 -- v30\n" - "v9 -- v57\n" - "v10 [label=\"{i_3}\"; color=\"#9c2,a20\"];\n" - "v10 -- v34\n" - "v10 -- v53\n" - "v11 [label=\"{i_4}\"; color=\"#9c2,a20\"];\n" - "v11 -- v34\n" - "v11 -- v53\n" - "v12 [label=\"{i_5}\"; color=\"#9c2,a20\"];\n" - "v12 -- v38\n" - "v12 -- v49\n" - "v13 [label=\"{i_6}\"; color=\"#9c2,a20\"];\n" - "v13 -- v38\n" - "v13 -- v49\n" - "v14 [label=\"{i_7}\"; color=\"#9c2,a20\"];\n" - "v14 -- v21\n" - "v14 -- v42\n" - "v15 [label=\"{i_8}\"; color=\"#9c2,a20\"];\n" - "v15 -- v21\n" - "v15 -- v42\n" - "v16 [label=\"{\\kappa_1}\"; color=\"#712,6de\"];\n" - "v16 -- v25\n" - "v16 -- v46\n" - "v17 [label=\"{\\kappa_2}\"; color=\"#712,6de\"];\n" - "v17 -- v25\n" - "v17 -- v46\n" - "v18 [label=\"{\\kappa_3}\"; color=\"#712,6de\"];\n" - "v18 -- v26\n" - "v18 -- v45\n" - "v19 [label=\"{\\kappa_4}\"; color=\"#712,6de\"];\n" - "v19 -- v26\n" - "v19 -- v45\n" - "v20 [label=\"A\"; color=\"#518,020\"];\n" - "v20 -- v23\n" - "v21 [label=\"bra2a\"; color=\"#eaa,2ab\"];\n" - "v21 -- v23\n" - "v22 [label=\"ket2a\"; color=\"#5a8,fd3\"];\n" - "v22 -- v23\n" - "v23 [label=\"bka\"; color=\"#518,020\"];\n" - "v24 [label=\"g\"; color=\"#2e0,351\"];\n" - "v24 -- v27\n" - "v25 [label=\"bra2a\"; color=\"#eaa,2ab\"];\n" - "v25 -- v27\n" - "v26 [label=\"ket2a\"; color=\"#5a8,fd3\"];\n" - "v26 -- v27\n" - "v27 [label=\"bka\"; color=\"#2e0,351\"];\n" - "v28 [label=\"t\"; color=\"#43,e44\"];\n" - "v28 -- v31\n" - "v29 [label=\"bra2a\"; color=\"#eaa,2ab\"];\n" - "v29 -- v31\n" - "v30 [label=\"ket2a\"; color=\"#5a8,fd3\"];\n" - "v30 -- v31\n" - "v31 [label=\"bka\"; color=\"#43,e44\"];\n" - "v32 [label=\"t\"; color=\"#43,e44\"];\n" - "v32 -- v35\n" - "v33 [label=\"bra2a\"; color=\"#eaa,2ab\"];\n" - "v33 -- v35\n" - "v34 [label=\"ket2a\"; color=\"#5a8,fd3\"];\n" - "v34 -- v35\n" - "v35 [label=\"bka\"; color=\"#43,e44\"];\n" - "v36 [label=\"t\"; color=\"#43,e44\"];\n" - "v36 -- v39\n" - "v37 [label=\"bra2a\"; color=\"#eaa,2ab\"];\n" - "v37 -- v39\n" - "v38 [label=\"ket2a\"; color=\"#5a8,fd3\"];\n" - "v38 -- v39\n" - "v39 [label=\"bka\"; color=\"#43,e44\"];\n" - "v40 [label=\"ã\"; color=\"#cbf,be5\"];\n" - "v40 -- v43\n" - "v41 [label=\"bra2a\"; color=\"#eaa,2ab\"];\n" - "v41 -- v43\n" - "v42 [label=\"ket2a\"; color=\"#5a8,fd3\"];\n" - "v42 -- v43\n" - "v43 [label=\"bka\"; color=\"#cbf,be5\"];\n" - "v44 [label=\"ã\"; color=\"#cbf,be5\"];\n" - "v44 -- v47\n" - "v45 [label=\"bra2a\"; color=\"#eaa,2ab\"];\n" - "v45 -- v47\n" - "v46 [label=\"ket2a\"; color=\"#5a8,fd3\"];\n" - "v46 -- v47\n" - "v47 [label=\"bka\"; color=\"#cbf,be5\"];\n" - "v48 [label=\"ã\"; color=\"#cbf,be5\"];\n" - "v48 -- v51\n" - "v49 [label=\"bra2a\"; color=\"#eaa,2ab\"];\n" - "v49 -- v51\n" - "v50 [label=\"ket2a\"; color=\"#5a8,fd3\"];\n" - "v50 -- v51\n" - "v51 [label=\"bka\"; color=\"#cbf,be5\"];\n" - "v52 [label=\"ã\"; color=\"#cbf,be5\"];\n" - "v52 -- v55\n" - "v53 [label=\"bra2a\"; color=\"#eaa,2ab\"];\n" - "v53 -- v55\n" - "v54 [label=\"ket2a\"; color=\"#5a8,fd3\"];\n" - "v54 -- v55\n" - "v55 [label=\"bka\"; color=\"#cbf,be5\"];\n" - "v56 [label=\"ã\"; color=\"#cbf,be5\"];\n" - "v56 -- v59\n" - "v57 [label=\"bra2a\"; color=\"#eaa,2ab\"];\n" - "v57 -- v59\n" - "v58 [label=\"ket2a\"; color=\"#5a8,fd3\"];\n" - "v58 -- v59\n" - "v59 [label=\"bka\"; color=\"#cbf,be5\"];\n" - "}\n"); + const std::wstring actual = oss.str(); + const std::wstring expected = + L"graph g {\n" + L"v0 [label=\"{a_1}\"; style=filled; color=\"#2603c0\"; " + L"fillcolor=\"#2603c080\"; penwidth=2];\n" + L"v0 -- v29\n" + L"v0 -- v58\n" + L"v1 [label=\"{a_2}\"; style=filled; color=\"#2603c0\"; " + L"fillcolor=\"#2603c080\"; penwidth=2];\n" + L"v1 -- v29\n" + L"v1 -- v58\n" + L"v2 [label=\"{a_3}\"; style=filled; color=\"#2603c0\"; " + L"fillcolor=\"#2603c080\"; penwidth=2];\n" + L"v2 -- v33\n" + L"v2 -- v54\n" + L"v3 [label=\"{a_4}\"; style=filled; color=\"#2603c0\"; " + L"fillcolor=\"#2603c080\"; penwidth=2];\n" + L"v3 -- v33\n" + L"v3 -- v54\n" + L"v4 [label=\"{a_5}\"; style=filled; color=\"#2603c0\"; " + L"fillcolor=\"#2603c080\"; penwidth=2];\n" + L"v4 -- v37\n" + L"v4 -- v50\n" + L"v5 [label=\"{a_6}\"; style=filled; color=\"#2603c0\"; " + L"fillcolor=\"#2603c080\"; penwidth=2];\n" + L"v5 -- v37\n" + L"v5 -- v50\n" + L"v6 [label=\"{a_7}\"; style=filled; color=\"#2603c0\"; " + L"fillcolor=\"#2603c080\"; penwidth=2];\n" + L"v6 -- v22\n" + L"v6 -- v41\n" + L"v7 [label=\"{a_8}\"; style=filled; color=\"#2603c0\"; " + L"fillcolor=\"#2603c080\"; penwidth=2];\n" + L"v7 -- v22\n" + L"v7 -- v41\n" + L"v8 [label=\"{i_1}\"; style=filled; color=\"#103109\"; " + L"fillcolor=\"#10310980\"; penwidth=2];\n" + L"v8 -- v30\n" + L"v8 -- v57\n" + L"v9 [label=\"{i_2}\"; style=filled; color=\"#103109\"; " + L"fillcolor=\"#10310980\"; penwidth=2];\n" + L"v9 -- v30\n" + L"v9 -- v57\n" + L"v10 [label=\"{i_3}\"; style=filled; color=\"#103109\"; " + L"fillcolor=\"#10310980\"; penwidth=2];\n" + L"v10 -- v34\n" + L"v10 -- v53\n" + L"v11 [label=\"{i_4}\"; style=filled; color=\"#103109\"; " + L"fillcolor=\"#10310980\"; penwidth=2];\n" + L"v11 -- v34\n" + L"v11 -- v53\n" + L"v12 [label=\"{i_5}\"; style=filled; color=\"#103109\"; " + L"fillcolor=\"#10310980\"; penwidth=2];\n" + L"v12 -- v38\n" + L"v12 -- v49\n" + L"v13 [label=\"{i_6}\"; style=filled; color=\"#103109\"; " + L"fillcolor=\"#10310980\"; penwidth=2];\n" + L"v13 -- v38\n" + L"v13 -- v49\n" + L"v14 [label=\"{i_7}\"; style=filled; color=\"#103109\"; " + L"fillcolor=\"#10310980\"; penwidth=2];\n" + L"v14 -- v21\n" + L"v14 -- v42\n" + L"v15 [label=\"{i_8}\"; style=filled; color=\"#103109\"; " + L"fillcolor=\"#10310980\"; penwidth=2];\n" + L"v15 -- v21\n" + L"v15 -- v42\n" + L"v16 [label=\"{\\kappa_1}\"; style=filled; color=\"#0d4103\"; " + L"fillcolor=\"#0d410380\"; penwidth=2];\n" + L"v16 -- v25\n" + L"v16 -- v46\n" + L"v17 [label=\"{\\kappa_2}\"; style=filled; color=\"#0d4103\"; " + L"fillcolor=\"#0d410380\"; penwidth=2];\n" + L"v17 -- v25\n" + L"v17 -- v46\n" + L"v18 [label=\"{\\kappa_3}\"; style=filled; color=\"#0d4103\"; " + L"fillcolor=\"#0d410380\"; penwidth=2];\n" + L"v18 -- v26\n" + L"v18 -- v45\n" + L"v19 [label=\"{\\kappa_4}\"; style=filled; color=\"#0d4103\"; " + L"fillcolor=\"#0d410380\"; penwidth=2];\n" + L"v19 -- v26\n" + L"v19 -- v45\n" + L"v20 [label=\"A\"; style=filled; color=\"#bd2ec1\"; " + L"fillcolor=\"#bd2ec180\"; penwidth=2];\n" + L"v20 -- v23\n" + L"v21 [label=\"bra2a\"; style=filled; color=\"#6ecb7d\"; " + L"fillcolor=\"#6ecb7d80\"; penwidth=2];\n" + L"v21 -- v23\n" + L"v22 [label=\"ket2a\"; style=filled; color=\"#cfd472\"; " + L"fillcolor=\"#cfd47280\"; penwidth=2];\n" + L"v22 -- v23\n" + L"v23 [label=\"bka\"; style=filled; color=\"#bd2ec1\"; " + L"fillcolor=\"#bd2ec180\"; penwidth=2];\n" + L"v24 [label=\"g\"; style=filled; color=\"#120912\"; " + L"fillcolor=\"#12091280\"; penwidth=2];\n" + L"v24 -- v27\n" + L"v25 [label=\"bra2a\"; style=filled; color=\"#6ecb7d\"; " + L"fillcolor=\"#6ecb7d80\"; penwidth=2];\n" + L"v25 -- v27\n" + L"v26 [label=\"ket2a\"; style=filled; color=\"#cfd472\"; " + L"fillcolor=\"#cfd47280\"; penwidth=2];\n" + L"v26 -- v27\n" + L"v27 [label=\"bka\"; style=filled; color=\"#120912\"; " + L"fillcolor=\"#12091280\"; penwidth=2];\n" + L"v28 [label=\"t\"; style=filled; color=\"#4b7e1b\"; " + L"fillcolor=\"#4b7e1b80\"; penwidth=2];\n" + L"v28 -- v31\n" + L"v29 [label=\"bra2a\"; style=filled; color=\"#6ecb7d\"; " + L"fillcolor=\"#6ecb7d80\"; penwidth=2];\n" + L"v29 -- v31\n" + L"v30 [label=\"ket2a\"; style=filled; color=\"#cfd472\"; " + L"fillcolor=\"#cfd47280\"; penwidth=2];\n" + L"v30 -- v31\n" + L"v31 [label=\"bka\"; style=filled; color=\"#4b7e1b\"; " + L"fillcolor=\"#4b7e1b80\"; penwidth=2];\n" + L"v32 [label=\"t\"; style=filled; color=\"#4b7e1b\"; " + L"fillcolor=\"#4b7e1b80\"; penwidth=2];\n" + L"v32 -- v35\n" + L"v33 [label=\"bra2a\"; style=filled; color=\"#6ecb7d\"; " + L"fillcolor=\"#6ecb7d80\"; penwidth=2];\n" + L"v33 -- v35\n" + L"v34 [label=\"ket2a\"; style=filled; color=\"#cfd472\"; " + L"fillcolor=\"#cfd47280\"; penwidth=2];\n" + L"v34 -- v35\n" + L"v35 [label=\"bka\"; style=filled; color=\"#4b7e1b\"; " + L"fillcolor=\"#4b7e1b80\"; penwidth=2];\n" + L"v36 [label=\"t\"; style=filled; color=\"#4b7e1b\"; " + L"fillcolor=\"#4b7e1b80\"; penwidth=2];\n" + L"v36 -- v39\n" + L"v37 [label=\"bra2a\"; style=filled; color=\"#6ecb7d\"; " + L"fillcolor=\"#6ecb7d80\"; penwidth=2];\n" + L"v37 -- v39\n" + L"v38 [label=\"ket2a\"; style=filled; color=\"#cfd472\"; " + L"fillcolor=\"#cfd47280\"; penwidth=2];\n" + L"v38 -- v39\n" + L"v39 [label=\"bka\"; style=filled; color=\"#4b7e1b\"; " + L"fillcolor=\"#4b7e1b80\"; penwidth=2];\n" + L"v40 [label=\"ã\"; style=filled; color=\"#e024b7\"; " + L"fillcolor=\"#e024b780\"; penwidth=2];\n" + L"v40 -- v43\n" + L"v41 [label=\"bra2a\"; style=filled; color=\"#6ecb7d\"; " + L"fillcolor=\"#6ecb7d80\"; penwidth=2];\n" + L"v41 -- v43\n" + L"v42 [label=\"ket2a\"; style=filled; color=\"#cfd472\"; " + L"fillcolor=\"#cfd47280\"; penwidth=2];\n" + L"v42 -- v43\n" + L"v43 [label=\"bka\"; style=filled; color=\"#e024b7\"; " + L"fillcolor=\"#e024b780\"; penwidth=2];\n" + L"v44 [label=\"ã\"; style=filled; color=\"#e024b7\"; " + L"fillcolor=\"#e024b780\"; penwidth=2];\n" + L"v44 -- v47\n" + L"v45 [label=\"bra2a\"; style=filled; color=\"#6ecb7d\"; " + L"fillcolor=\"#6ecb7d80\"; penwidth=2];\n" + L"v45 -- v47\n" + L"v46 [label=\"ket2a\"; style=filled; color=\"#cfd472\"; " + L"fillcolor=\"#cfd47280\"; penwidth=2];\n" + L"v46 -- v47\n" + L"v47 [label=\"bka\"; style=filled; color=\"#e024b7\"; " + L"fillcolor=\"#e024b780\"; penwidth=2];\n" + L"v48 [label=\"ã\"; style=filled; color=\"#e024b7\"; " + L"fillcolor=\"#e024b780\"; penwidth=2];\n" + L"v48 -- v51\n" + L"v49 [label=\"bra2a\"; style=filled; color=\"#6ecb7d\"; " + L"fillcolor=\"#6ecb7d80\"; penwidth=2];\n" + L"v49 -- v51\n" + L"v50 [label=\"ket2a\"; style=filled; color=\"#cfd472\"; " + L"fillcolor=\"#cfd47280\"; penwidth=2];\n" + L"v50 -- v51\n" + L"v51 [label=\"bka\"; style=filled; color=\"#e024b7\"; " + L"fillcolor=\"#e024b780\"; penwidth=2];\n" + L"v52 [label=\"ã\"; style=filled; color=\"#e024b7\"; " + L"fillcolor=\"#e024b780\"; penwidth=2];\n" + L"v52 -- v55\n" + L"v53 [label=\"bra2a\"; style=filled; color=\"#6ecb7d\"; " + L"fillcolor=\"#6ecb7d80\"; penwidth=2];\n" + L"v53 -- v55\n" + L"v54 [label=\"ket2a\"; style=filled; color=\"#cfd472\"; " + L"fillcolor=\"#cfd47280\"; penwidth=2];\n" + L"v54 -- v55\n" + L"v55 [label=\"bka\"; style=filled; color=\"#e024b7\"; " + L"fillcolor=\"#e024b780\"; penwidth=2];\n" + L"v56 [label=\"ã\"; style=filled; color=\"#e024b7\"; " + L"fillcolor=\"#e024b780\"; penwidth=2];\n" + L"v56 -- v59\n" + L"v57 [label=\"bra2a\"; style=filled; color=\"#6ecb7d\"; " + L"fillcolor=\"#6ecb7d80\"; penwidth=2];\n" + L"v57 -- v59\n" + L"v58 [label=\"ket2a\"; style=filled; color=\"#cfd472\"; " + L"fillcolor=\"#cfd47280\"; penwidth=2];\n" + L"v58 -- v59\n" + L"v59 [label=\"bka\"; style=filled; color=\"#e024b7\"; " + L"fillcolor=\"#e024b780\"; penwidth=2];\n" + L"}\n"; + + REQUIRE(actual == expected); // compute automorphism group { @@ -356,21 +427,23 @@ TEST_CASE("TensorNetwork", "[elements]") { &save_aut); std::basic_ostringstream oss; bliss::print_auts(aut_generators, oss, decltype(vlabels){}); - REQUIRE(oss.str() == - L"(14,15)\n" - "(6,7)\n" - "(18,19)\n" - "(16,17)\n" - "(8,9)\n" - "(0,1)\n" - "(10,11)\n" - "(12,13)\n" - "(2,3)\n" - "(4,5)\n" - "(2,4)(3,5)(10,12)(11,13)(32,36)(33,37)(34,38)(35,39)(48,52)(49," - "53)(50,54)(51,55)\n" - "(0,2)(1,3)(8,10)(9,11)(28,32)(29,33)(30,34)(31,35)(52,56)(53,57)" - "(54,58)(55,59)\n"); + const std::wstring actual = oss.str(); + const std::wstring expected = + L"(18,19)\n" + L"(16,17)\n" + L"(14,15)\n" + L"(6,7)\n" + L"(12,13)\n" + L"(4,5)\n" + L"(10,11)\n" + L"(8,9)\n" + L"(2,3)\n" + L"(0,1)\n" + L"(0,2)(1,3)(8,10)(9,11)(28,32)(29,33)(30,34)(31,35)(52,56)(53,57)(" + L"54,58)(55,59)\n" + L"(2,4)(3,5)(10,12)(11,13)(32,36)(33,37)(34,38)(35,39)(48,52)(49,53)(" + L"50,54)(51,55)\n"; + REQUIRE(actual == expected); // change to 1 to user vertex labels rather than indices if (0) { @@ -576,4 +649,715 @@ TEST_CASE("TensorNetwork", "[elements]") { } // SECTION("misc1") -} // TEST_CASE("Tensor") +} // TEST_CASE("TensorNetwork") + +template +std::vector to_tensors(const Container& cont) { + std::vector tensors; + + std::transform(cont.begin(), cont.end(), std::back_inserter(tensors), + [](const auto& tensor) { + auto casted = + std::dynamic_pointer_cast(tensor); + REQUIRE(casted != nullptr); + return casted; + }); + return tensors; +} + +template +sequant::ExprPtr to_product(const Container& container) { + return sequant::ex(to_tensors(container)); +} + +namespace sequant { +class TensorNetworkV2Accessor { + public: + auto get_canonical_bliss_graph( + sequant::TensorNetworkV2 tn, + const sequant::TensorNetwork::named_indices_t* named_indices = nullptr) { + tn.canonicalize_graph(named_indices ? *named_indices : tn.ext_indices_); + tn.init_edges(); + auto graph = tn.create_graph(named_indices); + return std::make_pair(std::move(graph.bliss_graph), graph.vertex_labels); + } +}; +} // namespace sequant + +TEST_CASE("TensorNetworkV2", "[elements]") { + using namespace sequant; + using namespace sequant::mbpt; + using sequant::Context; + namespace t = sequant::mbpt::tensor; + namespace o = sequant::mbpt::op; + + sequant::set_default_context(Context( + mbpt::make_sr_spaces(), Vacuum::SingleProduct, IndexSpaceMetric::Unit, + BraKetSymmetry::conjugate, SPBasis::spinorbital)); + + SECTION("Edges") { + using Vertex = TensorNetworkV2::Vertex; + using Edge = TensorNetworkV2::Edge; + using Origin = TensorNetworkV2::Origin; + + Vertex v1(Origin::Bra, 0, 1, Symmetry::antisymm); + Vertex v2(Origin::Bra, 0, 0, Symmetry::antisymm); + Vertex v3(Origin::Ket, 1, 0, Symmetry::symm); + Vertex v4(Origin::Ket, 1, 3, Symmetry::symm); + Vertex v5(Origin::Bra, 3, 0, Symmetry::nonsymm); + Vertex v6(Origin::Bra, 3, 2, Symmetry::nonsymm); + Vertex v7(Origin::Ket, 3, 1, Symmetry::nonsymm); + Vertex v8(Origin::Ket, 5, 0, Symmetry::symm); + + const Index dummy(L"a_1"); + + Edge e1(v1, dummy); + e1.connect_to(v4); + Edge e2(v2, dummy); + e2.connect_to(v3); + Edge e3(v3, dummy); + e3.connect_to(v5); + Edge e4(v4, dummy); + e4.connect_to(v6); + + Edge e5(v8, dummy); + e5.connect_to(v6); + Edge e6(v8, dummy); + e6.connect_to(v7); + + Edge e7(v4, dummy); + + // Due to tensor symmetries, these edges are considered equal + REQUIRE(e1 == e2); + REQUIRE(!(e1 < e2)); + REQUIRE(!(e2 < e1)); + + // Smallest terminal index wins + REQUIRE(!(e1 == e3)); + REQUIRE(e1 < e3); + REQUIRE(!(e3 < e1)); + + // For non-symmetric tensors the connection slot is taken into account + REQUIRE(!(e3 == e4)); + REQUIRE(e3 < e4); + REQUIRE(!(e4 < e3)); + + REQUIRE(!(e5 == e6)); + REQUIRE(e5 < e6); + REQUIRE(!(e6 < e5)); + + // Unconnected edges always come before fully connected ones + REQUIRE(!(e7 == e1)); + REQUIRE(e7 < e1); + REQUIRE(!(e1 < e7)); + } + + SECTION("constructors") { + { // with Tensors + auto t1 = ex(L"F", bra{L"i_1"}, ket{L"i_2"}); + auto t2 = ex(L"t", bra{L"i_1"}, ket{L"i_2"}); + auto t1_x_t2 = t1 * t2; + REQUIRE_NOTHROW(TensorNetworkV2(*t1_x_t2)); + + auto t1_x_t2_p_t2 = t1 * (t2 + t2); // can only use a flat tensor product + REQUIRE_THROWS_AS(TensorNetworkV2(*t1_x_t2_p_t2), std::logic_error); + } + + { // with NormalOperators + constexpr const auto V = Vacuum::SingleProduct; + auto t1 = ex(cre({L"i_1"}), ann({L"i_2"}), V); + auto t2 = ex(cre({L"i_2"}), ann({L"i_1"}), V); + auto t1_x_t2 = t1 * t2; + REQUIRE_NOTHROW(TensorNetworkV2(*t1_x_t2)); + } + + { // with Tensors and NormalOperators + auto tmp = t::A(nₚ(-2)) * t::H_(2) * t::T_(2) * t::T_(2); + REQUIRE_NOTHROW(TensorNetworkV2(tmp->as().factors())); + } + + } // SECTION("constructors") + + SECTION("accessors") { + { + constexpr const auto V = Vacuum::SingleProduct; + auto t1 = ex(L"F", bra{L"i_1"}, ket{L"i_2"}); + auto t2 = ex(cre({L"i_1"}), ann({L"i_3"}), V); + auto t1_x_t2 = t1 * t2; + REQUIRE_NOTHROW(TensorNetworkV2(*t1_x_t2)); + TensorNetworkV2 tn(*t1_x_t2); + + // edges + auto edges = tn.edges(); + REQUIRE(edges.size() == 3); + + // ext indices + auto ext_indices = tn.ext_indices(); + REQUIRE(ext_indices.size() == 2); + + // tensors + auto tensors = tn.tensors(); + REQUIRE(size(tensors) == 2); + REQUIRE(std::dynamic_pointer_cast(tensors[0])); + REQUIRE(std::dynamic_pointer_cast(tensors[1])); + REQUIRE(*std::dynamic_pointer_cast(tensors[0]) == *t1); + REQUIRE(*std::dynamic_pointer_cast(tensors[1]) == *t2); + } + } // SECTION("accessors") + + SECTION("canonicalizer") { + { // with no external indices, hence no named indices whatsoever + Index::reset_tmp_index(); + constexpr const auto V = Vacuum::SingleProduct; + auto t1 = ex(L"F", bra{L"i_1"}, ket{L"i_2"}); + auto t2 = ex(cre({L"i_1"}), ann({L"i_2"}), V); + auto t1_x_t2 = t1 * t2; + TensorNetworkV2 tn(*t1_x_t2); + tn.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), false); + + REQUIRE(size(tn.tensors()) == 2); + REQUIRE(std::dynamic_pointer_cast(tn.tensors()[0])); + REQUIRE(std::dynamic_pointer_cast(tn.tensors()[1])); + // std::wcout << + // to_latex(std::dynamic_pointer_cast(tn.tensors()[0])) << + // std::endl; std::wcout << + // to_latex(std::dynamic_pointer_cast(tn.tensors()[1])) << + // std::endl; + REQUIRE(to_latex(std::dynamic_pointer_cast(tn.tensors()[0])) == + L"{F^{{i_2}}_{{i_1}}}"); + REQUIRE(to_latex(std::dynamic_pointer_cast(tn.tensors()[1])) == + L"{\\tilde{a}^{{i_1}}_{{i_2}}}"); + } + + { + Index::reset_tmp_index(); + constexpr const auto V = Vacuum::SingleProduct; + auto t1 = ex(L"F", bra{L"i_2"}, ket{L"i_17"}); + auto t2 = ex(cre({L"i_2"}), ann({L"i_3"}), V); + auto t1_x_t2 = t1 * t2; + + // with all external named indices + SECTION("implicit") { + TensorNetworkV2 tn(*t1_x_t2); + tn.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), false); + + REQUIRE(size(tn.tensors()) == 2); + REQUIRE(std::dynamic_pointer_cast(tn.tensors()[0])); + REQUIRE(std::dynamic_pointer_cast(tn.tensors()[1])); + // std::wcout << + // to_latex(std::dynamic_pointer_cast(tn.tensors()[0])) << + // std::endl; std::wcout << + // to_latex(std::dynamic_pointer_cast(tn.tensors()[1])) << + // std::endl; + REQUIRE(to_latex(std::dynamic_pointer_cast(tn.tensors()[1])) == + L"{\\tilde{a}^{{i_1}}_{{i_3}}}"); + REQUIRE(to_latex(std::dynamic_pointer_cast(tn.tensors()[0])) == + L"{F^{{i_{17}}}_{{i_1}}}"); + } + + // with explicit named indices + SECTION("explicit") { + Index::reset_tmp_index(); + TensorNetworkV2 tn(*t1_x_t2); + + using named_indices_t = TensorNetworkV2::NamedIndexSet; + named_indices_t indices{Index{L"i_17"}}; + tn.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), false, + &indices); + + REQUIRE(size(tn.tensors()) == 2); + REQUIRE(std::dynamic_pointer_cast(tn.tensors()[0])); + REQUIRE(std::dynamic_pointer_cast(tn.tensors()[1])); + // std::wcout << + // to_latex(std::dynamic_pointer_cast(tn.tensors()[0])) + // << std::endl; std::wcout << + // to_latex(std::dynamic_pointer_cast(tn.tensors()[1])) + // << std::endl; + REQUIRE(to_latex(std::dynamic_pointer_cast(tn.tensors()[1])) == + L"{\\tilde{a}^{{i_2}}_{{i_1}}}"); + REQUIRE(to_latex(std::dynamic_pointer_cast(tn.tensors()[0])) == + L"{F^{{i_{17}}}_{{i_2}}}"); + } + } + + SECTION("particle non-conserving") { + const auto input1 = parse_expr(L"P{;a1,a3}"); + const auto input2 = parse_expr(L"P{a1,a3;}"); + const std::wstring expected1 = L"{{P^{{a_1}{a_3}}_{}}}"; + const std::wstring expected2 = L"{{P^{}_{{a_1}{a_3}}}}"; + + for (int variant : {1, 2}) { + for (bool fast : {true, false}) { + TensorNetworkV2 tn( + std::vector{variant == 1 ? input1 : input2}); + tn.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), fast); + REQUIRE(tn.tensors().size() == 1); + auto result = ex(to_tensors(tn.tensors())); + REQUIRE(to_latex(result) == (variant == 1 ? expected1 : expected2)); + } + } + } + + SECTION("non-symmetric") { + const auto input = + parse_expr(L"A{i7,i3;i9,i12}:A I1{i7,i3;;x5}:N I2{;i9,i12;x5}:N") + .as() + .factors(); + const std::wstring expected = + L"A{i_1,i_2;i_3,i_4}:A * I1{i_1,i_2;;x_1}:N * I2{;i_3,i_4;x_1}:N"; + + for (bool fast : {true, false}) { + TensorNetworkV2 tn(input); + tn.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), fast); + const auto result = ex(to_tensors(tn.tensors())); + REQUIRE_THAT(result, SimplifiesTo(expected)); + } + } + + SECTION("particle-1,2-symmetry") { + const std::vector> pairs = { + {L"S{i_1,i_2,i_3;a_1,a_2,a_3}:N * f{i_4;i_2}:N * " + L"t{a_1,a_2,a_3;i_4,i_3,i_1}:N", + L"S{i_1,i_2,i_3;a_1,a_2,a_3}:N * f{i_4;i_1}:N * " + L"t{a_1,a_2,a_3;i_2,i_3,i_4}:N"}, + {L"Γ{o_2,o_4;o_1,o_3}:N * g{i_1,o_1;o_2,e_1}:N * " + L"t{o_3,e_1;o_4,i_1}:N", + L"Γ{o_2,o_4;o_1,o_3}:N * g{i_1,o_3;o_4,e_1}:N * " + L"t{o_1,e_1;o_2,i_1}:N"}}; + for (const auto& pair : pairs) { + const auto first = parse_expr(pair.first).as().factors(); + const auto second = parse_expr(pair.second).as().factors(); + + TensorNetworkV2Accessor accessor; + auto [first_graph, first_labels] = + accessor.get_canonical_bliss_graph(TensorNetworkV2(first)); + auto [second_graph, second_labels] = + accessor.get_canonical_bliss_graph(TensorNetworkV2(second)); + if (first_graph->cmp(*second_graph) != 0) { + std::wstringstream stream; + stream << "First graph:\n"; + first_graph->write_dot(stream, first_labels, true); + stream << "Second graph:\n"; + second_graph->write_dot(stream, second_labels, true); + stream << "TN graph:\n"; + auto [wick_graph, labels, d1, d2] = + TensorNetwork(first).make_bliss_graph(); + wick_graph->write_dot(stream, labels, true); + + FAIL(to_string(stream.str())); + } + + TensorNetworkV2 tn1(first); + TensorNetworkV2 tn2(second); + + tn1.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), false); + tn2.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), false); + + REQUIRE(tn1.tensors().size() == tn2.tensors().size()); + for (std::size_t i = 0; i < tn1.tensors().size(); ++i) { + auto t1 = std::dynamic_pointer_cast(tn1.tensors()[i]); + auto t2 = std::dynamic_pointer_cast(tn2.tensors()[i]); + REQUIRE(t1); + REQUIRE(t2); + REQUIRE(to_latex(t1) == to_latex(t2)); + } + } + } + + SECTION("miscellaneous") { + const std::vector> inputs = { + {L"g{i_1,a_1;i_2,i_3}:A * I{i_2,i_3;i_1,a_1}:A", + L"g{i_1,a_1;i_2,i_3}:A * I{i_2,i_3;i_1,a_1}:A"}, + {L"g{a_1,i_1;i_2,i_3}:A * I{i_2,i_3;i_1,a_1}:A", + L"-1 g{i_1,a_1;i_2,i_3}:A * I{i_2,i_3;i_1,a_1}:A"}, + + {L"g{i_1,a_1;i_2,i_3}:N * I{i_2,i_3;i_1,a_1}:N", + L"g{i_1,a_1;i_2,i_3}:N * I{i_2,i_3;i_1,a_1}:N"}, + {L"g{a_1,i_1;i_2,i_3}:N * I{i_2,i_3;i_1,a_1}:N", + L"g{i_1,a_1;i_2,i_3}:N * I{i_2,i_3;a_1,i_1}:N"}, + }; + + for (const auto& [input, expected] : inputs) { + const auto input_tensors = parse_expr(input).as().factors(); + + TensorNetworkV2 tn(input_tensors); + ExprPtr factor = tn.canonicalize( + TensorCanonicalizer::cardinal_tensor_labels(), true); + + ExprPtr prod = to_product(tn.tensors()); + if (factor) { + prod = ex( + prod.as().scale(factor.as().value())); + } + + REQUIRE_THAT(prod, SimplifiesTo(expected)); + } + } + +#ifndef SEQUANT_SKIP_LONG_TESTS + SECTION("Exhaustive SRCC example") { + // Note: the exact canonical form written here is implementation-defined + // and doesn't actually matter What does, is that all equivalent ways of + // writing it down, canonicalizes to the same exact form + const Product expectedExpr = + parse_expr( + L"A{i1,i2;a1,a2} g{i3,i4;a3,a4} t{a1,a3;i3,i4} t{a2,a4;i1,i2}", + Symmetry::antisymm) + .as(); + + const auto expected = expectedExpr.factors(); + + TensorNetworkV2Accessor accessor; + const auto [canonical_graph, canonical_graph_labels] = + accessor.get_canonical_bliss_graph(TensorNetworkV2(expected)); + + std::wcout << "Canonical graph:\n"; + canonical_graph->write_dot(std::wcout, canonical_graph_labels); + std::wcout << std::endl; + + std::vector indices; + for (std::size_t i = 0; i < expected.size(); ++i) { + const Tensor& tensor = expected[i].as(); + for (const Index& idx : tensor.indices()) { + if (std::find(indices.begin(), indices.end(), idx) == indices.end()) { + indices.push_back(idx); + } + } + } + std::sort(indices.begin(), indices.end()); + + const auto original_indices = indices; + + // Make sure to clone all expressions in order to not accidentally + // modify the ones in expected (even though they are const... the + // pointer-like semantics of expressions messes with const semantics) + std::remove_const_t factors; + for (const auto& factor : expected) { + factors.push_back(factor.clone()); + } + std::sort(factors.begin(), factors.end()); + + const auto is_occ = [](const Index& idx) { + return idx.space() == Index(L"i_1").space(); + }; + + // Iterate over all tensor permutations and all permutations of possible + // index name swaps + REQUIRE(std::is_sorted(factors.begin(), factors.end())); + REQUIRE(std::is_sorted(indices.begin(), indices.end())); + REQUIRE(std::is_partitioned(indices.begin(), indices.end(), is_occ)); + REQUIRE(std::partition_point(indices.begin(), indices.end(), is_occ) == + indices.begin() + 4); + std::size_t total_variations = 0; + do { + do { + do { + total_variations++; + + // Compute index replacements + container::map idxrepl; + for (std::size_t i = 0; i < indices.size(); ++i) { + REQUIRE(original_indices[i].space() == indices[i].space()); + + idxrepl.insert( + std::make_pair(original_indices.at(i), indices.at(i))); + } + + // Apply index replacements to a copy of the current tensor + // permutation + auto copy = factors; + for (ExprPtr& expr : copy) { + expr.as().transform_indices(idxrepl); + reset_tags(expr.as()); + } + + TensorNetworkV2 tn(copy); + + // At the heart of our canonicalization lies the fact that we can + // always create the uniquely defined canonical graph for a given + // network + const auto [current_graph, current_graph_labels] = + accessor.get_canonical_bliss_graph(tn); + if (current_graph->cmp(*canonical_graph) != 0) { + std::wcout << "Canonical graph for " << deparse(ex(copy)) + << ":\n"; + current_graph->write_dot(std::wcout, current_graph_labels); + std::wcout << std::endl; + } + REQUIRE(current_graph->cmp(*canonical_graph) == 0); + + tn.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), + false); + + std::vector actual; + std::transform(tn.tensors().begin(), tn.tensors().end(), + std::back_inserter(actual), [](const auto& t) { + assert(std::dynamic_pointer_cast(t)); + return std::dynamic_pointer_cast(t); + }); + + // The canonical graph must not change due to the other + // canonicalization steps we perform + REQUIRE(accessor.get_canonical_bliss_graph(TensorNetworkV2(actual)) + .first->cmp(*canonical_graph) == 0); + + REQUIRE(actual.size() == expected.size()); + + if (!std::equal(expected.begin(), expected.end(), actual.begin())) { + std::wostringstream sstream; + sstream + << "Expected all tensors to be equal (actual == expected), " + "but got:\n"; + for (std::size_t i = 0; i < expected.size(); ++i) { + std::wstring equality = + actual[i] == expected[i] ? L" == " : L" != "; + + sstream << deparse(actual[i]) << equality + << deparse(expected[i]) << "\n"; + } + sstream << "\nInput was " << deparse(ex(factors)) + << "\n"; + FAIL(to_string(sstream.str())); + } + } while (std::next_permutation(indices.begin() + 4, indices.end())); + } while (std::next_permutation(indices.begin(), indices.begin() + 4)); + } while (std::next_permutation(factors.begin(), factors.end())); + + // 4! (tensors) * 4! (internal indices) * 4! (external indices) + REQUIRE(total_variations == 24 * 24 * 24); + } +#endif + + SECTION("idempotency") { + const std::vector inputs = { + L"F{i1;i8} g{i8,i9;i1,i7}", + L"A{i7,i3;i9,i12}:A I1{i7,i3;;x5}:N I2{;i9,i12;x5}:N", + L"f{i4;i1}:N t{a1,a2,a3;i2,i3,i4}:N S{i1,i2,i3;a1,a2,a3}:N", + L"P{a1,a3;} k{i8;i2}", + L"L{x6;;x2} P{;a1,a3}", + }; + + for (const std::wstring& current : inputs) { + auto factors1 = parse_expr(current).as().factors(); + auto factors2 = parse_expr(current).as().factors(); + + TensorNetworkV2 reference_tn(factors1); + reference_tn.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), + false); + + TensorNetworkV2 check_tn(factors2); + check_tn.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), + false); + + REQUIRE(to_latex(to_product(reference_tn.tensors())) == + to_latex(to_product(check_tn.tensors()))); + + for (bool fast : {true, false, true, true, false, false, true}) { + reference_tn.canonicalize( + TensorCanonicalizer::cardinal_tensor_labels(), fast); + + REQUIRE(to_latex(to_product(reference_tn.tensors())) == + to_latex(to_product(check_tn.tensors()))); + } + } + } // SECTION("idempotency") + + } // SECTION("canonicalizer") + + SECTION("misc1") { + if (false) { + Index::reset_tmp_index(); + // TN1 from manuscript + auto g = ex(L"g", bra{L"i_3", L"i_4"}, ket{L"a_3", L"a_4"}, + Symmetry::antisymm); + auto ta = ex(L"t", bra{L"a_1", L"a_3"}, ket{L"i_1", L"i_2"}, + Symmetry::antisymm); + auto tb = ex(L"t", bra{L"a_2", L"a_4"}, ket{L"i_3", L"i_4"}, + Symmetry::antisymm); + + auto tmp = g * ta * tb; + // std::wcout << "TN1 = " << to_latex(tmp) << std::endl; + TensorNetworkV2 tn(tmp->as().factors()); + + // make graph + // N.B. treat all indices as dummy so that the automorphism ignores the + using named_indices_t = TensorNetworkV2::NamedIndexSet; + named_indices_t indices{}; + REQUIRE_NOTHROW(tn.create_graph(&indices)); + TensorNetworkV2::Graph graph = tn.create_graph(&indices); + + // create dot + { + std::basic_ostringstream oss; + REQUIRE_NOTHROW(graph.bliss_graph->write_dot(oss, graph.vertex_labels)); + // std::wcout << "oss.str() = " << std::endl << oss.str() << + // std::endl; + } + + bliss::Stats stats; + graph.bliss_graph->set_splitting_heuristic(bliss::Graph::shs_fsm); + + std::vector> aut_generators; + auto save_aut = [&aut_generators](const unsigned int n, + const unsigned int* aut) { + aut_generators.emplace_back(aut, aut + n); + }; + graph.bliss_graph->find_automorphisms( + stats, &bliss::aut_hook, &save_aut); + CHECK(aut_generators.size() == + 2); // there are 2 generators, i1<->i2, i3<->i4 + + std::basic_ostringstream oss; + bliss::print_auts(aut_generators, oss, graph.vertex_labels); + CHECK(oss.str() == L"({i_3},{i_4})\n({i_1},{i_2})\n"); + // std::wcout << oss.str() << std::endl; + } + + // profile canonicalization for synthetic tests in + // DOI 10.1016/j.cpc.2018.02.014 + if (false) { + for (auto testcase : {0, 1, 2, 3}) { + // - testcase=0,2 are "equivalent" and correspond to the "frustrated" + // case in Section 5.3 of DOI 10.1016/j.cpc.2018.02.014 + // - testcase=1 corresponds to the "frustrated" case in Section 5.4 of + // DOI 10.1016/j.cpc.2018.02.014 + // - testcase=3 corresponds to the "No symmetry dummy" + // case in Section 5.1 of DOI 10.1016/j.cpc.2018.02.014 + if (testcase == 0) + std::wcout << "canonicalizing network with 1 totally-symmetric " + "tensor with " + "N indices and 1 asymmetric tensor with N indices\n"; + else if (testcase == 3) + std::wcout << "canonicalizing network with 1 asymmetric tensor with " + "N indices and 1 asymmetric tensor with N indices\n"; + else + std::wcout << "canonicalizing network with n equivalent asymmetric " + "tensors with N/n indices each and 1 asymmetric tensor " + "with N indices\n"; + + std::wcout << "N,n,min_time,geommean_time,max_time\n"; + + for (auto N : + {1, 2, 4, 8, 16, 32, 64, 128, 256}) { // total number of indices + + int n; + switch (testcase) { + case 0: + n = 1; + break; + case 1: + n = N / 2; + break; + case 2: + n = N; + break; + case 3: + n = 1; + break; + default: + abort(); + } + if (n == 0 || n > N) continue; + + auto ctx_resetter = set_scoped_default_context( + (static_cast(N) > Index::min_tmp_index()) + ? Context(get_default_context()) + .set_first_dummy_index_ordinal(N + 1) + : get_default_context()); + + // make list of indices + std::vector indices; + for (auto i = 0; i != N; ++i) { + std::wostringstream oss; + oss << "i_" << i; + indices.emplace_back(oss.str()); + } + std::random_device rd; + + // randomly sample connectivity between bra and ket tensors + const auto S = 10; // how many samples to take + + auto product_time = + 1.; // product of all times, need to get geometric mean + auto min_time = std::numeric_limits::max(); // total time for + // all samples + auto max_time = std::numeric_limits::min(); // total time for + // all samples + for (auto s = 0; s != S; ++s) { + // make tensors of independently (and randomly) permuted + // contravariant and covariant indices + auto contravariant_indices = indices; + auto covariant_indices = indices; + + std::shuffle(contravariant_indices.begin(), + contravariant_indices.end(), std::mt19937{rd()}); + std::shuffle(covariant_indices.begin(), covariant_indices.end(), + std::mt19937{rd()}); + + auto utensors = + covariant_indices | ranges::views::chunk(N / n) | + ranges::views::transform([&](const auto& idxs) { + return ex( + L"u", bra(idxs), ket{}, + (testcase == 3 + ? Symmetry::nonsymm + : ((n == 1) ? Symmetry::symm : Symmetry::nonsymm))); + }) | + ranges::to_vector; + CHECK(utensors.size() == static_cast(n)); + auto dtensors = + contravariant_indices | ranges::views::chunk(N) | + ranges::views::transform([&](const auto& idxs) { + return ex(L"d", bra{}, ket(idxs), Symmetry::nonsymm); + }) | + ranges::to_vector; + CHECK(dtensors.size() == 1); + + ExprPtr expr; + for (auto g = 0; g != n; ++g) { + if (g == 0) + expr = utensors[0] * dtensors[0]; + else + expr = expr * utensors[g]; + } + + TensorNetworkV2 tn(expr->as().factors()); + + // produce misc data for publication + if (false && s == 0) { + std::wcout << "N=" << N << " n=" << n << " expr:\n" + << expr->to_latex() << std::endl; + + // make graph + REQUIRE_NOTHROW(tn.create_graph()); + TensorNetworkV2::Graph graph = tn.create_graph(); + + // create dot + std::basic_ostringstream oss; + REQUIRE_NOTHROW( + graph.bliss_graph->write_dot(oss, graph.vertex_labels)); + std::wcout << "bliss graph:" << std::endl + << oss.str() << std::endl; + } + + sequant::TimerPool<> timer; + timer.start(); + tn.canonicalize(TensorCanonicalizer::cardinal_tensor_labels(), + false); + timer.stop(); + const auto elapsed_seconds = timer.read(); + product_time *= elapsed_seconds; + min_time = std::min(min_time, elapsed_seconds); + max_time = std::max(max_time, elapsed_seconds); + } + + const auto geommean_time = std::pow(product_time, 1. / S); + std::wcout << N << "," << n << "," << min_time << "," << geommean_time + << "," << max_time << "\n"; + } + } + } + + } // SECTION("misc1") + +} // TEST_CASE("TensorNetworkV2") diff --git a/tests/unit/test_utilities.cpp b/tests/unit/test_utilities.cpp index 42a77ff4b..54be913e5 100644 --- a/tests/unit/test_utilities.cpp +++ b/tests/unit/test_utilities.cpp @@ -1,34 +1,67 @@ #include #include +#include "catch2_sequant.hpp" + #include #include #include #include #include +#include #include #include #include #include #include +#include #include +#include +#include -namespace Catch { +sequant::Tensor parse_tensor(std::wstring_view str) { + return sequant::parse_expr(str)->as(); +} -// Note: For some reason this template specialization is never used. It works -// for custom types but not for sequant::Index. -template <> -struct StringMaker { - static std::string convert(const sequant::Index& idx) { - using convert_type = std::codecvt_utf8; - std::wstring_convert converter; +TEST_CASE("TEST GET_UNCONCTRACTED_INDICES", "[utilities]") { + using namespace sequant; - return converter.to_bytes(sequant::to_latex(idx)); + SECTION("dot_product") { + std::vector> inputs = { + {L"t{}", L"t{}"}, + {L"t{i1}", L"t{;i1}"}, + {L"t{;i1}", L"t{i1}"}, + {L"t{i1;a1}", L"t{a1;i1}"}, + {L"t{i1;a1;x1}", L"t{a1;i1;x1}"}, + {L"t{;i1;x1}", L"t{i1;;x1}"}, + {L"t{i1;;x1}", L"t{;i1;x1}"}, + {L"t{;;x1}", L"t{;;x1}"}, + }; + + for (auto [left, right] : inputs) { + auto [bra, ket, aux] = + get_uncontracted_indices(parse_tensor(left), parse_tensor(right)); + + REQUIRE(bra.size() == 0); + REQUIRE(ket.size() == 0); + REQUIRE(aux.size() == 0); + } } -}; -} // namespace Catch + SECTION("partial_contraction") { + auto [bra, ket, aux] = get_uncontracted_indices>( + parse_tensor(L"t{i1,i2;a1,a2;x1,x2}"), parse_tensor(L"t{a1;i2;x2}")); + + std::vector expectedBra = {Index(L"i_1")}; + std::vector expectedKet = {Index(L"a_2")}; + std::vector expectedAux = {Index(L"x_1")}; + + REQUIRE_THAT(bra, Catch::Matchers::UnorderedEquals(expectedBra)); + REQUIRE_THAT(ket, Catch::Matchers::UnorderedEquals(expectedKet)); + REQUIRE_THAT(aux, Catch::Matchers::UnorderedEquals(expectedAux)); + } +} TEST_CASE("get_unique_indices", "[utilities]") { using namespace sequant; @@ -53,13 +86,14 @@ TEST_CASE("get_unique_indices", "[utilities]") { REQUIRE(indices == get_unique_indices(expression->as())); } SECTION("Tensor") { - auto expression = parse_expr(L"t{i1;a1,a2}"); + auto expression = parse_expr(L"t{i1;a1,a2;x1}"); auto indices = get_unique_indices(expression); REQUIRE_THAT(indices.bra, UnorderedEquals(std::vector{{L"i_1"}})); REQUIRE_THAT(indices.ket, UnorderedEquals(std::vector{{L"a_1", L"a_2"}})); + REQUIRE_THAT(indices.aux, UnorderedEquals(std::vector{{L"x_1"}})); REQUIRE(indices == get_unique_indices(expression->as())); expression = parse_expr(L"t{i1,i2;a1,a2}"); @@ -70,6 +104,7 @@ TEST_CASE("get_unique_indices", "[utilities]") { UnorderedEquals(std::vector{{L"i_1"}, {L"i_2"}})); REQUIRE_THAT(indices.ket, UnorderedEquals(std::vector{{L"a_1", L"a_2"}})); + REQUIRE(indices.aux.size() == 0); REQUIRE(indices == get_unique_indices(expression->as())); expression = parse_expr(L"t{i1,i2;a1,i1}"); @@ -81,16 +116,17 @@ TEST_CASE("get_unique_indices", "[utilities]") { REQUIRE(indices == get_unique_indices(expression->as())); } SECTION("Product") { - auto expression = parse_expr(L"t{i1;a1,a2} p{a2;i2}"); + auto expression = parse_expr(L"t{i1;a1,a2} p{a2;i2;x1}"); auto indices = get_unique_indices(expression); REQUIRE_THAT(indices.bra, UnorderedEquals(std::vector{{L"i_1"}})); REQUIRE_THAT(indices.ket, UnorderedEquals(std::vector{{L"a_1", L"i_2"}})); + REQUIRE_THAT(indices.aux, UnorderedEquals(std::vector{{L"x_1"}})); REQUIRE(indices == get_unique_indices(expression->as())); - expression = parse_expr(L"1/8 g{a3,a4;i3,i4} t{a1,a4;i1,i4}"); + expression = parse_expr(L"1/8 g{a3,a4;i3,i4;x1} t{a1,a4;i1,i4;x1}"); indices = get_unique_indices(expression); @@ -98,15 +134,17 @@ TEST_CASE("get_unique_indices", "[utilities]") { UnorderedEquals(std::vector{{L"a_3"}, {L"a_1"}})); REQUIRE_THAT(indices.ket, UnorderedEquals(std::vector{{L"i_3", L"i_1"}})); + REQUIRE(indices.aux.size() == 0); REQUIRE(indices == get_unique_indices(expression->as())); } SECTION("Sum") { - auto expression = parse_expr(L"t{i1;a2} + g{i1;a2}"); + auto expression = parse_expr(L"t{i1;a2;x1} + g{i1;a2;x1}"); auto indices = get_unique_indices(expression); REQUIRE_THAT(indices.bra, UnorderedEquals(std::vector{{L"i_1"}})); REQUIRE_THAT(indices.ket, UnorderedEquals(std::vector{{L"a_2"}})); + REQUIRE_THAT(indices.aux, UnorderedEquals(std::vector{{L"x_1"}})); REQUIRE(indices == get_unique_indices(expression->as())); expression = parse_expr(L"t{i1;a2} t{i1;a1} + t{i1;a1} g{i1;a2}"); @@ -116,6 +154,7 @@ TEST_CASE("get_unique_indices", "[utilities]") { REQUIRE(indices.bra.empty()); REQUIRE_THAT(indices.ket, UnorderedEquals(std::vector{{L"a_1"}, {L"a_2"}})); + REQUIRE(indices.aux.empty()); REQUIRE(indices == get_unique_indices(expression->as())); } } diff --git a/tests/unit/test_wick.cpp b/tests/unit/test_wick.cpp index 98c54403e..8564bcef6 100644 --- a/tests/unit/test_wick.cpp +++ b/tests/unit/test_wick.cpp @@ -19,6 +19,7 @@ #include #include +#include "catch2_sequant.hpp" #include "test_config.hpp" #include @@ -646,8 +647,7 @@ TEST_CASE("WickTheorem", "[algorithms][wick]") { ExprPtr result; REQUIRE_NOTHROW(result = wick.compute()); // std::wcout << "result = " << to_latex(result) << std::endl; - REQUIRE(to_latex(result) == - L"{{{-1}}{\\bar{g}^{{i_1}{a_2}}_{{a_3}{a_4}}}}"); + REQUIRE(to_latex(result) == L"{{-}{\\bar{g}^{{i_1}{a_2}}_{{a_3}{a_4}}}}"); } // odd number of ops -> full contraction is 0 @@ -912,12 +912,9 @@ TEST_CASE("WickTheorem", "[algorithms][wick]") { std::wcout << L"spinfree H2*T2 = " << to_latex(wick_result_2) << std::endl; - REQUIRE(to_latex(wick_result_2) == - L"{ " - L"\\bigl({{{8}}{g^{{a_1}{a_2}}_{{i_1}{i_2}}}{t^{{i_1}{i_2}}_{{" - L"a_1}{a_2}}}} - " - L"{{{4}}{g^{{a_1}{a_2}}_{{i_1}{i_2}}}{t^{{i_2}{i_1}}_{{a_1}{a_" - L"2}}}}\\bigr) }"); + REQUIRE_THAT(wick_result_2, + EquivalentTo("-4 * g{i1,i2;a1,a2} t{a1,a2;i2,i1} + 8 " + "g{i1,i2;a1,a2} t{a1,a2;i1,i2}")); } });