Skip to content

Commit

Permalink
Add additional concurrency helpers run_tasks and `run_tasks_in_back…
Browse files Browse the repository at this point in the history
…ground`

PiperOrigin-RevId: 672486510
Change-Id: I9e096db377d071bdafb2287e119f76f75a57dab0
  • Loading branch information
jagapiou authored and copybara-github committed Sep 9, 2024
1 parent 3dd54ca commit 3ea3ca3
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections.abc import Callable, Sequence
import datetime
import functools

from concordia.components.agent import action_spec_ignored
from concordia.components.agent import memory_component
Expand Down Expand Up @@ -125,9 +126,12 @@ def _query_memory(self, query: str) -> str:

def _make_pre_act_value(self) -> str:
agent_name = self.get_entity().name
results = concurrency.run_parallel(self._query_memory, self._queries)
results = concurrency.run_tasks({
query: functools.partial(self._query_memory, query)
for query in self._queries
})
results_str = '\n'.join(
[f'{query}: {result}' for query, result in zip(self._queries, results)]
[f'{query}: {result}' for query, result in results.items()]
)
if self._summarization_question is not None:
prompt = self._summarization_question.format(
Expand Down
10 changes: 6 additions & 4 deletions concordia/components/agent/relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Agent relationships with others component."""

from collections.abc import Sequence
import functools

from concordia.components.agent import action_spec_ignored
from concordia.components.agent import memory_component
Expand Down Expand Up @@ -93,12 +94,13 @@ def _query_memory(self, query: str) -> str:
return result

def _make_pre_act_value(self) -> str:
results = concurrency.run_parallel(
self._query_memory, self._related_agents_names
)
results = concurrency.run_tasks({
query: functools.partial(self._query_memory, query)
for query in self._related_agents_names
})
output = '\n'.join([
f'{query}: {result}'
for query, result in zip(self._related_agents_names, results)
for query, result in results.items()
])

self._logging_channel({'Key': self.get_pre_act_key(), 'Value': output})
Expand Down
22 changes: 10 additions & 12 deletions concordia/environment/game_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Callable, Mapping, Sequence
import dataclasses
import datetime
import functools
import random
from typing import Any

Expand Down Expand Up @@ -211,7 +212,7 @@ def get_player_names(self):
def update_from_player(self, player_name: str, action_attempt: str):
prompt = interactive_document.InteractiveDocument(self._model)

concurrency.run_parallel(
concurrency.map_parallel(
lambda construct: construct.update_before_event(
f'{player_name}: {action_attempt}'
),
Expand Down Expand Up @@ -270,7 +271,7 @@ def get_externality(externality):
return externality.update_after_event(event_statement)

if self._concurrent_externalities:
concurrency.run_parallel(get_externality, self._components.values())
concurrency.map_parallel(get_externality, self._components.values())
else:
for externality in self._components.values():
externality.update_after_event(event_statement)
Expand Down Expand Up @@ -303,17 +304,14 @@ def view_for_player(self, player_name):
return

def update_components(self) -> None:
futures = []
with concurrency.executor() as pool:
for comp in self._components.values():
future = pool.submit(
concurrency.run_tasks({
f'{component.name}.update': functools.partial(
helper_functions.apply_recursively,
parent_component=comp,
function_name='update'
parent_component=component,
function_name='update',
)
futures.append(future)
for future in futures:
future.result()
for component in self._components.values()
})

def _step_player(
self,
Expand Down Expand Up @@ -374,7 +372,7 @@ def step(
random.shuffle(players)

if self._concurrent_action:
concurrency.run_parallel(step_player_fn, players)
concurrency.map_parallel(step_player_fn, players)
else:
for player in players:
step_player_fn(player)
Expand Down
130 changes: 123 additions & 7 deletions concordia/utils/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@

"""Concurrency helpers."""

from collections.abc import Collection, Iterator, Sequence
from collections.abc import Collection, Iterator, Mapping, Sequence
from concurrent import futures
import contextlib
import functools
from typing import Any, Callable, TypeVar

from absl import logging

_T = TypeVar('_T')


Expand Down Expand Up @@ -51,7 +54,119 @@ def executor(**kwargs) -> Iterator[futures.ThreadPoolExecutor]:
thread_executor.shutdown()


def run_parallel(
def _run_task(key: str, fn: Callable[[], _T]) -> Callable[[], _T]:
"""Returns fn() and logs any error."""
try:
return fn()
except:
logging.exception('Error in task %s', key)
raise


def _as_completed(
tasks: Mapping[str, Callable[[], _T]],
*,
timeout: float | None = None,
max_workers: int | None = None,
) -> Iterator[tuple[str, futures.Future[_T]]]:
"""Maps a function to a sequence of values in parallel.
IMPORTANT: Passed callables must be threadsafe.
Args:
tasks: callables to execute (MUST BE THREADSAFE)
timeout: the maximum number of seconds to wait for all tasks to complete.
max_workers: them maximum number of parallel jobs. If None will use as many
workers as there are tasks.
Yields:
(key, future) as tasks complete.
Raises:
TimeoutError: If all the results are not generated before the timeout.
"""
if max_workers is None:
max_workers = len(tasks)
with executor(max_workers=max_workers) as executor_:
key_by_future = {
executor_.submit(_run_task, key, task): key
for key, task in tasks.items()
}
for future in futures.as_completed(key_by_future, timeout=timeout):
yield key_by_future[future], future


def run_tasks(
tasks: Mapping[str, Callable[[], _T]],
*,
timeout: float | None = None,
max_workers: int | None = None,
) -> Mapping[str, _T]:
"""Runs the callables in parallel, blocks until first failure.
IMPORTANT: Passed callables must be threadsafe.
Args:
tasks: callables to execute (MUST BE THREADSAFE)
timeout: the maximum number of seconds to wait.
max_workers: them maximum number of parallel jobs. If None will use as many
workers as there are tasks.
Returns:
The results fn(*arg) for arg in args]
However, the calls will be executed concurrently.
Raises:
TimeoutError: If all the results are not generated before the timeout.
Exception: If any task raises an exception.
"""
return {
key: future.result()
for key, future in _as_completed(
tasks, timeout=timeout, max_workers=max_workers
)
}


def run_tasks_in_background(
tasks: Mapping[str, Callable[[], _T]],
*,
timeout: float | None = None,
max_workers: int | None = None,
) -> tuple[Mapping[str, _T], Mapping[str, BaseException]]:
"""Runs the callables in parallel, blocks until all complete.
IMPORTANT: Passed callables must be threadsafe.
Args:
tasks: callables to execute (MUST BE THREADSAFE)
timeout: the maximum number of seconds to wait.
max_workers: them maximum number of parallel jobs. If None will use as many
workers as there are tasks.
Returns:
(results, errors): a mappings from key to the result of the callable or the
exception it raised. Thus if no task raised an error, errors will be empty.
"""
results = {}
errors = {}
try:
for key, future in _as_completed(
tasks, timeout=timeout, max_workers=max_workers
):
error = future.exception()
if error is not None:
errors[key] = error
else:
results[key] = future.result()
except TimeoutError as error:
unfinished = tasks.keys() - results.keys() - errors.keys()
for key in unfinished:
errors[key] = error
return results, errors


def map_parallel(
fn: Callable[..., _T],
*args: Collection[Any],
timeout: float | None = None,
Expand All @@ -76,8 +191,9 @@ def run_parallel(
TimeoutError: If all the results are not generated before the timeout.
Exception: If fn(*args) raises for any values.
"""
if max_workers is None:
max_workers = min(len(arg) for arg in args)
with executor(max_workers=max_workers) as executor_:
results = executor_.map(fn, *args, timeout=timeout)
return list(results) # Consume iterator to surface any errors.
tasks = {
str(n): functools.partial(fn, *arg)
for n, arg in enumerate(zip(*args, strict=True))
}
results = run_tasks(tasks, timeout=timeout, max_workers=max_workers)
return [results[key] for key in tasks]
115 changes: 115 additions & 0 deletions concordia/utils/concurrency_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2024 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import time

