diff --git a/planning/unified/plugin/up_aries/solver.py b/planning/unified/plugin/up_aries/solver.py index c4b4e17e..b9fc561d 100644 --- a/planning/unified/plugin/up_aries/solver.py +++ b/planning/unified/plugin/up_aries/solver.py @@ -8,7 +8,7 @@ import time from fractions import Fraction from pathlib import Path -from typing import IO, Callable, Optional, Iterator +from typing import IO, Callable, Optional, Iterator, Tuple import grpc import unified_planning as up @@ -292,7 +292,7 @@ class Aries(AriesEngine, mixins.OneshotPlannerMixin, mixins.AnytimePlannerMixin) def name(self) -> str: return "aries" - def _solve( + def _prepare_solving( self, problem: "up.model.AbstractProblem", heuristic: Optional[ @@ -300,7 +300,7 @@ def _solve( ] = None, timeout: Optional[float] = None, output_stream: Optional[IO[str]] = None, - ) -> "up.engines.results.PlanGenerationResult": + ) -> Tuple["_Server", proto.PlanRequest]: # Assert that the problem is a valid problem assert isinstance(problem, up.model.AbstractProblem) if heuristic is not None: @@ -312,9 +312,15 @@ def _solve( # Note: when the `server` object is garbage collected, the process will be killed server = _Server(self._executable, output_stream=output_stream) proto_problem = self._writer.convert(problem) - req = proto.PlanRequest(problem=proto_problem, timeout=timeout) - response = server.planner.planOneShot(req) + + return server, req + + def _process_response( + self, + response: proto.PlanGenerationResult, + problem: "up.model.AbstractProblem", + ) -> "up.engines.results.PlanGenerationResult": response = self._reader.convert(response, problem) # if we have a time triggered plan and a recent version of the UP that support setting epsilon-separation, @@ -330,24 +336,29 @@ def _solve( return response + def _solve( + self, + problem: "up.model.AbstractProblem", + heuristic: Optional[ + Callable[["up.model.state.ROState"], Optional[float]] + ] = None, + timeout: Optional[float] = None, + output_stream: Optional[IO[str]] = None, + ) -> "up.engines.results.PlanGenerationResult": + server, req = self._prepare_solving(problem, heuristic, timeout, output_stream) + response = server.planner.planOneShot(req) + return self._process_response(response, problem) + def _get_solutions( self, problem: "up.model.AbstractProblem", timeout: Optional[float] = None, output_stream: Optional[IO[str]] = None, ) -> Iterator["up.engines.results.PlanGenerationResult"]: - # Assert that the problem is a valid problem - assert isinstance(problem, up.model.Problem) - - # start a gRPC server in its own process - # Note: when the `server` object is garbage collected, the process will be killed - server = _Server(self._executable, output_stream=output_stream) - proto_problem = self._writer.convert(problem) - - req = proto.PlanRequest(problem=proto_problem, timeout=timeout) + server, req = self._prepare_solving(problem, None, timeout, output_stream) stream = server.planner.planAnytime(req) for response in stream: - response = self._reader.convert(response, problem) + response = self._process_response(response, problem) yield response # The parallel solver implementation in aries are such that intermediate answer might arrive late if response.status != PlanGenerationResultStatus.INTERMEDIATE: