Skip to content

Commit

Permalink
Merge pull request #79 from Kautenja/random_level_env
Browse files Browse the repository at this point in the history
Random level env
  • Loading branch information
Kautenja authored Jan 16, 2019
2 parents 657e5da + fc4f7da commit 142da44
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 11 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ where:
For example, to play 4-2 on the downsampled ROM, you would use the environment
id `SuperMarioBros-4-2-v1`.

### Random Stage Selection

The random stage selection environment randomly selects a stage and allows a
single attempt to clear it. Upon a death and subsequent call to `reset`, the
environment randomly selects a new stage. This is only available for the
standard Super Mario Bros. game, _not_ Lost Levels (at the moment). To use
these environments, append `RandomLevels` to the `SuperMarioBros` id. For
example, to use the standard ROM with random stage selection use
`SuperMarioBrosRandomLevels-v0`. To seed the random stage selection use the
`seed` method of the env, i.e., `env.seed(1)`, before any calls to `reset`.

## Step

Info about the rewards and info returned by the `step` method.
Expand Down
2 changes: 2 additions & 0 deletions gym_super_mario_bros/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Registration code of Gym environments in this package."""
from .smb_env import SuperMarioBrosEnv
from .smb_random_levels_env import SuperMarioBrosRandomLevelsEnv
from ._registration import make


# define the outward facing API of this package
__all__ = [
make.__name__,
SuperMarioBrosEnv.__name__,
SuperMarioBrosRandomLevelsEnv.__name__,
]
19 changes: 17 additions & 2 deletions gym_super_mario_bros/_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,30 @@
import gym


def _register_mario_env(id, **kwargs):
def _register_mario_env(id, is_random=False, **kwargs):
"""
Register a Super Mario Bros. (1/2) environment with OpenAI Gym.
Args:
id (str): id for the env to register
is_random (bool): whether to use the random levels environment
kwargs (dict): keyword arguments for the SuperMarioBrosEnv initializer
Returns:
None
"""
# if the is random flag is set
if is_random:
# set the entry point to the random level environment
entry_point = 'gym_super_mario_bros:SuperMarioBrosRandomLevelsEnv'
else:
# set the entry point to the standard Super Mario Bros. environment
entry_point = 'gym_super_mario_bros:SuperMarioBrosEnv'
# register the environment
gym.envs.registration.register(
id=id,
entry_point='gym_super_mario_bros:SuperMarioBrosEnv',
entry_point=entry_point,
max_episode_steps=9999999,
reward_threshold=9999999,
kwargs=kwargs,
Expand All @@ -32,6 +40,13 @@ def _register_mario_env(id, **kwargs):
_register_mario_env('SuperMarioBros-v3', rom_mode='rectangle')


# Super Mario Bros. Random Levels
_register_mario_env('SuperMarioBrosRandomLevels-v0', is_random=True, rom_mode='vanilla')
_register_mario_env('SuperMarioBrosRandomLevels-v1', is_random=True, rom_mode='downsample')
_register_mario_env('SuperMarioBrosRandomLevels-v2', is_random=True, rom_mode='pixel')
_register_mario_env('SuperMarioBrosRandomLevels-v3', is_random=True, rom_mode='rectangle')


# Super Mario Bros. 2 (Lost Levels)
_register_mario_env('SuperMarioBros2-v0', lost_levels=True, rom_mode='vanilla')
_register_mario_env('SuperMarioBros2-v1', lost_levels=True, rom_mode='downsample')
Expand Down
165 changes: 165 additions & 0 deletions gym_super_mario_bros/smb_random_levels_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""An OpenAI Gym Super Mario Bros. environment that randomly selects levels."""
import gym
from numpy.random import RandomState
from nes_py.nes_env import SCREEN_HEIGHT, SCREEN_WIDTH
from .smb_env import SuperMarioBrosEnv


class SuperMarioBrosRandomLevelsEnv(gym.Env):
"""A Super Mario Bros. environment that randomly selects levels."""

# relevant meta-data about the environment
metadata = SuperMarioBrosEnv.metadata

# the legal range of rewards for each step
reward_range = SuperMarioBrosEnv.reward_range

# observation space for the environment is static across all instances
observation_space = SuperMarioBrosEnv.observation_space

# action space is a bitmap of button press values for the 8 NES buttons
action_space = SuperMarioBrosEnv.action_space

def __init__(self, rom_mode='vanilla'):
"""
Initialize a new Super Mario Bros environment.
Args:
rom_mode (str): the ROM mode to use when loading ROMs from disk
Returns:
None
"""
# create a dedicated random number generator for the environment
self._rng = RandomState()
# setup the environments
self.envs = []
# iterate over the worlds in the game, i.e., {1, ..., 8}
for world in range(1, 9):
# append a new list to put this world's stages into
self.envs.append([])
# iterate over the stages in the world, i.e., {1, ..., 4}
for stage in range(1, 5):
# create the environment with the given ROM mode
env = SuperMarioBrosEnv(rom_mode=rom_mode, target=(world, stage))
# add the environment to the stage list for this world
self.envs[-1].append(env)
# create a placeholder for the current environment
self.env = self.envs[0][0]
# create a placeholder for the image viewer to render the screen
self.viewer = None

