From 91f5a5c024995d1272883b1db78fe14fd55bee4e Mon Sep 17 00:00:00 2001 From: Robert Adam Date: Thu, 16 Jan 2025 18:50:03 +0100 Subject: [PATCH] Parse: support full symmetry specification --- SeQuant/core/parse.hpp | 24 +++++- SeQuant/core/parse/ast.hpp | 17 +++- SeQuant/core/parse/ast_conversions.hpp | 115 ++++++++++++++++++++----- SeQuant/core/parse/parse.cpp | 20 +++-- tests/unit/test_parse.cpp | 20 +++-- 5 files changed, 160 insertions(+), 36 deletions(-) diff --git a/SeQuant/core/parse.hpp b/SeQuant/core/parse.hpp index df9dc9f79..8804abfb7 100644 --- a/SeQuant/core/parse.hpp +++ b/SeQuant/core/parse.hpp @@ -44,7 +44,25 @@ struct ParseError : std::runtime_error { /// '1.0/2.0 * t{i1;a1} * f{i1; a1}' same as above /// 't{i1,i2; a1, a2}' a tensor having indices with proto indices. /// a1 is an index with i1 and i2 as proto-indices. -/// \param tensor_sym The symmetry of all atomic tensors in the +/// Every tensor may optionally be annoted with index symmetry specifications. The general syntax is +/// [: [- [-]]] +/// (no whitespace is allowed at this place). Examples are +/// 't{i1;i2}:A', 't{i1;i2}:A-S', 't{i1;i2}:N-C-S' +/// Possible values for are +/// - 'A' for antisymmetry (sequant::Symmetry::antisymm) +/// - 'S' for symmetric (sequant::Symmetry::symm) +/// - 'N' for non-symmetric (sequant::Symmetry::nonsymm) +/// Possible values for are +/// - 'C' for antisymmetry (sequant::BraKetSymmetry::conjugate) +/// - 'S' for symmetric (sequant::BraKetSymmetry::symm) +/// - 'N' for non-symmetric (sequant::BraKetSymmetry::nonsymm) +/// Possible values for are +/// - 'S' for symmetric (sequant::ParticleSymmetry::symm) +/// - 'N' for non-symmetric (sequant::ParticleSymmetry::nonsymm) +/// \param perm_symm Default index permutation symmetry to be used if tensors don't specify a permutation +/// symmetry explicitly. +/// \param braket_symm Default BraKet symmetry to be used if tensors don't specify a BraKet symmetry explicitly. +/// \param particle_symm Default particle symmetry to be used if tensors don't specify a particle symmetry explicitly. /// @c raw expression. Explicit tensor symmetry can /// be annotated in the expression itself. In that case, the /// annotated symmetry will be used. @@ -54,7 +72,9 @@ struct ParseError : std::runtime_error { /// \return SeQuant expression. // clang-format on ExprPtr parse_expr(std::wstring_view raw, - Symmetry tensor_sym = Symmetry::nonsymm); + Symmetry perm_symm = Symmetry::nonsymm, + BraKetSymmetry braket_symm = BraKetSymmetry::nonsymm, + ParticleSymmetry particle_symm = ParticleSymmetry::symm); /// /// Get a parsable string from an expression. diff --git a/SeQuant/core/parse/ast.hpp b/SeQuant/core/parse/ast.hpp index 2bf749278..521289652 100644 --- a/SeQuant/core/parse/ast.hpp +++ b/SeQuant/core/parse/ast.hpp @@ -7,6 +7,7 @@ #define BOOST_SPIRIT_X3_UNICODE #include +#include #include #include #include @@ -66,17 +67,23 @@ struct IndexGroups : boost::spirit::x3::position_tagged { reverse_bra_ket(reverse_bra_ket) {} }; +struct SymmetrySpec : boost::spirit::x3::position_tagged { + static constexpr char unspecified = '\0'; + char perm_symm = unspecified; + char braket_symm = unspecified; + char particle_symm = unspecified; +}; + struct Tensor : boost::spirit::x3::position_tagged { - static constexpr char unspecified_symmetry = '\0'; std::wstring name; IndexGroups indices; - char symmetry; + boost::optional symmetry; Tensor(std::wstring name = {}, IndexGroups indices = {}, - char symmetry = unspecified_symmetry) + boost::optional symmetry = {}) : name(std::move(name)), indices(std::move(indices)), - symmetry(symmetry) {} + symmetry(std::move(symmetry)) {} }; struct Product; @@ -125,6 +132,8 @@ BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Number, numerator, denominator); BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Variable, name, conjugated); BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::IndexGroups, bra, ket, auxiliaries, reverse_bra_ket); +BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::SymmetrySpec, perm_symm, + braket_symm, particle_symm); BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Tensor, name, indices, symmetry); BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Product, factors); diff --git a/SeQuant/core/parse/ast_conversions.hpp b/SeQuant/core/parse/ast_conversions.hpp index 6e1025d7e..f47580d7b 100644 --- a/SeQuant/core/parse/ast_conversions.hpp +++ b/SeQuant/core/parse/ast_conversions.hpp @@ -22,6 +22,9 @@ namespace sequant::parse::transform { +using DefaultSymmetries = + std::tuple; + template std::tuple get_pos(const AST &ast, const PositionCache &cache, @@ -114,9 +117,9 @@ make_indices(const parse::ast::IndexGroups &groups, } template -Symmetry to_symmetry(char c, std::size_t offset, const Iterator &begin, - Symmetry default_symmetry) { - if (c == parse::ast::Tensor::unspecified_symmetry) { +Symmetry to_perm_symmetry(char c, std::size_t offset, const Iterator &begin, + Symmetry default_symmetry) { + if (c == parse::ast::SymmetrySpec::unspecified) { return default_symmetry; } @@ -136,6 +139,52 @@ Symmetry to_symmetry(char c, std::size_t offset, const Iterator &begin, std::string("Invalid symmetry specifier '") + c + "'"); } +template +BraKetSymmetry to_braket_symmetry(char c, std::size_t offset, + const Iterator &begin, + BraKetSymmetry default_symmetry) { + if (c == parse::ast::SymmetrySpec::unspecified) { + return default_symmetry; + } + + switch (c) { + case 'C': + case 'c': + return BraKetSymmetry::conjugate; + case 'S': + case 's': + return BraKetSymmetry::symm; + case 'N': + case 'n': + return BraKetSymmetry::nonsymm; + } + + throw ParseError( + offset, 1, std::string("Invalid BraKet symmetry specifier '") + c + "'"); +} + +template +ParticleSymmetry to_particle_symmetry(char c, std::size_t offset, + const Iterator &begin, + ParticleSymmetry default_symmetry) { + if (c == parse::ast::SymmetrySpec::unspecified) { + return default_symmetry; + } + + switch (c) { + case 'S': + case 's': + return ParticleSymmetry::symm; + case 'N': + case 'n': + return ParticleSymmetry::nonsymm; + } + + throw ParseError( + offset, 1, + std::string("Invalid particle symmetry specifier '") + c + "'"); +} + template Constant to_constant(const parse::ast::Number &number, const PositionCache &position_cache, @@ -152,45 +201,71 @@ Constant to_constant(const parse::ast::Number &number, } } +template +std::tuple to_symmetries( + const boost::optional &symm_spec, + const DefaultSymmetries &default_symms, const PositionCache &cache, + const Iterator &begin) { + if (!symm_spec.has_value()) { + return {std::get<0>(default_symms), std::get<1>(default_symms), + std::get<2>(default_symms)}; + } + + const ast::SymmetrySpec &spec = symm_spec.get(); + + auto [offset, length] = get_pos(spec, cache, begin); + + // Note: symmetry specifications are a separator (colon or dash) followed by + // an uppercase letter each (no whitespace allowed in-between) + Symmetry perm_symm = to_perm_symmetry(spec.perm_symm, offset + 1, begin, + std::get<0>(default_symms)); + BraKetSymmetry braket_symm = to_braket_symmetry( + spec.braket_symm, offset + 3, begin, std::get<1>(default_symms)); + ParticleSymmetry particle_symm = to_particle_symmetry( + spec.particle_symm, offset + 5, begin, std::get<2>(default_symms)); + + return {perm_symm, braket_symm, particle_symm}; +} + template ExprPtr ast_to_expr(const parse::ast::Product &product, const PositionCache &position_cache, const Iterator &begin, - Symmetry default_symmetry); + const DefaultSymmetries &default_symms); template ExprPtr ast_to_expr(const parse::ast::Sum &sum, const PositionCache &position_cache, const Iterator &begin, - Symmetry default_symmetry); + const DefaultSymmetries &default_symms); template ExprPtr ast_to_expr(const parse::ast::NullaryValue &value, const PositionCache &position_cache, const Iterator &begin, - Symmetry default_symmetry) { + const DefaultSymmetries &default_symms) { struct Transformer { std::reference_wrapper position_cache; std::reference_wrapper begin; - std::reference_wrapper default_symmetry; + std::reference_wrapper default_symms; ExprPtr operator()(const parse::ast::Product &product) const { return ast_to_expr(product, position_cache.get(), - begin.get(), default_symmetry); + begin.get(), default_symms.get()); } ExprPtr operator()(const parse::ast::Sum &sum) const { return ast_to_expr(sum, position_cache.get(), begin.get(), - default_symmetry); + default_symms.get()); } ExprPtr operator()(const parse::ast::Tensor &tensor) const { auto [braIndices, ketIndices, auxiliaries] = make_indices(tensor.indices, position_cache.get(), begin.get()); - auto [offset, length] = - get_pos(tensor, position_cache.get(), begin.get()); + auto [perm_symm, braket_symm, particle_symm] = + to_symmetries(tensor.symmetry, default_symms.get(), + position_cache.get(), begin.get()); return ex(tensor.name, bra(std::move(braIndices)), ket(std::move(ketIndices)), aux(std::move(auxiliaries)), - to_symmetry(tensor.symmetry, offset + length - 1, - begin.get(), default_symmetry)); + perm_symm, braket_symm, particle_symm); } ExprPtr operator()(const parse::ast::Variable &variable) const { @@ -211,7 +286,7 @@ ExprPtr ast_to_expr(const parse::ast::NullaryValue &value, return boost::apply_visitor( Transformer{std::ref(position_cache), std::ref(begin), - std::ref(default_symmetry)}, + std::ref(default_symms)}, value); } @@ -223,7 +298,7 @@ bool holds_alternative(const boost::variant &v) noexcept { template ExprPtr ast_to_expr(const parse::ast::Product &product, const PositionCache &position_cache, const Iterator &begin, - Symmetry default_symmetry) { + const DefaultSymmetries &default_symms) { if (product.factors.empty()) { // This shouldn't happen assert(false); @@ -233,7 +308,7 @@ ExprPtr ast_to_expr(const parse::ast::Product &product, if (product.factors.size() == 1) { return ast_to_expr(product.factors.front(), position_cache, begin, - default_symmetry); + default_symms); } std::vector factors; @@ -247,7 +322,7 @@ ExprPtr ast_to_expr(const parse::ast::Product &product, position_cache, begin); } else { factors.push_back( - ast_to_expr(value, position_cache, begin, default_symmetry)); + ast_to_expr(value, position_cache, begin, default_symms)); } } @@ -267,13 +342,13 @@ ExprPtr ast_to_expr(const parse::ast::Product &product, template ExprPtr ast_to_expr(const parse::ast::Sum &sum, const PositionCache &position_cache, const Iterator &begin, - Symmetry default_symmetry) { + const DefaultSymmetries &default_symms) { if (sum.summands.empty()) { return {}; } if (sum.summands.size() == 1) { return ast_to_expr(sum.summands.front(), position_cache, begin, - default_symmetry); + default_symms); } std::vector summands; @@ -281,7 +356,7 @@ ExprPtr ast_to_expr(const parse::ast::Sum &sum, std::transform( sum.summands.begin(), sum.summands.end(), std::back_inserter(summands), [&](const parse::ast::Product &product) { - return ast_to_expr(product, position_cache, begin, default_symmetry); + return ast_to_expr(product, position_cache, begin, default_symms); }); return ex(std::move(summands)); diff --git a/SeQuant/core/parse/parse.cpp b/SeQuant/core/parse/parse.cpp index 7b6f5bed9..0a4c4e23e 100644 --- a/SeQuant/core/parse/parse.cpp +++ b/SeQuant/core/parse/parse.cpp @@ -47,6 +47,7 @@ struct ExprRule; struct IndexLabelRule; struct IndexRule; struct IndexGroupRule; +struct SymmetrySpecRule; // Types x3::rule number{"Number"}; @@ -63,6 +64,7 @@ x3::rule name{"Name"}; x3::rule index_label{"IndexLabel"}; x3::rule index{"Index"}; x3::rule index_groups{"IndexGroups"}; +x3::rule symmetry_spec{"SymmetrySpec"}; auto to_char_type = [](auto c) { return static_cast(c); @@ -105,8 +107,12 @@ auto index_groups_def = L"_{" > -(index % ',') > L"}^{" > -(index % ',') > L" | L"^{" > -(index % ',') > L"}_{" > -(index % ',') > L"}" >> x3::attr(noIndices) >> x3::attr(true) | '{' > -(index % ',') > -( ';' > -(index % ',')) > -(';' > -(index % ',')) > '}' >> x3::attr(false); +auto symmetry_spec_def= x3::lexeme[ + ':' >> x3::upper >> -('-' >> x3::upper) >> -('-' >> x3::upper) + ]; + auto tensor_def = x3::lexeme[ - name >> x3::skip[index_groups] >> -(':' >> x3::upper) + name >> x3::skip[index_groups] >> -(symmetry_spec) ]; auto nullary = number | tensor | variable; @@ -125,7 +131,7 @@ auto expr_def = -sum > x3::eoi; // clang-format on BOOST_SPIRIT_DEFINE(name, number, variable, index_label, index, index_groups, - tensor, product, sum, expr); + tensor, product, sum, expr, symmetry_spec); struct position_cache_tag; struct error_handler_tag; @@ -163,6 +169,7 @@ struct ExprRule : helpers::annotate_position, helpers::error_handler {}; struct IndexLabelRule : helpers::annotate_position, helpers::error_handler {}; struct IndexRule : helpers::annotate_position, helpers::error_handler {}; struct IndexGroupRule : helpers::annotate_position, helpers::error_handler {}; +struct SymmetrySpecRule : helpers::annotate_position, helpers::error_handler {}; } // namespace parse @@ -180,7 +187,8 @@ struct ErrorHandler { } }; -ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) { +ExprPtr parse_expr(std::wstring_view input, Symmetry perm_symm, + BraKetSymmetry braket_symm, ParticleSymmetry particle_symm) { using iterator_type = decltype(input)::iterator; x3::position_cache> positions(input.begin(), input.end()); @@ -217,8 +225,10 @@ ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) { throw; } - return parse::transform::ast_to_expr(ast, positions, input.begin(), - default_symmetry); + return parse::transform::ast_to_expr( + ast, positions, input.begin(), + parse::transform::DefaultSymmetries{perm_symm, braket_symm, + particle_symm}); } } // namespace sequant diff --git a/tests/unit/test_parse.cpp b/tests/unit/test_parse.cpp index 574d9f1fb..12e884c53 100644 --- a/tests/unit/test_parse.cpp +++ b/tests/unit/test_parse.cpp @@ -170,11 +170,21 @@ TEST_CASE("parse_expr", "[parse]") { SECTION("Tensor with symmetry annotation") { auto expr1 = parse_expr(L"t{a1;i1}:A"); - auto expr2 = parse_expr(L"t{a1;i1}:S"); - auto expr3 = parse_expr(L"t{a1;i1}:N"); - REQUIRE(expr1->as().symmetry() == sequant::Symmetry::antisymm); - REQUIRE(expr2->as().symmetry() == sequant::Symmetry::symm); - REQUIRE(expr3->as().symmetry() == sequant::Symmetry::nonsymm); + auto expr2 = parse_expr(L"t{a1;i1}:S-C"); + auto expr3 = parse_expr(L"t{a1;i1}:N-S-N"); + + const Tensor& t1 = expr1->as(); + const Tensor& t2 = expr2->as(); + const Tensor& t3 = expr3->as(); + + REQUIRE(t1.symmetry() == Symmetry::antisymm); + + REQUIRE(t2.symmetry() == Symmetry::symm); + REQUIRE(t2.braket_symmetry() == BraKetSymmetry::conjugate); + + REQUIRE(t3.symmetry() == Symmetry::nonsymm); + REQUIRE(t3.braket_symmetry() == BraKetSymmetry::symm); + REQUIRE(t3.particle_symmetry() == ParticleSymmetry::nonsymm); } SECTION("Constant") {