Skip to content

Commit

Permalink
Split out trigger_service, implement automerge setting
Browse files Browse the repository at this point in the history
  • Loading branch information
irgolic committed Oct 16, 2023
1 parent 8fddadb commit 3267121
Show file tree
Hide file tree
Showing 14 changed files with 1,112 additions and 125 deletions.
1 change: 1 addition & 0 deletions .autopr/triggers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ triggers:
run: summarize_pr
- branch_name: main
run: generate_readme_summaries
automerge: true
parameters:
FILE_SUMMARY_PROMPT: |
Write an executive summary of this file, intended for someone seeing it for the first time.
Expand Down
10 changes: 7 additions & 3 deletions autopr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .services.commit_service import CommitService
from .services.platform_service import PlatformService
from .services.publish_service import PublishService

from .services.trigger_service import TriggerService

from .services.workflow_service import WorkflowService
from .triggers import get_all_triggers
Expand Down Expand Up @@ -75,15 +75,19 @@ def __init__(self):
repo_path=self.get_repo_path(),
)
self.workflow_service = WorkflowService(
triggers=triggers,
workflows=workflows,
action_service=action_service,
publish_service=self.publish_service,
)
self.trigger_service = TriggerService(
triggers=triggers,
publish_service=self.publish_service,
workflow_service=self.workflow_service,
)

async def run(self):
# Run the triggers
return await self.workflow_service.trigger_event(self.event)
return await self.trigger_service.trigger_event(self.event)

def get_repo_path(self):
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions autopr/models/config/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def get_all_executable_ids():
class TriggerModel(ContextModel):
type: str
run: StrictExecutable = Field() # pyright: ignore[reportGeneralTypeIssues]
automerge: bool = False
parameters: Optional[dict[str, Any]] = Field(default=None)

