Skip to content

Latest commit

 

History

History
86 lines (66 loc) · 3.65 KB

README.md

File metadata and controls

86 lines (66 loc) · 3.65 KB

RL Flappy Bird

Overview

This project is a basic application of Reinforcement Learning.

It integrates Deep Java Library (DJL) to uses DQN to train agent. The pretrained model are trained with 3M steps on a single GPU.

You can find article explaining the training process on towards data science, or 中文版文章.

Build the project and run

This project supports building with Maven, you can use the following command to build:

mvn compile  

The following command will start to train without graphics:

mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird"

The above command will train from scratch. You can also try to train with the pretrained weight:

mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird" -Dexec.args="-p"

To test with the model directly, you can do the followings

mvn exec:java -Dexec.mainClass="com.kingyu.rlbird.ai.TrainBird" -Dexec.args="-p -t"  
Argument Comments
-g Training with graphics.
-b Batch size to use for training.
-p Use pre-trained weights.
-t Test the trained model.

Deep Q-Network Algorithm

The pseudo-code for the Deep Q Learning algorithm, as given in Human-level Control through Deep Reinforcement Learning. Nature, can be found below:

Initialize replay memory D to size N
Initialize action-value function Q with random weights
for episode = 1, M do
    Initialize state s_1
    for t = 1, T do
        With probability ϵ select random action a_t
        otherwise select a_t=max_a  Q(s_t,a; θ_i)
        Execute action a_t in emulator and observe r_t and s_(t+1)
        Store transition (s_t,a_t,r_t,s_(t+1)) in D
        Sample a minibatch of transitions (s_j,a_j,r_j,s_(j+1)) from D
        Set y_j:=
            r_j for terminal s_(j+1)
            r_j+γ*max_(a^' )  Q(s_(j+1),a'; θ_i) for non-terminal s_(j+1)
        Perform a gradient step on (y_j-Q(s_j,a_j; θ_i))^2 with respect to θ
    end for
end for

Notes

Trained Model

  • It may take 10+ hours to train a bird to a perfect state. You can find the model trained with three million steps in project resource folder: src/main/resources/model/dqn-trained-0000-params

Troubleshooting

This work is based on the following repos:

License

MIT © Kingyu Luk