Skip to content

Commit

Permalink
Merge branch 'fix/support_parent_rel_paths' of github.com:instadeepai…
Browse files Browse the repository at this point in the history
…/flashbax into fix/support_parent_rel_paths
  • Loading branch information
callumtilbury committed Aug 29, 2024
2 parents 1bffc73 + a72126b commit f5819f0
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest
container:
image: ghcr.io/catthehacker/ubuntu:runner-latest
timeout-minutes: 30
steps:
- uses: actions/checkout@v2
Expand Down
18 changes: 7 additions & 11 deletions .github/workflows/tests_linters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,22 @@ on: [ push, pull_request ]

jobs:
tests-and-linters:
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
runs-on: "${{ matrix.os }}"

strategy:
matrix:
python-version: ["3.9"]
os: [self-hosted]
name: "Python 3.9 on GitHub Hosted runner"
runs-on: ubuntu-latest
container:
image: python:3.9

steps:
- name: Install dependencies for viewer test
run: sudo apt-get update && sudo apt-get install -y xvfb
run: apt-get update && apt-get install -y xvfb
- name: Checkout flashbax
uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "${{ matrix.python-version }}"
- name: Install python dependencies 🔧
run: pip install .[dev]
- name: List python packages 📦
run: pip list
- name: Update git permissions
run: git config --global --add safe.directory /__w/flashbax/flashbax
- name: Run linters 🖌️
run: pre-commit run --all-files --verbose
- name: Run tests 🧪
Expand Down
1 change: 1 addition & 0 deletions flashbax/buffers/prioritised_trajectory_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ def make_prioritised_trajectory_buffer(
if max_size is not None:
max_length_time_axis = max_size // add_batch_size

assert max_length_time_axis is not None
init_fn = functools.partial(
prioritised_init,
add_batch_size=add_batch_size,
Expand Down
1 change: 1 addition & 0 deletions flashbax/buffers/trajectory_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def make_trajectory_buffer(
if max_size is not None:
max_length_time_axis = max_size // add_batch_size

assert max_length_time_axis is not None
init_fn = functools.partial(
init,
add_batch_size=add_batch_size,
Expand Down
24 changes: 24 additions & 0 deletions flashbax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,27 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


def get_timestep_count(buffer_state: chex.ArrayTree) -> int:
"""Utility to compute the total number of timesteps currently in the buffer state.
Args:
buffer_state (BufferStateTypes): the buffer state to compute the total timesteps for.
Returns:
int: the total number of timesteps in the buffer state.
"""
# Ensure the buffer state is a valid buffer state.
assert hasattr(buffer_state, "experience")
assert hasattr(buffer_state, "current_index")
assert hasattr(buffer_state, "is_full")

b_size, t_size_max = get_tree_shape_prefix(buffer_state.experience, 2)
t_size = jax.lax.cond(
buffer_state.is_full,
lambda: t_size_max,
lambda: buffer_state.current_index,
)
timestep_count: int = b_size * t_size
return timestep_count

0 comments on commit f5819f0

Please sign in to comment.