From 09fb9f9e5d48a031acb02c71609e3f261dac54b5 Mon Sep 17 00:00:00 2001 From: Robert Adam Date: Wed, 3 Jul 2024 12:09:13 +0200 Subject: [PATCH] Introduce ResultExpr This is a wrapper type that allows to encode the LHS (the result) as well as the LHS (the expression) of an equation. Explicitly tracking the LHS is required in order to access information about which indices are paired together during spin integration (the approach abusing the symmetrizer operators breaks down for results that have mixed index spaces in the bra and/or ket). In the future, this class may solve as a nucleation point for a framework to store arbitrary metadata about a given expression because it is not nestable (contrary to Expr objects). --- CMakeLists.txt | 2 + SeQuant/core/parse.hpp | 5 ++ SeQuant/core/parse/ast.hpp | 13 ++++ SeQuant/core/parse/ast_conversions.hpp | 101 +++++++++++++++---------- SeQuant/core/parse/deparse.cpp | 24 ++++++ SeQuant/core/parse/parse.cpp | 65 ++++++++++++---- SeQuant/core/result_expr.cpp | 55 ++++++++++++++ SeQuant/core/result_expr.hpp | 86 +++++++++++++++++++++ SeQuant/domain/mbpt/spin.cpp | 22 ++++++ SeQuant/domain/mbpt/spin.hpp | 5 ++ tests/unit/test_parse.cpp | 58 ++++++++++++++ tests/unit/test_spin.cpp | 32 ++++++++ 12 files changed, 413 insertions(+), 55 deletions(-) create mode 100644 SeQuant/core/result_expr.cpp create mode 100644 SeQuant/core/result_expr.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ed46b8fc..042c6f26e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -230,6 +230,8 @@ set(SeQuant_src SeQuant/core/parse.hpp SeQuant/core/ranges.hpp SeQuant/core/rational.hpp + SeQuant/core/result_expr.cpp + SeQuant/core/result_expr.hpp SeQuant/core/runtime.cpp SeQuant/core/runtime.hpp SeQuant/core/space.hpp diff --git a/SeQuant/core/parse.hpp b/SeQuant/core/parse.hpp index df9dc9f79..309dc47b9 100644 --- a/SeQuant/core/parse.hpp +++ b/SeQuant/core/parse.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -56,6 +57,9 @@ struct ParseError : std::runtime_error { ExprPtr parse_expr(std::wstring_view raw, Symmetry tensor_sym = Symmetry::nonsymm); +ResultExpr parse_result_expr(std::wstring_view raw, + Symmetry tensor_sym = Symmetry::nonsymm); + /// /// Get a parsable string from an expression. /// @@ -71,6 +75,7 @@ ExprPtr parse_expr(std::wstring_view raw, /// \param annot_sym Whether to add sequant::Symmetry annotation /// to each Tensor string. /// \return wstring of the expression. +std::wstring deparse(const ResultExpr &expr, bool annot_sym = true); std::wstring deparse(const ExprPtr &expr, bool annot_sym = true); std::wstring deparse(const Expr &expr, bool annot_sym = true); std::wstring deparse(const Product &product, bool annot_sym); diff --git a/SeQuant/core/parse/ast.hpp b/SeQuant/core/parse/ast.hpp index ad148b8d6..9a829618f 100644 --- a/SeQuant/core/parse/ast.hpp +++ b/SeQuant/core/parse/ast.hpp @@ -13,6 +13,7 @@ #include #include +#include #include namespace sequant::parse::ast { @@ -115,6 +116,17 @@ Product::Product(std::vector factors) Sum::Sum(std::vector summands) : summands(std::move(summands)) {} +struct ResultExpr : boost::spirit::x3::position_tagged { + std::variant lhs; + Sum rhs; + + ResultExpr(Variable variable = {}, Sum expr = {}) + : lhs(std::move(variable)), rhs(std::move(expr)) {} + + ResultExpr(Tensor tensor, Sum expr) + : lhs(std::move(tensor)), rhs(std::move(expr)) {} +}; + } // namespace sequant::parse::ast BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::IndexLabel, label, id); @@ -127,5 +139,6 @@ BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Tensor, name, indices, symmetry); BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Product, factors); BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Sum, summands); +BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::ResultExpr, lhs, rhs); #endif // SEQUANT_CORE_PARSE_AST_HPP diff --git a/SeQuant/core/parse/ast_conversions.hpp b/SeQuant/core/parse/ast_conversions.hpp index e108b6cd1..8c1c17209 100644 --- a/SeQuant/core/parse/ast_conversions.hpp +++ b/SeQuant/core/parse/ast_conversions.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -156,54 +157,53 @@ ExprPtr ast_to_expr(const parse::ast::Sum &sum, Symmetry default_symmetry); template -ExprPtr ast_to_expr(const parse::ast::NullaryValue &value, - const PositionCache &position_cache, const Iterator &begin, - Symmetry default_symmetry) { - struct Transformer { - std::reference_wrapper position_cache; - std::reference_wrapper begin; - std::reference_wrapper default_symmetry; - - ExprPtr operator()(const parse::ast::Product &product) const { - return ast_to_expr(product, position_cache.get(), - begin.get(), default_symmetry); - } +struct Transformer { + std::reference_wrapper position_cache; + std::reference_wrapper begin; + std::reference_wrapper default_symmetry; + + ExprPtr operator()(const parse::ast::Product &product) const { + return ast_to_expr(product, position_cache.get(), + begin.get(), default_symmetry); + } - ExprPtr operator()(const parse::ast::Sum &sum) const { - return ast_to_expr(sum, position_cache.get(), begin.get(), - default_symmetry); - } + ExprPtr operator()(const parse::ast::Sum &sum) const { + return ast_to_expr(sum, position_cache.get(), begin.get(), + default_symmetry); + } - ExprPtr operator()(const parse::ast::Tensor &tensor) const { - auto [braIndices, ketIndices] = - make_indices(tensor.indices, position_cache.get(), begin.get()); + ExprPtr operator()(const parse::ast::Tensor &tensor) const { + auto [braIndices, ketIndices] = + make_indices(tensor.indices, position_cache.get(), begin.get()); - auto [offset, length] = - get_pos(tensor, position_cache.get(), begin.get()); + auto [offset, length] = get_pos(tensor, position_cache.get(), begin.get()); - return ex(tensor.name, std::move(braIndices), - std::move(ketIndices), - to_symmetry(tensor.symmetry, offset + length - 1, - begin.get(), default_symmetry)); - } + return ex(tensor.name, std::move(braIndices), std::move(ketIndices), + to_symmetry(tensor.symmetry, offset + length - 1, + begin.get(), default_symmetry)); + } - ExprPtr operator()(const parse::ast::Variable &variable) const { - if (variable.conjugated) { - return ex(variable.name + L"^*"); - } else { - return ex(variable.name); - } + ExprPtr operator()(const parse::ast::Variable &variable) const { + if (variable.conjugated) { + return ex(variable.name + L"^*"); + } else { + return ex(variable.name); } + } - ExprPtr operator()(const parse::ast::Number &number) const { - return ex( - to_constant(number, position_cache.get(), begin.get())); - } - }; + ExprPtr operator()(const parse::ast::Number &number) const { + return ex(to_constant(number, position_cache.get(), begin.get())); + } +}; +template +ExprPtr ast_to_expr(const parse::ast::NullaryValue &value, + const PositionCache &position_cache, const Iterator &begin, + Symmetry default_symmetry) { return boost::apply_visitor( - Transformer{std::ref(position_cache), std::ref(begin), - std::ref(default_symmetry)}, + Transformer{std::ref(position_cache), + std::ref(begin), + std::ref(default_symmetry)}, value); } @@ -279,6 +279,29 @@ ExprPtr ast_to_expr(const parse::ast::Sum &sum, return ex(std::move(summands)); } +template +ResultExpr ast_to_result(const parse::ast::ResultExpr &result, + const PositionCache &position_cache, + const Iterator &begin, Symmetry default_symmetry) { + ExprPtr lhs = std::visit( + Transformer{std::ref(position_cache), + std::ref(begin), + std::ref(default_symmetry)}, + result.lhs); + ExprPtr rhs = + ast_to_expr(result.rhs, position_cache, begin, default_symmetry); + + if (lhs.is()) { + return {std::move(lhs.as()), std::move(rhs)}; + } else if (lhs.is()) { + return {std::move(lhs.as()), std::move(rhs)}; + } else { + auto [offset, length] = get_pos(result.lhs, position_cache, begin); + throw ParseError(offset, length, + "LHS of a ResultExpr must be a Tensor or a Variable"); + } +} + } // namespace sequant::parse::transform #endif // SEQUANT_CORE_PARSE_AST_CONVERSIONS_HPP diff --git a/SeQuant/core/parse/deparse.cpp b/SeQuant/core/parse/deparse.cpp index 46afebb29..f0fd183d6 100644 --- a/SeQuant/core/parse/deparse.cpp +++ b/SeQuant/core/parse/deparse.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -111,6 +112,29 @@ std::wstring deparse(const ExprPtr& expr, bool annot_sym) { throw std::runtime_error("Unsupported expr type for deparse!"); } +std::wstring deparse(const ResultExpr& result, bool annot_sym) { + std::wstring deparsed; + if (result.has_label()) { + deparsed += result.label(); + } else { + deparsed += L"?"; + } + + if (!result.bra().empty() || !result.ket().empty()) { + deparsed += L"{"; + deparsed += details::deparse_indices(result.bra()); + deparsed += L";"; + deparsed += details::deparse_indices(result.ket()); + deparsed += L"}"; + + if (annot_sym) { + deparsed += L":" + details::deparse_sym(result.symmetry()); + } + } + + return deparsed + L" = " + deparse(result.expression(), annot_sym); +} + std::wstring deparse(const Index& index) { std::wstring deparsed(index.label()); diff --git a/SeQuant/core/parse/parse.cpp b/SeQuant/core/parse/parse.cpp index 5aa528fdb..a806fa711 100644 --- a/SeQuant/core/parse/parse.cpp +++ b/SeQuant/core/parse/parse.cpp @@ -2,15 +2,15 @@ // Created by Robert Adam on 2023-09-20 // -#include -#include -#include -#include - #include #include #include #include +#include +#include +#include +#include +#include #include #include @@ -22,11 +22,18 @@ #include #include +#include #include #include #include +#include #include +#include +#include +#include #include +#include +#include namespace sequant { @@ -44,6 +51,7 @@ struct TensorRule; struct ProductRule; struct SumRule; struct ExprRule; +struct ResultExprRule; struct IndexLabelRule; struct IndexRule; struct IndexGroupRule; @@ -57,6 +65,7 @@ x3::rule tensor{"Tensor"}; x3::rule product{"Product"}; x3::rule sum{"Sum"}; x3::rule expr{"Expression"}; +x3::rule resultExpr{"ResultExpr"}; // Auxiliaries x3::rule name{"Name"}; @@ -121,10 +130,12 @@ auto addend = (('+' >> x3::attr(1) | '-' >> x3::attr(-1)) > product)[a auto sum_def = first_addend >> *addend; auto expr_def = -sum > x3::eoi; + +auto resultExpr_def = (tensor | variable) > (L'=' | x3::lit(L"->")) >> expr; // clang-format on BOOST_SPIRIT_DEFINE(name, number, variable, index_label, index, index_groups, - tensor, product, sum, expr); + tensor, product, sum, expr, resultExpr); struct position_cache_tag; struct error_handler_tag; @@ -159,6 +170,7 @@ struct TensorRule : helpers::annotate_position, helpers::error_handler {}; struct ProductRule : helpers::annotate_position, helpers::error_handler {}; struct SumRule : helpers::annotate_position, helpers::error_handler {}; struct ExprRule : helpers::annotate_position, helpers::error_handler {}; +struct ResultRule : helpers::annotate_position, helpers::error_handler {}; struct IndexLabelRule : helpers::annotate_position, helpers::error_handler {}; struct IndexRule : helpers::annotate_position, helpers::error_handler {}; struct IndexGroupRule : helpers::annotate_position, helpers::error_handler {}; @@ -179,23 +191,23 @@ struct ErrorHandler { } }; -ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) { +template +AST do_parse(const StartRule &start, std::wstring_view input, + PositionCache &positions) { using iterator_type = decltype(input)::iterator; - x3::position_cache> positions(input.begin(), - input.end()); ErrorHandler error_handler(input.begin()); - parse::ast::Sum ast; + AST ast; const auto parser = x3::with( std::ref(error_handler))[x3::with( - std::ref(positions))[parse::expr]]; + std::ref(positions))[start]]; - auto start = input.begin(); + auto begin = input.begin(); try { bool success = - x3::phrase_parse(start, input.end(), parser, x3::unicode::space, ast); + x3::phrase_parse(begin, input.end(), parser, x3::unicode::space, ast); if (!success) { // Normally, this shouldn't happen as any error should itself throw a @@ -203,10 +215,10 @@ ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) { throw ParseError(0, input.size(), "Parsing was unsuccessful for an unknown reason"); } - if (start != input.end()) { + if (begin != input.end()) { // This should also not happen as the parser requires matching EOI - throw ParseError(std::distance(input.begin(), start), - std::distance(start, input.end()), + throw ParseError(std::distance(input.begin(), begin), + std::distance(begin, input.end()), "Couldn't parse the entire input"); } } catch (const boost::spirit::x3::expectation_failure &e) { @@ -216,6 +228,27 @@ ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) { throw; } + return ast; +} + +ResultExpr parse_result_expr(std::wstring_view input, + Symmetry default_symmetry) { + using iterator_type = decltype(input)::iterator; + x3::position_cache> positions(input.begin(), + input.end()); + auto ast = + do_parse(parse::resultExpr, input, positions); + + return parse::transform::ast_to_result(ast, positions, input.begin(), + default_symmetry); +} + +ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) { + using iterator_type = decltype(input)::iterator; + x3::position_cache> positions(input.begin(), + input.end()); + auto ast = do_parse(parse::expr, input, positions); + return parse::transform::ast_to_expr(ast, positions, input.begin(), default_symmetry); } diff --git a/SeQuant/core/result_expr.cpp b/SeQuant/core/result_expr.cpp new file mode 100644 index 000000000..d775f9173 --- /dev/null +++ b/SeQuant/core/result_expr.cpp @@ -0,0 +1,55 @@ +#include +#include +#include + +namespace sequant { + +ResultExpr::ResultExpr(const Tensor &tensor, ExprPtr expression) + : m_expr(std::move(expression)), + m_symm(tensor.symmetry()), + m_bksymm(tensor.braket_symmetry()), + m_psymm(tensor.particle_symmetry()), + m_braIndices(tensor.bra().begin(), tensor.bra().end()), + m_ketIndices(tensor.ket().begin(), tensor.ket().end()), + m_label(tensor.label()) {} + +ResultExpr::ResultExpr(const Variable &variable, ExprPtr expression) + : m_expr(std::move(expression)), m_label(variable.label()) {} + +ResultExpr &ResultExpr::operator=(ExprPtr expression) { + m_expr = std::move(expression); + + return *this; +} + +bool ResultExpr::has_label() const { return m_label.has_value(); } + +const std::wstring &ResultExpr::label() const { return m_label.value(); } + +Symmetry ResultExpr::symmetry() const { return m_symm; } + +void ResultExpr::set_symmetry(Symmetry symm) { m_symm = symm; } + +BraKetSymmetry ResultExpr::braket_symmetry() const { return m_bksymm; } + +void ResultExpr::set_braket_symmetry(BraKetSymmetry symm) { m_bksymm = symm; } + +ParticleSymmetry ResultExpr::particle_symmetry() const { return m_psymm; } + +void ResultExpr::set_particle_symmetry(ParticleSymmetry symm) { + m_psymm = symm; +} + +const ResultExpr::IndexContainer &ResultExpr::bra() const { + return m_braIndices; +} + +const ResultExpr::IndexContainer &ResultExpr::ket() const { + return m_ketIndices; +} + +const ExprPtr &ResultExpr::expression() const { return m_expr; } + +ExprPtr &ResultExpr::expression() { return m_expr; } + +} // namespace sequant diff --git a/SeQuant/core/result_expr.hpp b/SeQuant/core/result_expr.hpp new file mode 100644 index 000000000..c09202e34 --- /dev/null +++ b/SeQuant/core/result_expr.hpp @@ -0,0 +1,86 @@ +#ifndef SEQUANT_RESULT_EXPR_HPP +#define SEQUANT_RESULT_EXPR_HPP + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace sequant { + +class Tensor; +class Variable; + +class ResultExpr { + public: + using IndexContainer = container::svector; + + ResultExpr(const Tensor &tensor, ExprPtr expression); + ResultExpr(const Variable &variable, ExprPtr expression); + + ResultExpr(const ResultExpr &other) = default; + ResultExpr(ResultExpr &&other) = default; + + ResultExpr &operator=(const ResultExpr &other) = default; + ResultExpr &operator=(ResultExpr &&other) = default; + + /// Assigns a new expression to this result + ResultExpr &operator=(ExprPtr expression); + + bool has_label() const; + const std::wstring &label() const; + + Symmetry symmetry() const; + void set_symmetry(Symmetry symm); + + BraKetSymmetry braket_symmetry() const; + void set_braket_symmetry(BraKetSymmetry symm); + + ParticleSymmetry particle_symmetry() const; + void set_particle_symmetry(ParticleSymmetry symm); + + const IndexContainer &bra() const; + const IndexContainer &ket() const; + + const ExprPtr &expression() const; + ExprPtr &expression(); + + template + container::svector index_particle_grouping() const { + container::svector groups; + + assert(m_braIndices.size() == m_ketIndices.size() && + "Not yet generalized to particle non-conserving results"); + + groups.reserve(m_braIndices.size()); + + // Note that the assumption is that indices are sorted + // based on the particle they belong to and that bra and + // ket indices are assigned to the same set of particles. + for (std::size_t i = 0; i < m_braIndices.size(); ++i) { + groups.emplace_back( + std::initializer_list{m_braIndices.at(i), m_ketIndices.at(i)}); + } + + return groups; + } + + private: + ExprPtr m_expr; + + Symmetry m_symm = Symmetry::nonsymm; + BraKetSymmetry m_bksymm = BraKetSymmetry::nonsymm; + ParticleSymmetry m_psymm = ParticleSymmetry::nonsymm; + IndexContainer m_braIndices; + IndexContainer m_ketIndices; + std::optional m_label; +}; + +} // namespace sequant + +#endif // SEQUANT_RESULT_EXPR_HPP diff --git a/SeQuant/domain/mbpt/spin.cpp b/SeQuant/domain/mbpt/spin.cpp index 1369d652a..b21aa5111 100644 --- a/SeQuant/domain/mbpt/spin.cpp +++ b/SeQuant/domain/mbpt/spin.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -1007,6 +1008,16 @@ ExprPtr closed_shell_CC_spintrace(ExprPtr const& expr) { return st_expr; } +ResultExpr closed_shell_spintrace(ResultExpr expr) { + expr.expression() = closed_shell_spintrace( + expr.expression(), + expr.index_particle_grouping>()); + + expr.set_symmetry(Symmetry::nonsymm); + + return expr; +} + ExprPtr closed_shell_CC_spintrace_rigorous(ExprPtr const& expr) { assert(expr->is()); using ranges::views::transform; @@ -1665,6 +1676,17 @@ ExprPtr spintrace( return result; } // ExprPtr spintrace +ResultExpr spintrace(ResultExpr expr, bool spinfree_index_spaces) { + expr.expression() = + spintrace(expr.expression(), + expr.index_particle_grouping>(), + spinfree_index_spaces); + + expr.set_symmetry(Symmetry::nonsymm); + + return expr; +} + ExprPtr factorize_S(const ExprPtr& expression, std::initializer_list ext_index_groups, const bool fast_method) { diff --git a/SeQuant/domain/mbpt/spin.hpp b/SeQuant/domain/mbpt/spin.hpp index 334ac0aee..b9c015819 100644 --- a/SeQuant/domain/mbpt/spin.hpp +++ b/SeQuant/domain/mbpt/spin.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -305,6 +306,8 @@ ExprPtr closed_shell_spintrace( const ExprPtr& expr, const container::svector>& ext_index_groups = {}); +ResultExpr closed_shell_spintrace(ResultExpr expr); + /// /// \brief Given a OpType::A or OpType::S tensor, generates a list of external /// indices. @@ -388,6 +391,8 @@ ExprPtr spintrace( container::svector> ext_index_groups = {}, bool spinfree_index_spaces = true); +ResultExpr spintrace(ResultExpr expr, bool spinfree_index_spaces = true); + /// @brief Factorize S out of terms /// @details Given an expression, permute indices and check if a given product /// @param expression Expression pointer diff --git a/tests/unit/test_parse.cpp b/tests/unit/test_parse.cpp index 8d34bdd6b..4ab9b9f57 100644 --- a/tests/unit/test_parse.cpp +++ b/tests/unit/test_parse.cpp @@ -320,6 +320,50 @@ TEST_CASE("parse_expr", "[parse]") { } } +TEST_CASE("parse_result", "[parse]") { + using namespace sequant; + + SECTION("constant") { + ResultExpr result = parse_result_expr(L"A = 3"); + + REQUIRE(result.has_label()); + REQUIRE(result.label() == L"A"); + REQUIRE(result.bra().empty()); + REQUIRE(result.ket().empty()); + REQUIRE(result.symmetry() == Symmetry::nonsymm); + REQUIRE(result.braket_symmetry() == BraKetSymmetry::nonsymm); + REQUIRE(result.particle_symmetry() == ParticleSymmetry::nonsymm); + + REQUIRE(result.expression().is()); + REQUIRE(result.expression().as().value() == 3); + } + SECTION("contraction") { + ResultExpr result = + parse_result_expr(L"R{i1,i2;e1,e2}:A = f{e2;e3} t{e1,e3;i1,i2}"); + + REQUIRE(result.has_label()); + REQUIRE(result.label() == L"R"); + REQUIRE(result.bra().size() == 2); + REQUIRE(result.bra()[0].full_label() == L"i_1"); + REQUIRE(result.bra()[1].full_label() == L"i_2"); + REQUIRE(result.ket().size() == 2); + REQUIRE(result.ket()[0].full_label() == L"e_1"); + REQUIRE(result.ket()[1].full_label() == L"e_2"); + REQUIRE(result.symmetry() == Symmetry::antisymm); + REQUIRE(result.braket_symmetry() == + get_default_context().braket_symmetry()); + REQUIRE(result.particle_symmetry() == ParticleSymmetry::symm); + + REQUIRE(result.expression().is()); + const Product& prod = result.expression().as(); + REQUIRE(prod.size() == 2); + REQUIRE(prod.factor(0).is()); + REQUIRE(prod.factor(0).as().label() == L"f"); + REQUIRE(prod.factor(1).is()); + REQUIRE(prod.factor(1).as().label() == L"t"); + } +} + TEST_CASE("deparse", "[parse]") { using namespace sequant; @@ -337,4 +381,18 @@ TEST_CASE("deparse", "[parse]") { REQUIRE(deparse(expression, true) == current); } + SECTION("result_expressions") { + std::vector expressions = { + L"A = 5", + L"A = g{i_1,i_2;e_1,e_2}:S * t{e_1,e_2;i_1,i_2}:N", + L"R{i_1,i_2;e_1,e_2}:A = f{e_2;e_3}:A * t{e_1,e_3;i_1,i_2}:A + " + L"g{i_1,i_2;e_1,e_2}:A", + }; + + for (const std::wstring& current : expressions) { + ResultExpr result = parse_result_expr(current); + + REQUIRE(deparse(result, true) == current); + } + } } diff --git a/tests/unit/test_spin.cpp b/tests/unit/test_spin.cpp index 101469265..0070a8582 100644 --- a/tests/unit/test_spin.cpp +++ b/tests/unit/test_spin.cpp @@ -1538,4 +1538,36 @@ SECTION("Open-shell spin-tracing") { REQUIRE(result2[1]->size() == 24); } } + +SECTION("ResultExpr") { + SECTION("closed_shell") { + ResultExpr result = parse_result_expr( + L"R{i1,i2,i3;a1,a2,a3}:A = 1/12 * A{i1,i2,i3;a1,a2,a3}:A f{i4;i1} " + L"t{a1,a2,a3;i2,i3,i4}:A"); + + const ExprPtr expected = closed_shell_spintrace( + result.expression().clone(), + {{L"i_1", L"a_1"}, {L"i_2", L"a_2"}, {L"i_3", L"a_3"}}); + + ResultExpr traced = closed_shell_spintrace(result); + + REQUIRE(traced.expression() == expected); + REQUIRE(traced.symmetry() == Symmetry::nonsymm); + REQUIRE(traced.particle_symmetry() == ParticleSymmetry::symm); + } + SECTION("rigorous") { + ResultExpr result = parse_result_expr( + L"R{i1,i2;e1,e2}:A = 1/4 A{i1,i2;e1,e2}:A g{i3,i4;e3,e4}:A " + L"t{e3,e4;i2,i3}:A t{e1,e2;i1,i4}:A"); + + const ExprPtr expected = spintrace(result.expression().clone(), + {{L"i_1", L"e_1"}, {L"i_2", L"e_2"}}); + + ResultExpr traced = spintrace(result); + + REQUIRE(traced.expression() == expected); + REQUIRE(traced.symmetry() == Symmetry::nonsymm); + REQUIRE(traced.particle_symmetry() == ParticleSymmetry::symm); + } +} }