-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_kill_enemy.py
98 lines (82 loc) · 2.68 KB
/
train_kill_enemy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import argparse
import logging
import os
import time
from typing import Any, Optional
import common.logging_options as logging_options
from envs.predefined_envs import create_kill_enemy
from trainer.trainer import Trainer
logging_options.set_default()
def get_configs():
configs = dict[str, Any] (
ppo_params= dict(
policy = 'MultiInputPolicy',
n_steps = 1024,
batch_size = 64,
learning_rate= 2*1e-5,
policy_kwargs= dict(
# net_arch = [64, 64]
net_arch = [128] * 2
),
),
grid_params=dict(
sight = 0,
add_grid = False,
),
objective_params = dict(
enemy_type = 'slime',
enemy_count = 2,
min_distance=50,
max_distance=120,
bounty=5,
episode_max_len=60*6
),
learn_params = dict(),
actions_params = dict(
can_shoot = False,
can_dash = True,
)
)
return configs
def parse_load_from(load_from: str) -> tuple[str, str, Optional[str]]:
load_from_split = load_from.split('/')[-3:]
if len(load_from_split) == 2:
project, name = load_from_split
model = None
elif len(load_from_split) == 3:
project, name, model = load_from_split
else:
raise Exception("Can't evaluate with load_from: {load_from}")
return project, name, model
def evaluate(load_from: str):
project, name, model = parse_load_from(load_from)
trainer = Trainer(export_wandb=True)
if model:
trainer.evaluate_model(create_kill_enemy, 15, project, name, model)
else:
logging.info('Evaluating all models')
trainer.evaluate_all_models(create_kill_enemy, 5, project, name)
def main(total_steps: int, load_from: str, eval: bool):
if eval:
assert load_from is not None, 'Must specify load_from when evaluating'
evaluate(load_from)
return
trainer = Trainer(export_wandb=True)
configs = get_configs()
project_name=f'study_{time.time_ns()//100000000}'
trial_name = '1'
record_path = os.path.join(trainer.get_trial_path(project_name, trial_name), 'replay')
env = create_kill_enemy(configs, record_path=record_path)
if load_from:
load_project, load_trial, load_model = parse_load_from(load_from)
assert load_model is not None, 'Must specify model when loading'
trainer.fork_training(env, total_steps, configs, project_name, trial_name, load_project, load_trial, load_model)
else:
trainer.train(env, total_steps, configs, project_name, trial_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--total_steps', type=int, default=1000)
parser.add_argument('--eval', type=bool, default=False)
parser.add_argument('--load_from', type=str, default=None)
args = parser.parse_args()
main(**vars(args))