-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #79 from Kautenja/random_level_env
Random level env
- Loading branch information
Showing
6 changed files
with
218 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
"""Registration code of Gym environments in this package.""" | ||
from .smb_env import SuperMarioBrosEnv | ||
from .smb_random_levels_env import SuperMarioBrosRandomLevelsEnv | ||
from ._registration import make | ||
|
||
|
||
# define the outward facing API of this package | ||
__all__ = [ | ||
make.__name__, | ||
SuperMarioBrosEnv.__name__, | ||
SuperMarioBrosRandomLevelsEnv.__name__, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
"""An OpenAI Gym Super Mario Bros. environment that randomly selects levels.""" | ||
import gym | ||
from numpy.random import RandomState | ||
from nes_py.nes_env import SCREEN_HEIGHT, SCREEN_WIDTH | ||
from .smb_env import SuperMarioBrosEnv | ||
|
||
|
||
class SuperMarioBrosRandomLevelsEnv(gym.Env): | ||
"""A Super Mario Bros. environment that randomly selects levels.""" | ||
|
||
# relevant meta-data about the environment | ||
metadata = SuperMarioBrosEnv.metadata | ||
|
||
# the legal range of rewards for each step | ||
reward_range = SuperMarioBrosEnv.reward_range | ||
|
||
# observation space for the environment is static across all instances | ||
observation_space = SuperMarioBrosEnv.observation_space | ||
|
||
# action space is a bitmap of button press values for the 8 NES buttons | ||
action_space = SuperMarioBrosEnv.action_space | ||
|
||
def __init__(self, rom_mode='vanilla'): | ||
""" | ||
Initialize a new Super Mario Bros environment. | ||
Args: | ||
rom_mode (str): the ROM mode to use when loading ROMs from disk | ||
Returns: | ||
None | ||
""" | ||
# create a dedicated random number generator for the environment | ||
self._rng = RandomState() | ||
# setup the environments | ||
self.envs = [] | ||
# iterate over the worlds in the game, i.e., {1, ..., 8} | ||
for world in range(1, 9): | ||
# append a new list to put this world's stages into | ||
self.envs.append([]) | ||
# iterate over the stages in the world, i.e., {1, ..., 4} | ||
for stage in range(1, 5): | ||
# create the environment with the given ROM mode | ||
env = SuperMarioBrosEnv(rom_mode=rom_mode, target=(world, stage)) | ||
# add the environment to the stage list for this world | ||
self.envs[-1].append(env) | ||
# create a placeholder for the current environment | ||
self.env = self.envs[0][0] | ||
# create a placeholder for the image viewer to render the screen | ||
self.viewer = None | ||
|
||
def _select_random_level(self): | ||
"""Select a random level to use.""" | ||
self.env = self.envs[self._rng.randint(1, 9) - 1][self._rng.randint(1, 5) - 1] | ||
|
||
def seed(self, seed=None): | ||
""" | ||
Set the seed for this environment's random number generator. | ||
Returns: | ||
list<bigint>: Returns the list of seeds used in this env's random | ||
number generators. The first value in the list should be the | ||
"main" seed, or the value which a reproducer should pass to | ||
'seed'. Often, the main seed equals the provided 'seed', but | ||
this won't be true if seed=None, for example. | ||
""" | ||
# if there is no seed, return an empty list | ||
if seed is None: | ||
return [] | ||
# set the random number seed for the NumPy random number generator | ||
self._rng.seed(seed) | ||
# return the list of seeds used by RNG(s) in the environment | ||
return [seed] | ||
|
||
def reset(self): | ||
""" | ||
Reset the state of the environment and returns an initial observation. | ||
Returns: | ||
state (np.ndarray): next frame as a result of the given action | ||
""" | ||
# select a new level | ||
self._select_random_level() | ||
# reset the environment | ||
return self.env.reset() | ||
|
||
def step(self, action): | ||
""" | ||
Run one frame of the NES and return the relevant observation data. | ||
Args: | ||
action (byte): the bitmap determining which buttons to press | ||
Returns: | ||
a tuple of: | ||
- state (np.ndarray): next frame as a result of the given action | ||
- reward (float) : amount of reward returned after given action | ||
- done (boolean): whether the episode has ended | ||
- info (dict): contains auxiliary diagnostic information | ||
""" | ||
return self.env.step(action) | ||
|
||
def close(self): | ||
"""Close the environment.""" | ||
# make sure the environment hasn't already been closed | ||
if self.env is None: | ||
raise ValueError('env has already been closed.') | ||
# iterate over each list of stages | ||
for stage_lists in self.envs: | ||
# iterate over each stage | ||
for stage in stage_lists: | ||
# close the environment | ||
stage.close() | ||
# close the environment permanently | ||
self.env = None | ||
# if there is an image viewer open, delete it | ||
if self.viewer is not None: | ||
self.viewer.close() | ||
|
||
def render(self, mode='human'): | ||
""" | ||
Render the environment. | ||
Args: | ||
mode (str): the mode to render with: | ||
- human: render to the current display | ||
- rgb_array: Return an numpy.ndarray with shape (x, y, 3), | ||
representing RGB values for an x-by-y pixel image | ||
Returns: | ||
a numpy array if mode is 'rgb_array', None otherwise | ||
""" | ||
if mode == 'human': | ||
# if the viewer isn't setup, import it and create one | ||
if self.viewer is None: | ||
from nes_py._image_viewer import ImageViewer | ||
# get the caption for the ImageViewer | ||
# create the ImageViewer to display frames | ||
self.viewer = ImageViewer( | ||
caption=self.__class__.__name__, | ||
height=SCREEN_HEIGHT, | ||
width=SCREEN_WIDTH, | ||
) | ||
# show the screen on the image viewer | ||
self.viewer.show(self.env.screen) | ||
elif mode == 'rgb_array': | ||
return self.env.screen | ||
else: | ||
# unpack the modes as comma delineated strings ('a', 'b', ...) | ||
render_modes = [repr(x) for x in self.metadata['render.modes']] | ||
msg = 'valid render modes are: {}'.format(', '.join(render_modes)) | ||
raise NotImplementedError(msg) | ||
|
||
def get_keys_to_action(self): | ||
"""Return the dictionary of keyboard keys to actions.""" | ||
return self.env.get_keys_to_action() | ||
|
||
|
||
# explicitly define the outward facing API of this module | ||
__all__ = [SuperMarioBrosRandomLevelsEnv.__name__] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,17 +2,16 @@ | |
from setuptools import setup, find_packages | ||
|
||
|
||
def README(): | ||
"""Return the contents of the read me file for this project.""" | ||
with open('README.md') as readme: | ||
return readme.read() | ||
# read the contents of the README | ||
with open('README.md') as README_md: | ||
README = README_md.read() | ||
|
||
|
||
setup( | ||
name='gym_super_mario_bros', | ||
version='7.0.1', | ||
version='7.1.0', | ||
description='Super Mario Bros. for OpenAI Gym', | ||
long_description=README(), | ||
long_description=README, | ||
long_description_content_type='text/markdown', | ||
keywords=' '.join([ | ||
'OpenAI-Gym', | ||
|
@@ -40,9 +39,7 @@ def README(): | |
author_email='[email protected]', | ||
license='Proprietary', | ||
packages=find_packages(exclude=['tests', '*.tests', '*.tests.*']), | ||
package_data={ | ||
'gym_super_mario_bros': ['_roms/*.nes'] | ||
}, | ||
package_data={ 'gym_super_mario_bros': ['_roms/*.nes'] }, | ||
install_requires=[ | ||
'matplotlib>=2.0.2', | ||
'nes-py>=5.0.0', | ||
|