def get_context_for_event(self, event: EventUnion) -> Optional[ContextDict]:
Expand Down
47 changes: 47 additions & 0 deletions autopr/services/platform_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,21 @@ async def create_pr(
"""
raise NotImplementedError

async def merge_pr(
self,
pr_number: int,
merge_method: str = "squash",
):
"""
Merge the pull request.
Parameters
----------
pr_number: int
The PR number
"""
raise NotImplementedError

async def update_pr_body(self, pr_number: int, body: str):
"""
Update the body of the pull request.
Expand Down Expand Up @@ -379,6 +394,31 @@ async def create_pr(

return pr_number, comment_ids

async def merge_pr(
self,
pr_number: int,
merge_method: str = "squash",
):
url = f'https://api.github.com/repos/{self.owner}/{self.repo_name}/pulls/{pr_number}/merge'
headers = self._get_headers()
data = {
'commit_message': 'Merged automatically by AutoPR',
}

async with ClientSession() as session:
async with session.put(url, json=data, headers=headers) as response:
if response.status != 200:
await self._log_failed_request(
'Failed to merge pull request',
request_url=url,
request_headers=headers,
request_body=data,
response=response,
)
raise RuntimeError('Failed to merge pull request')

self.log.debug('Pull request merged successfully')

async def _patch_pr(self, pr_number: int, data: dict[str, Any]):
url = f'https://api.github.com/repos/{self.owner}/{self.repo_name}/pulls/{pr_number}'
headers = self._get_headers()
Expand Down Expand Up @@ -724,6 +764,13 @@ async def create_pr(
) -> tuple[Optional[int], list[Union[str, Type[PlatformService.PRBodySentinel]]]]:
return 1, [PlatformService.PRBodySentinel]

async def merge_pr(
self,
pr_number: int,
merge_method: str = "squash",
):
pass

async def update_pr_title(self, pr_number: int, title: str):
pass

Expand Down
11 changes: 11 additions & 0 deletions autopr/services/publish_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,17 @@ async def end_section(

await self.update()

async def merge(self):
"""
Merge the pull request.
"""
if self.root_publish_service is not None:
return await self.root_publish_service.merge()
if self.pr_number is None:
self.log.warning("PR merge requested, but does not exist")
return
return await self.platform_service.merge_pr(self.pr_number)

def _contains_last_code_block(self, parent: UpdateSection) -> bool:
for section in reversed(parent.updates):
if isinstance(section, CodeBlock):
Expand Down
150 changes: 150 additions & 0 deletions autopr/services/trigger_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import asyncio
from typing import Coroutine, Any

from autopr.log_config import get_logger
from autopr.models.config.elements import ActionConfig, WorkflowInvocation, IterableWorkflowInvocation, ContextAction
from autopr.models.config.entrypoints import Trigger
from autopr.models.events import EventUnion
from autopr.models.executable import Executable, ContextDict
from autopr.services.platform_service import PlatformService
from autopr.services.publish_service import PublishService
from autopr.services.utils import truncate_strings, format_for_publishing
from autopr.services.workflow_service import WorkflowService


class TriggerService:
def __init__(
self,
triggers: list[Trigger],
publish_service: PublishService,
workflow_service: WorkflowService,
):
self.triggers = triggers
self.publish_service = publish_service
self.workflow_service = workflow_service

print("Loaded triggers:")
for t in self.triggers:
print(t.json(indent=2))

self.log = get_logger(service="trigger")

def _get_name_for_executable(self, executable: Executable) -> str:
if isinstance(executable, str):
return executable
if isinstance(executable, ActionConfig):
return executable.action
if isinstance(executable, WorkflowInvocation) or isinstance(executable, IterableWorkflowInvocation):
return executable.workflow
if isinstance(executable, ContextAction):
raise RuntimeError("Meaningless trigger! Whatchu tryina do :)")
raise ValueError(f"Unknown executable type {executable}")

def _get_triggers_and_contexts_for_event(self, event: EventUnion) -> list[tuple[Trigger, ContextDict]]:
# Gather all triggers that match the event
triggers_and_context: list[tuple[Trigger, ContextDict]] = []
for trigger in self.triggers:
context = trigger.get_context_for_event(event)
if context is None:
continue
triggers_and_context.append((trigger, context))
return triggers_and_context

async def _get_trigger_coros_for_event(
self,
triggers_and_context: list[tuple[Trigger, ContextDict]],
) -> list[Coroutine[Any, Any, ContextDict]]:
# Build coroutines for each trigger
if not triggers_and_context:
return []
if len(triggers_and_context) == 1:
self.publish_service.title = f"AutoPR: {self._get_name_for_executable(triggers_and_context[0][0].run)}"
return [
self.handle_trigger(
trigger,
context,
publish_service=self.publish_service,
)
for trigger, context in triggers_and_context
]
trigger_titles = [self._get_name_for_executable(trigger.run) for trigger, context in triggers_and_context]
self.publish_service.title = f"AutoPR: {', '.join(truncate_strings(trigger_titles))}"
return [
self.handle_trigger(
trigger,
context,
publish_service=(await self.publish_service.create_child(title=title)),
)
for i, ((trigger, context), title) in enumerate(zip(triggers_and_context, trigger_titles))
]

async def trigger_event(
self,
event: EventUnion,
):
triggers_and_contexts = self._get_triggers_and_contexts_for_event(event)
trigger_coros = await self._get_trigger_coros_for_event(triggers_and_contexts)
if not trigger_coros:
print(event)
self.log.debug(f"No triggers for event")
return

results = await asyncio.gather(*trigger_coros)

exceptions = []
for r in results:
if isinstance(r, Exception):
self.log.error("Error in trigger", exc_info=r)
exceptions.append(r)

if exceptions:
await self.publish_service.finalize(False, exceptions)
else:
await self.publish_service.finalize(True)
# TODO split out multiple triggered workflows into separate PRs,
# so that automerge can be evaluated separately for each
if any(trigger.automerge for trigger, _ in triggers_and_contexts):
await self.publish_service.merge()

return results

async def handle_trigger(
self,
trigger: Trigger,
context: ContextDict,
publish_service: PublishService,
) -> ContextDict:
await publish_service.publish_code_block(
heading="📣 Trigger",
code=format_for_publishing(trigger),
language="json",
)
await publish_service.publish_code_block(
heading="🎬 Starting context",
code=format_for_publishing(context),
language="json",
)

executable = trigger.run

# Add params
if trigger.parameters:
context["__params__"] = trigger.parameters

try:
context = await self.workflow_service.execute(
executable,
context,
publish_service=publish_service,
)
except Exception as e:
self.log.error("Error while executing", executable=executable, exc_info=e)
raise

await publish_service.publish_code_block(
heading="🏁 Final context",
code=format_for_publishing(context),
language="json",
)

return context
Loading

0 comments on commit 3267121

Please sign in to comment.