def _select_random_level(self):
"""Select a random level to use."""
self.env = self.envs[self._rng.randint(1, 9) - 1][self._rng.randint(1, 5) - 1]

def seed(self, seed=None):
"""
Set the seed for this environment's random number generator.
Returns:
list<bigint>: Returns the list of seeds used in this env's random
number generators. The first value in the list should be the
"main" seed, or the value which a reproducer should pass to
'seed'. Often, the main seed equals the provided 'seed', but
this won't be true if seed=None, for example.
"""
# if there is no seed, return an empty list
if seed is None:
return []
# set the random number seed for the NumPy random number generator
self._rng.seed(seed)
# return the list of seeds used by RNG(s) in the environment
return [seed]

def reset(self):
"""
Reset the state of the environment and returns an initial observation.
Returns:
state (np.ndarray): next frame as a result of the given action
"""
# select a new level
self._select_random_level()
# reset the environment
return self.env.reset()

def step(self, action):
"""
Run one frame of the NES and return the relevant observation data.
Args:
action (byte): the bitmap determining which buttons to press
Returns:
a tuple of:
- state (np.ndarray): next frame as a result of the given action
- reward (float) : amount of reward returned after given action
- done (boolean): whether the episode has ended
- info (dict): contains auxiliary diagnostic information
"""
return self.env.step(action)

def close(self):
"""Close the environment."""
# make sure the environment hasn't already been closed
if self.env is None:
raise ValueError('env has already been closed.')
# iterate over each list of stages
for stage_lists in self.envs:
# iterate over each stage
for stage in stage_lists:
# close the environment
stage.close()
# close the environment permanently
self.env = None
# if there is an image viewer open, delete it
if self.viewer is not None:
self.viewer.close()

def render(self, mode='human'):
"""
Render the environment.
Args:
mode (str): the mode to render with:
- human: render to the current display
- rgb_array: Return an numpy.ndarray with shape (x, y, 3),
representing RGB values for an x-by-y pixel image
Returns:
a numpy array if mode is 'rgb_array', None otherwise
"""
if mode == 'human':
# if the viewer isn't setup, import it and create one
if self.viewer is None:
from nes_py._image_viewer import ImageViewer
# get the caption for the ImageViewer
# create the ImageViewer to display frames
self.viewer = ImageViewer(
caption=self.__class__.__name__,
height=SCREEN_HEIGHT,
width=SCREEN_WIDTH,
)
# show the screen on the image viewer
self.viewer.show(self.env.screen)
elif mode == 'rgb_array':
return self.env.screen
else:
# unpack the modes as comma delineated strings ('a', 'b', ...)
render_modes = [repr(x) for x in self.metadata['render.modes']]
msg = 'valid render modes are: {}'.format(', '.join(render_modes))
raise NotImplementedError(msg)

def get_keys_to_action(self):
"""Return the dictionary of keyboard keys to actions."""
return self.env.get_keys_to_action()


# explicitly define the outward facing API of this module
__all__ = [SuperMarioBrosRandomLevelsEnv.__name__]
17 changes: 17 additions & 0 deletions gym_super_mario_bros/tests/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ class ShouldMakeEnv:
x_pos = 40
# the environments ID
env_id = None
# the random seed to apply
seed = None

def _test_env(self, env_id):
env = make(env_id)
if self.seed is not None:
env.seed(self.seed)
env.reset()
s, r, d, i = env.step(0)
self.assertEqual(self.coins, i['coins'])
Expand All @@ -51,6 +55,19 @@ class ShouldMakeSuperMarioBros(ShouldMakeEnv, TestCase):
env_id = ['SuperMarioBros-v{}'.format(v) for v in range(4)]


class ShouldMakeSuperMarioBrosRandomLevels(ShouldMakeEnv, TestCase):
# the random number seed for this environment
seed = 1
# the amount of time left
time = 300
# the current world
world = 6
# the current stage
stage = 4
# the environments ID for all versions of Super Mario Bros
env_id = ['SuperMarioBrosRandomLevels-v{}'.format(v) for v in range(4)]


class ShouldMakeSuperMarioBrosLostLevels(ShouldMakeEnv, TestCase):
# the amount of time left
time = 400
Expand Down
15 changes: 6 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
from setuptools import setup, find_packages


def README():
"""Return the contents of the read me file for this project."""
with open('README.md') as readme:
return readme.read()
# read the contents of the README
with open('README.md') as README_md:
README = README_md.read()


setup(
name='gym_super_mario_bros',
version='7.0.1',
version='7.1.0',
description='Super Mario Bros. for OpenAI Gym',
long_description=README(),
long_description=README,
long_description_content_type='text/markdown',
keywords=' '.join([
'OpenAI-Gym',
Expand Down Expand Up @@ -40,9 +39,7 @@ def README():
author_email='[email protected]',
license='Proprietary',
packages=find_packages(exclude=['tests', '*.tests', '*.tests.*']),
package_data={
'gym_super_mario_bros': ['_roms/*.nes']
},
package_data={ 'gym_super_mario_bros': ['_roms/*.nes'] },
install_requires=[
'matplotlib>=2.0.2',
'nes-py>=5.0.0',
Expand Down

0 comments on commit 142da44

Please sign in to comment.