Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce ResultExpr #216

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ set(SeQuant_src
SeQuant/core/parse.hpp
SeQuant/core/ranges.hpp
SeQuant/core/rational.hpp
SeQuant/core/result_expr.cpp
SeQuant/core/result_expr.hpp
SeQuant/core/runtime.cpp
SeQuant/core/runtime.hpp
SeQuant/core/space.hpp
Expand Down
5 changes: 5 additions & 0 deletions SeQuant/core/parse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <SeQuant/core/attr.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/result_expr.hpp>
#include <SeQuant/core/tensor.hpp>

#include <stdexcept>
Expand Down Expand Up @@ -56,6 +57,9 @@ struct ParseError : std::runtime_error {
ExprPtr parse_expr(std::wstring_view raw,
Symmetry tensor_sym = Symmetry::nonsymm);

ResultExpr parse_result_expr(std::wstring_view raw,
Symmetry tensor_sym = Symmetry::nonsymm);

///
/// Get a parsable string from an expression.
///
Expand All @@ -71,6 +75,7 @@ ExprPtr parse_expr(std::wstring_view raw,
/// \param annot_sym Whether to add sequant::Symmetry annotation
/// to each Tensor string.
/// \return wstring of the expression.
std::wstring deparse(const ResultExpr &expr, bool annot_sym = true);
std::wstring deparse(const ExprPtr &expr, bool annot_sym = true);
std::wstring deparse(const Expr &expr, bool annot_sym = true);
std::wstring deparse(const Product &product, bool annot_sym);
Expand Down
13 changes: 13 additions & 0 deletions SeQuant/core/parse/ast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <cstdint>
#include <string>
#include <variant>
#include <vector>

namespace sequant::parse::ast {
Expand Down Expand Up @@ -115,6 +116,17 @@ Product::Product(std::vector<NullaryValue> factors)

Sum::Sum(std::vector<Product> summands) : summands(std::move(summands)) {}

struct ResultExpr : boost::spirit::x3::position_tagged {
std::variant<Tensor, Variable> lhs;
Sum rhs;

ResultExpr(Variable variable = {}, Sum expr = {})
: lhs(std::move(variable)), rhs(std::move(expr)) {}

ResultExpr(Tensor tensor, Sum expr)
: lhs(std::move(tensor)), rhs(std::move(expr)) {}
};

} // namespace sequant::parse::ast

BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::IndexLabel, label, id);
Expand All @@ -127,5 +139,6 @@ BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Tensor, name, indices, symmetry);

BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Product, factors);
BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::Sum, summands);
BOOST_FUSION_ADAPT_STRUCT(sequant::parse::ast::ResultExpr, lhs, rhs);

#endif // SEQUANT_CORE_PARSE_AST_HPP
101 changes: 62 additions & 39 deletions SeQuant/core/parse/ast_conversions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/parse.hpp>
#include <SeQuant/core/parse/ast.hpp>
#include <SeQuant/core/result_expr.hpp>
#include <SeQuant/core/space.hpp>
#include <SeQuant/core/tensor.hpp>
#include <SeQuant/core/utility/string.hpp>
Expand Down Expand Up @@ -156,54 +157,53 @@ ExprPtr ast_to_expr(const parse::ast::Sum &sum,
Symmetry default_symmetry);

template <typename PositionCache, typename Iterator>
ExprPtr ast_to_expr(const parse::ast::NullaryValue &value,
const PositionCache &position_cache, const Iterator &begin,
Symmetry default_symmetry) {
struct Transformer {
std::reference_wrapper<const PositionCache> position_cache;
std::reference_wrapper<const Iterator> begin;
std::reference_wrapper<Symmetry> default_symmetry;

ExprPtr operator()(const parse::ast::Product &product) const {
return ast_to_expr<PositionCache>(product, position_cache.get(),
begin.get(), default_symmetry);
}
struct Transformer {
std::reference_wrapper<const PositionCache> position_cache;
std::reference_wrapper<const Iterator> begin;
std::reference_wrapper<Symmetry> default_symmetry;

ExprPtr operator()(const parse::ast::Product &product) const {
return ast_to_expr<PositionCache>(product, position_cache.get(),
begin.get(), default_symmetry);
}

ExprPtr operator()(const parse::ast::Sum &sum) const {
return ast_to_expr<PositionCache>(sum, position_cache.get(), begin.get(),
default_symmetry);
}
ExprPtr operator()(const parse::ast::Sum &sum) const {
return ast_to_expr<PositionCache>(sum, position_cache.get(), begin.get(),
default_symmetry);
}

ExprPtr operator()(const parse::ast::Tensor &tensor) const {
auto [braIndices, ketIndices] =
make_indices(tensor.indices, position_cache.get(), begin.get());
ExprPtr operator()(const parse::ast::Tensor &tensor) const {
auto [braIndices, ketIndices] =
make_indices(tensor.indices, position_cache.get(), begin.get());

auto [offset, length] =
get_pos(tensor, position_cache.get(), begin.get());
auto [offset, length] = get_pos(tensor, position_cache.get(), begin.get());

return ex<Tensor>(tensor.name, std::move(braIndices),
std::move(ketIndices),
to_symmetry(tensor.symmetry, offset + length - 1,
begin.get(), default_symmetry));
}
return ex<Tensor>(tensor.name, std::move(braIndices), std::move(ketIndices),
to_symmetry(tensor.symmetry, offset + length - 1,
begin.get(), default_symmetry));
}

ExprPtr operator()(const parse::ast::Variable &variable) const {
if (variable.conjugated) {
return ex<Variable>(variable.name + L"^*");
} else {
return ex<Variable>(variable.name);
}
ExprPtr operator()(const parse::ast::Variable &variable) const {
if (variable.conjugated) {
return ex<Variable>(variable.name + L"^*");
} else {
return ex<Variable>(variable.name);
}
}

ExprPtr operator()(const parse::ast::Number &number) const {
return ex<Constant>(
to_constant(number, position_cache.get(), begin.get()));
}
};
ExprPtr operator()(const parse::ast::Number &number) const {
return ex<Constant>(to_constant(number, position_cache.get(), begin.get()));
}
};

template <typename PositionCache, typename Iterator>
ExprPtr ast_to_expr(const parse::ast::NullaryValue &value,
const PositionCache &position_cache, const Iterator &begin,
Symmetry default_symmetry) {
return boost::apply_visitor(
Transformer{std::ref(position_cache), std::ref(begin),
std::ref(default_symmetry)},
Transformer<PositionCache, Iterator>{std::ref(position_cache),
std::ref(begin),
std::ref(default_symmetry)},
value);
}

Expand Down Expand Up @@ -279,6 +279,29 @@ ExprPtr ast_to_expr(const parse::ast::Sum &sum,
return ex<Sum>(std::move(summands));
}

template <typename PositionCache, typename Iterator>
ResultExpr ast_to_result(const parse::ast::ResultExpr &result,
const PositionCache &position_cache,
const Iterator &begin, Symmetry default_symmetry) {
ExprPtr lhs = std::visit(
Transformer<PositionCache, Iterator>{std::ref(position_cache),
std::ref(begin),
std::ref(default_symmetry)},
result.lhs);
ExprPtr rhs =
ast_to_expr(result.rhs, position_cache, begin, default_symmetry);

if (lhs.is<Tensor>()) {
return {std::move(lhs.as<Tensor>()), std::move(rhs)};
} else if (lhs.is<Variable>()) {
return {std::move(lhs.as<Variable>()), std::move(rhs)};
} else {
auto [offset, length] = get_pos(result.lhs, position_cache, begin);
throw ParseError(offset, length,
"LHS of a ResultExpr must be a Tensor or a Variable");
}
}

} // namespace sequant::parse::transform

#endif // SEQUANT_CORE_PARSE_AST_CONVERSIONS_HPP
24 changes: 24 additions & 0 deletions SeQuant/core/parse/deparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <SeQuant/core/complex.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/result_expr.hpp>
#include <SeQuant/core/tensor.hpp>

#include <range/v3/all.hpp>
Expand Down Expand Up @@ -111,6 +112,29 @@ std::wstring deparse(const ExprPtr& expr, bool annot_sym) {
throw std::runtime_error("Unsupported expr type for deparse!");
}

std::wstring deparse(const ResultExpr& result, bool annot_sym) {
std::wstring deparsed;
if (result.has_label()) {
deparsed += result.label();
} else {
deparsed += L"?";
}

if (!result.bra().empty() || !result.ket().empty()) {
deparsed += L"{";
deparsed += details::deparse_indices(result.bra());
deparsed += L";";
deparsed += details::deparse_indices(result.ket());
deparsed += L"}";

if (annot_sym) {
deparsed += L":" + details::deparse_sym(result.symmetry());
}
}

return deparsed + L" = " + deparse(result.expression(), annot_sym);
}

std::wstring deparse(const Index& index) {
std::wstring deparsed(index.label());

Expand Down
65 changes: 49 additions & 16 deletions SeQuant/core/parse/parse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
// Created by Robert Adam on 2023-09-20
//

#include <SeQuant/core/parse.hpp>
#include <SeQuant/core/parse/ast.hpp>
#include <SeQuant/core/parse/ast_conversions.hpp>
#include <SeQuant/core/parse/semantic_actions.hpp>

#include <SeQuant/core/attr.hpp>
#include <SeQuant/core/container.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/parse.hpp>
#include <SeQuant/core/parse/ast.hpp>
#include <SeQuant/core/parse/ast_conversions.hpp>
#include <SeQuant/core/parse/semantic_actions.hpp>
#include <SeQuant/core/result_expr.hpp>
#include <SeQuant/core/space.hpp>
#include <SeQuant/core/tensor.hpp>

Expand All @@ -22,11 +22,18 @@

#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iostream>
#include <iterator>
#include <memory>
#include <stdexcept>
#include <string>
#include <string_view>
#include <type_traits>
#include <utility>
#include <vector>

namespace sequant {

Expand All @@ -44,6 +51,7 @@ struct TensorRule;
struct ProductRule;
struct SumRule;
struct ExprRule;
struct ResultExprRule;
struct IndexLabelRule;
struct IndexRule;
struct IndexGroupRule;
Expand All @@ -57,6 +65,7 @@ x3::rule<TensorRule, ast::Tensor> tensor{"Tensor"};
x3::rule<ProductRule, ast::Product> product{"Product"};
x3::rule<SumRule, ast::Sum> sum{"Sum"};
x3::rule<ExprRule, ast::Sum> expr{"Expression"};
x3::rule<ResultExprRule, ast::ResultExpr> resultExpr{"ResultExpr"};

// Auxiliaries
x3::rule<struct NameRule, std::wstring> name{"Name"};
Expand Down Expand Up @@ -121,10 +130,12 @@ auto addend = (('+' >> x3::attr(1) | '-' >> x3::attr(-1)) > product)[a
auto sum_def = first_addend >> *addend;
auto expr_def = -sum > x3::eoi;
auto resultExpr_def = (tensor | variable) > (L'=' | x3::lit(L"->")) >> expr;
// clang-format on
BOOST_SPIRIT_DEFINE(name, number, variable, index_label, index, index_groups,
tensor, product, sum, expr);
tensor, product, sum, expr, resultExpr);
struct position_cache_tag;
struct error_handler_tag;
Expand Down Expand Up @@ -159,6 +170,7 @@ struct TensorRule : helpers::annotate_position, helpers::error_handler {};
struct ProductRule : helpers::annotate_position, helpers::error_handler {};
struct SumRule : helpers::annotate_position, helpers::error_handler {};
struct ExprRule : helpers::annotate_position, helpers::error_handler {};
struct ResultRule : helpers::annotate_position, helpers::error_handler {};
struct IndexLabelRule : helpers::annotate_position, helpers::error_handler {};
struct IndexRule : helpers::annotate_position, helpers::error_handler {};
struct IndexGroupRule : helpers::annotate_position, helpers::error_handler {};
Expand All @@ -179,34 +191,34 @@ struct ErrorHandler {
}
};
ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) {
template <typename AST, typename StartRule, typename PositionCache>
AST do_parse(const StartRule &start, std::wstring_view input,
PositionCache &positions) {
using iterator_type = decltype(input)::iterator;
x3::position_cache<std::vector<iterator_type>> positions(input.begin(),
input.end());
ErrorHandler<iterator_type> error_handler(input.begin());
parse::ast::Sum ast;
AST ast;
const auto parser = x3::with<parse::error_handler_tag>(
std::ref(error_handler))[x3::with<parse::position_cache_tag>(
std::ref(positions))[parse::expr]];
std::ref(positions))[start]];
auto start = input.begin();
auto begin = input.begin();
try {
bool success =
x3::phrase_parse(start, input.end(), parser, x3::unicode::space, ast);
x3::phrase_parse(begin, input.end(), parser, x3::unicode::space, ast);
if (!success) {
// Normally, this shouldn't happen as any error should itself throw a
// ParseError already
throw ParseError(0, input.size(),
"Parsing was unsuccessful for an unknown reason");
}
if (start != input.end()) {
if (begin != input.end()) {
// This should also not happen as the parser requires matching EOI
throw ParseError(std::distance(input.begin(), start),
std::distance(start, input.end()),
throw ParseError(std::distance(input.begin(), begin),
std::distance(begin, input.end()),
"Couldn't parse the entire input");
}
} catch (const boost::spirit::x3::expectation_failure<iterator_type> &e) {
Expand All @@ -216,6 +228,27 @@ ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) {
throw;
}
return ast;
}
ResultExpr parse_result_expr(std::wstring_view input,
Symmetry default_symmetry) {
using iterator_type = decltype(input)::iterator;
x3::position_cache<std::vector<iterator_type>> positions(input.begin(),
input.end());
auto ast =
do_parse<parse::ast::ResultExpr>(parse::resultExpr, input, positions);
return parse::transform::ast_to_result(ast, positions, input.begin(),
default_symmetry);
}
ExprPtr parse_expr(std::wstring_view input, Symmetry default_symmetry) {
using iterator_type = decltype(input)::iterator;
x3::position_cache<std::vector<iterator_type>> positions(input.begin(),
input.end());
auto ast = do_parse<parse::ast::Sum>(parse::expr, input, positions);
return parse::transform::ast_to_expr(ast, positions, input.begin(),
default_symmetry);
}
Expand Down
Loading
Loading