from absl.testing import absltest
from concordia.utils import concurrency


class ExpectedError(Exception):
pass


def wait_for(seconds):
time.sleep(seconds)


def error_after(seconds):
time.sleep(seconds)
raise ExpectedError()


def return_after(seconds, value):
time.sleep(seconds)
return value


class ConcurrencyTest(absltest.TestCase):

def test_executor_fails_fast(self):
start_time = time.time()
try:
with concurrency.executor() as executor:
executor.submit(wait_for, 5)
raise ExpectedError()
except ExpectedError:
pass
end_time = time.time()
self.assertLess(end_time - start_time, 2)

def test_run_tasks_fails_fast(self):
tasks = {
'wait': functools.partial(wait_for, 5),
'error': functools.partial(error_after, 1),
}
start_time = time.time()
try:
concurrency.run_tasks(tasks)
except ExpectedError:
pass
end_time = time.time()
self.assertLess(end_time - start_time, 2)

def test_run_tasks_error(self):
tasks = {
'wait': functools.partial(wait_for, 5),
'error': functools.partial(error_after, 1),
}
with self.assertRaises(ExpectedError):
concurrency.run_tasks(tasks)

def test_run_tasks_timeout(self):
tasks = {
'wait': functools.partial(wait_for, 5),
}
with self.assertRaises(TimeoutError):
concurrency.run_tasks(tasks, timeout=1)

def test_run_tasks_success(self):
tasks = {
'a': functools.partial(return_after, 1, 'a'),
'b': functools.partial(return_after, 0.1, 'b'),
'c': functools.partial(return_after, 0.1, 'c'),
}
results = concurrency.run_tasks(tasks)
self.assertEqual(results, {'a': 'a', 'b': 'b', 'c': 'c'})

def test_run_tasks_in_background(self):
tasks = {
'a': functools.partial(return_after, 1, 'a'),
'b': functools.partial(return_after, 0.1, 'b'),
'c': functools.partial(return_after, 0.1, 'c'),
'error': functools.partial(error_after, 1),
'wait': functools.partial(wait_for, 5),
}
results, errors = concurrency.run_tasks_in_background(tasks, timeout=2)
with self.subTest('results'):
self.assertEqual(results, {'a': 'a', 'b': 'b', 'c': 'c'})
with self.subTest('errors'):
self.assertEqual(
{key: type(error) for key, error in errors.items()},
{'error': ExpectedError, 'wait': TimeoutError}
)

def test_map_parallel(self):
results = concurrency.map_parallel(
return_after, [1, 0.5, 0.1], ['a', 'b', 'c']
)
self.assertEqual(results, ['a', 'b', 'c'])


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 3ea3ca3

Please sign in to comment.