Skip to content

Commit

Permalink
Parse: support full symmetry specification
Browse files Browse the repository at this point in the history
  • Loading branch information
Krzmbrzl committed Jan 16, 2025
1 parent a274e52 commit 91f5a5c
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 36 deletions.
24 changes: 22 additions & 2 deletions SeQuant/core/parse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,25 @@ struct ParseError : std::runtime_error {
/// '1.0/2.0 * t{i1;a1} * f{i1; a1}' same as above
/// 't{i1,i2; a1<i1,i2>, a2<i1,i2>}' a tensor having indices with proto indices.
/// a1<i1,i2> is an index with i1 and i2 as proto-indices.
/// \param tensor_sym The symmetry of all atomic tensors in the
/// Every tensor may optionally be annoted with index symmetry specifications. The general syntax is
/// <tensorSpec> [:<perm symm> [-<braket symm> [-<particle symm>]]]
/// (no whitespace is allowed at this place). Examples are
/// 't{i1;i2}:A', 't{i1;i2}:A-S', 't{i1;i2}:N-C-S'
/// Possible values for <perm symm> are
/// - 'A' for antisymmetry (sequant::Symmetry::antisymm)
/// - 'S' for symmetric (sequant::Symmetry::symm)
/// - 'N' for non-symmetric (sequant::Symmetry::nonsymm)
/// Possible values for <braket symm> are
/// - 'C' for antisymmetry (sequant::BraKetSymmetry::conjugate)
/// - 'S' for symmetric (sequant::BraKetSymmetry::symm)
/// - 'N' for non-symmetric (sequant::BraKetSymmetry::nonsymm)
/// Possible values for <particle symm> are
/// - 'S' for symmetric (sequant::ParticleSymmetry::symm)
/// - 'N' for non-symmetric (sequant::ParticleSymmetry::nonsymm)
/// \param perm_symm Default index permutation symmetry to be used if tensors don't specify a permutation
/// symmetry explicitly.
/// \param braket_symm Default BraKet symmetry to be used if tensors don't specify a BraKet symmetry explicitly.
/// \param particle_symm Default particle symmetry to be used if tensors don't specify a particle symmetry explicitly.
/// @c raw expression. Explicit tensor symmetry can
/// be annotated in the expression itself. In that case, the
/// annotated symmetry will be used.
Expand All @@ -54,7 +72,9 @@ struct ParseError : std::runtime_error {
/// \return SeQuant expression.
// clang-format on
ExprPtr parse_expr(std::wstring_view raw,
Symmetry tensor_sym = Symmetry::nonsymm);
Symmetry perm_symm = Symmetry::nonsymm,
BraKetSymmetry braket_symm = BraKetSymmetry::nonsymm,
ParticleSymmetry particle_symm = ParticleSymmetry::symm);

///
/// Get a parsable string from an expression.
Expand Down
17 changes: 13 additions & 4 deletions SeQuant/core/parse/ast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#define BOOST_SPIRIT_X3_UNICODE
#include <boost/fusion/include/adapt_struct.hpp>
#include <boost/optional.hpp>
#include <boost/spirit/home/x3.hpp>
#include <boost/spirit/home/x3/support/ast/position_tagged.hpp>
#include <boost/variant.hpp>
Expand Down Expand Up @@ -66,17 +67,23 @@ struct IndexGroups : boost::spirit::x3::position_tagged {
reverse_bra_ket(reverse_bra_ket) {}
};

struct SymmetrySpec : boost::spirit::x3::position_tagged {
static constexpr char unspecified = '\0';
char perm_symm = unspecified;
char braket_symm = unspecified;
char particle_symm = unspecified;
};

struct Tensor : boost::spirit::x3::position_tagged {
static constexpr char unspecified_symmetry = '\0';
std::wstring name;
IndexGroups indices;
char symmetry;
boost::optional<SymmetrySpec> symmetry;

Tensor(std::wstring name = {}, IndexGroups indices = {},
char symmetry = unspecified_symmetry)
boost::optional<SymmetrySpec> symmetry = {})
: name(std::move(name)),
indices(std::move(indices)),
symmetry(symmetry) {}
symmetry(std::move(symmetry)) {}
};

struct Product;
Expand Down Expand Up @@ -125,6 +132,8 @@ BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Number, numerator, denominator);
BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Variable, name, conjugated);
BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::IndexGroups, bra, ket,
auxiliaries, reverse_bra_ket);
BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::SymmetrySpec, perm_symm,
braket_symm, particle_symm);
BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Tensor, name, indices, symmetry);

BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Product, factors);
Expand Down
115 changes: 95 additions & 20 deletions SeQuant/core/parse/ast_conversions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

namespace sequant::parse::transform {

using DefaultSymmetries =
std::tuple<Symmetry, BraKetSymmetry, ParticleSymmetry>;

template <typename AST, typename PositionCache, typename Iterator>
std::tuple<std::size_t, std::size_t> get_pos(const AST &ast,
const PositionCache &cache,
Expand Down Expand Up @@ -114,9 +117,9 @@ make_indices(const parse::ast::IndexGroups &groups,
}

template <typename Iterator>
Symmetry to_symmetry(char c, std::size_t offset, const Iterator &begin,
Symmetry default_symmetry) {
if (c == parse::ast::Tensor::unspecified_symmetry) {
Symmetry to_perm_symmetry(char c, std::size_t offset, const Iterator &begin,
Symmetry default_symmetry) {
if (c == parse::ast::SymmetrySpec::unspecified) {
return default_symmetry;
}

Expand All @@ -136,6 +139,52 @@ Symmetry to_symmetry(char c, std::size_t offset, const Iterator &begin,
std::string("Invalid symmetry specifier '") + c + "'");
}

template <typename Iterator>
BraKetSymmetry to_braket_symmetry(char c, std::size_t offset,
const Iterator &begin,
BraKetSymmetry default_symmetry) {
if (c == parse::ast::SymmetrySpec::unspecified) {
return default_symmetry;
}

switch (c) {
case 'C':
case 'c':
return BraKetSymmetry::conjugate;
case 'S':
case 's':
return BraKetSymmetry::symm;
case 'N':
case 'n':
return BraKetSymmetry::nonsymm;
}

throw ParseError(
offset, 1, std::string("Invalid BraKet symmetry specifier '") + c + "'");
}

template <typename Iterator>
ParticleSymmetry to_particle_symmetry(char c, std::size_t offset,
const Iterator &begin,
ParticleSymmetry default_symmetry) {
if (c == parse::ast::SymmetrySpec::unspecified) {
return default_symmetry;
}

switch (c) {
case 'S':
case 's':
return ParticleSymmetry::symm;
case 'N':
case 'n':
return ParticleSymmetry::nonsymm;
}

throw ParseError(
offset, 1,
std::string("Invalid particle symmetry specifier '") + c + "'");
}

template <typename PositionCache, typename Iterator>
Constant to_constant(const parse::ast::Number &number,
const PositionCache &position_cache,
Expand All @@ -152,45 +201,71 @@ Constant to_constant(const parse::ast::Number &number,
}
}

template <typename PositionCache, typename Iterator>
std::tuple<Symmetry, BraKetSymmetry, ParticleSymmetry> to_symmetries(
const boost::optional<parse::ast::SymmetrySpec> &symm_spec,
const DefaultSymmetries &default_symms, const PositionCache &cache,
const Iterator &begin) {
if (!symm_spec.has_value()) {
return {std::get<0>(default_symms), std::get<1>(default_symms),
std::get<2>(default_symms)};
}

const ast::SymmetrySpec &spec = symm_spec.get();

auto [offset, length] = get_pos(spec, cache, begin);

// Note: symmetry specifications are a separator (colon or dash) followed by
// an uppercase letter each (no whitespace allowed in-between)
Symmetry perm_symm = to_perm_symmetry(spec.perm_symm, offset + 1, begin,
std::get<0>(default_symms));
BraKetSymmetry braket_symm = to_braket_symmetry(
spec.braket_symm, offset + 3, begin, std::get<1>(default_symms));
ParticleSymmetry particle_symm = to_particle_symmetry(
spec.particle_symm, offset + 5, begin, std::get<2>(default_symms));

return {perm_symm, braket_symm, particle_symm};
}

template <typename PositionCache, typename Iterator>
ExprPtr ast_to_expr(const parse::ast::Product &product,
const PositionCache &position_cache, const Iterator &begin,
Symmetry default_symmetry);
const DefaultSymmetries &default_symms);
template <typename PositionCache, typename Iterator>
ExprPtr ast_to_expr(const parse::ast::Sum &sum,
const PositionCache &position_cache, const Iterator &begin,
Symmetry default_symmetry);
const DefaultSymmetries &default_symms);

template <typename PositionCache, typename Iterator>
ExprPtr ast_to_expr(const parse::ast::NullaryValue &value,
const PositionCache &position_cache, const Iterator &begin,
Symmetry default_symmetry) {
const DefaultSymmetries &default_symms) {
struct Transformer {
std::reference_wrapper<const PositionCache> position_cache;
std::reference_wrapper<const Iterator> begin;
std::reference_wrapper<Symmetry> default_symmetry;
std::reference_wrapper<const DefaultSymmetries> default_symms;

ExprPtr operator()(const parse::ast::Product &product) const {
return ast_to_expr<PositionCache>(product, position_cache.get(),
begin.get(), default_symmetry);
begin.get(), default_symms.get());
}

ExprPtr operator()(const parse::ast::Sum &sum) const {
return ast_to_expr<PositionCache>(sum, position_cache.get(), begin.get(),
default_symmetry);
default_symms.get());
}

ExprPtr operator()(const parse::ast::Tensor &tensor) const {
auto [braIndices, ketIndices, auxiliaries] =
make_indices(tensor.indices, position_cache.get(), begin.get());

auto [offset, length] =
get_pos(tensor, position_cache.get(), begin.get());
auto [perm_symm, braket_symm, particle_symm] =
to_symmetries(tensor.symmetry, default_symms.get(),
position_cache.get(), begin.get());

return ex<Tensor>(tensor.name, bra(std::move(braIndices)),
ket(std::move(ketIndices)), aux(std::move(auxiliaries)),
to_symmetry(tensor.symmetry, offset + length - 1,
begin.get(), default_symmetry));
perm_symm, braket_symm, particle_symm);
}

ExprPtr operator()(const parse::ast::Variable &variable) const {
Expand All @@ -211,7 +286,7 @@ ExprPtr ast_to_expr(const parse::ast::NullaryValue &value,

return boost::apply_visitor(
Transformer{std::ref(position_cache), std::ref(begin),
std::ref(default_symmetry)},
std::ref(default_symms)},
value);
}

Expand All @@ -223,7 +298,7 @@ bool holds_alternative(const boost::variant<Ts...> &v) noexcept {
template <typename PositionCache, typename Iterator>
ExprPtr ast_to_expr(const parse::ast::Product &product,
const PositionCache &position_cache, const Iterator &begin,
Symmetry default_symmetry) {
const DefaultSymmetries &default_symms) {
if (product.factors.empty()) {
// This shouldn't happen
assert(false);
Expand All @@ -233,7 +308,7 @@ ExprPtr ast_to_expr(const parse::ast::Product &product,

if (product.factors.size() == 1) {
return ast_to_expr(product.factors.front(), position_cache, begin,
default_symmetry);
default_symms);
}

std::vector<ExprPtr> factors;
Expand All @@ -247,7 +322,7 @@ ExprPtr ast_to_expr(const parse::ast::Product &product,
position_cache, begin);
} else {
factors.push_back(
ast_to_expr(value, position_cache, begin, default_symmetry));
ast_to_expr(value, position_cache, begin, default_symms));
}
}

Expand All @@ -267,21 +342,21 @@ ExprPtr ast_to_expr(const parse::ast::Product &product,
template <typename PositionCache, typename Iterator>
ExprPtr ast_to_expr(const parse::ast::Sum &sum,
const PositionCache &position_cache, const Iterator &begin,
Symmetry default_symmetry) {
const DefaultSymmetries &default_symms) {
if (sum.summands.empty()) {
return {};
}
if (sum.summands.size() == 1) {
return ast_to_expr(sum.summands.front(), position_cache, begin,
default_symmetry);
default_symms);
}

std::vector<ExprPtr> summands;
summands.reserve(sum.summands.size());
std::transform(
sum.summands.begin(), sum.summands.end(), std::back_inserter(summands),
[&](const parse::ast::Product &product) {
return ast_to_expr(product, position_cache, begin, default_symmetry);
return ast_to_expr(product, position_cache, begin, default_symms);
});

return ex<Sum>(std::move(summands));
Expand Down
20 changes: 15 additions & 5 deletions SeQuant/core/parse/parse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct ExprRule;
struct IndexLabelRule;
struct IndexRule;
struct IndexGroupRule;
struct SymmetrySpecRule;

// Types
x3::rule<NumberRule, ast::Number> number{"Number"};
Expand All @@ -63,6 +64,7 @@ x3::rule<struct NameRule, std::wstring> name{"Name"};
x3::rule<IndexLabelRule, ast::IndexLabel> index_label{"IndexLabel"};
x3::rule<IndexRule, ast::Index> index{"Index"};
x3::rule<IndexGroupRule, ast::IndexGroups> index_groups{"IndexGroups"};
x3::rule<SymmetrySpecRule, ast::SymmetrySpec> symmetry_spec{"SymmetrySpec"};

auto to_char_type = [](auto c) {
return static_cast<x3::unicode::char_type::char_type>(c);
Expand Down Expand Up @@ -105,8 +107,12 @@ auto index_groups_def = L"_{" > -(index % ',') > L"}^{" > -(index % ',') > L"
| L"^{" > -(index % ',') > L"}_{" > -(index % ',') > L"}" >> x3::attr(noIndices) >> x3::attr(true)
| '{' > -(index % ',') > -( ';' > -(index % ',')) > -(';' > -(index % ',')) > '}' >> x3::attr(false);
auto symmetry_spec_def= x3::lexeme[
':' >> x3::upper >> -('-' >> x3::upper) >> -('-' >> x3::upper)
];
auto tensor_def = x3::lexeme[
name >> x3::skip[index_groups] >> -(':' >> x3::upper)
name >> x3::skip[index_groups] >> -(symmetry_spec)
];
auto nullary = number | tensor | variable;
Expand All @@ -125,7 +131,7 @@ auto expr_def = -sum > x3::eoi;
// clang-format on
BOOST_SPIRIT_DEFINE(name, number, variable, index_label, index, index_groups,
tensor, product, sum, expr);
tensor, product, sum, expr, symmetry_spec);
struct position_cache_tag;
struct error_handler_tag;
Expand Down Expand Up @@ -163,6 +169,7 @@ struct ExprRule : helpers::annotate_position, helpers::error_handler {};
struct IndexLabelRule : helpers::annotate_position, helpers::error_handler {};
struct IndexRule : helpers::annotate_position, helpers::error_handler {};
struct IndexGroupRule : helpers::annotate_position, helpers::error_handler {};
struct SymmetrySpecRule : helpers::annotate_position, helpers::error_handler {};
} // namespace parse
Expand All @@ -180,7 +187,8 @@ struct ErrorHandler {
}
};
ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) {
ExprPtr parse_expr(std::wstring_view input, Symmetry perm_symm,
BraKetSymmetry braket_symm, ParticleSymmetry particle_symm) {
using iterator_type = decltype(input)::iterator;
x3::position_cache<std::vector<iterator_type>> positions(input.begin(),
input.end());
Expand Down Expand Up @@ -217,8 +225,10 @@ ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) {
throw;
}
return parse::transform::ast_to_expr(ast, positions, input.begin(),
default_symmetry);
return parse::transform::ast_to_expr(
ast, positions, input.begin(),
parse::transform::DefaultSymmetries{perm_symm, braket_symm,
particle_symm});
}
} // namespace sequant
20 changes: 15 additions & 5 deletions tests/unit/test_parse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,21 @@ TEST_CASE("parse_expr", "[parse]") {

SECTION("Tensor with symmetry annotation") {
auto expr1 = parse_expr(L"t{a1;i1}:A");
auto expr2 = parse_expr(L"t{a1;i1}:S");
auto expr3 = parse_expr(L"t{a1;i1}:N");
REQUIRE(expr1->as<Tensor>().symmetry() == sequant::Symmetry::antisymm);
REQUIRE(expr2->as<Tensor>().symmetry() == sequant::Symmetry::symm);
REQUIRE(expr3->as<Tensor>().symmetry() == sequant::Symmetry::nonsymm);
auto expr2 = parse_expr(L"t{a1;i1}:S-C");
auto expr3 = parse_expr(L"t{a1;i1}:N-S-N");

const Tensor& t1 = expr1->as<Tensor>();
const Tensor& t2 = expr2->as<Tensor>();
const Tensor& t3 = expr3->as<Tensor>();

REQUIRE(t1.symmetry() == Symmetry::antisymm);

REQUIRE(t2.symmetry() == Symmetry::symm);
REQUIRE(t2.braket_symmetry() == BraKetSymmetry::conjugate);

REQUIRE(t3.symmetry() == Symmetry::nonsymm);
REQUIRE(t3.braket_symmetry() == BraKetSymmetry::symm);
REQUIRE(t3.particle_symmetry() == ParticleSymmetry::nonsymm);
}

SECTION("Constant") {
Expand Down

0 comments on commit 91f5a5c

Please sign in to comment.