Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DQN for flow #964

Open
wants to merge 45 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ef50b5b
DQN for flow
Jun 11, 2020
c5c5ed2
correct typo
pengyuan-zhou Jun 11, 2020
74cbd10
rm trailing space
pengyuan-zhou Jun 11, 2020
0829e9e
add reference for dqn parameters setup
pengyuan-zhou Jun 11, 2020
69236de
Update train.py
pengyuan-zhou Jun 11, 2020
b59728e
Update train.py
pengyuan-zhou Jun 11, 2020
d66f1ce
Update train.py
pengyuan-zhou Jun 11, 2020
b2628e2
fix import
pengyuan-zhou Jun 11, 2020
72f9fca
add rllib back to avoid test error
pengyuan-zhou Jun 11, 2020
6751b56
change default test to dqn
pengyuan-zhou Jun 11, 2020
4c3d9d4
add TestDQNExamples for traffic light grid examples
pengyuan-zhou Jun 11, 2020
3bfb9e6
rm light grid test for PPO
Jun 11, 2020
c6b2fd1
Update test_examples.py
pengyuan-zhou Jun 11, 2020
3100841
trial
pengyuan-zhou Jun 11, 2020
f74eb50
trial2
pengyuan-zhou Jun 12, 2020
1cd579d
pass ignore_reinit_error=True
pengyuan-zhou Jun 12, 2020
81a203e
pass ray.shutdown() before ray.init
pengyuan-zhou Jun 12, 2020
e773582
Update traffic_light_grid.py
pengyuan-zhou Jun 12, 2020
eb12328
update
Jun 22, 2020
97f75bb
update
Jun 22, 2020
e7c7ea0
update
Jun 22, 2020
bdb9ccc
update
Jun 22, 2020
c36d223
typo
Jun 22, 2020
720bcad
update
Jun 22, 2020
91a0a0b
Update singleagent_traffic_light_grid.py
pengyuan-zhou Jun 22, 2020
50be0ae
Update multiagent_traffic_light_grid.py
pengyuan-zhou Jun 22, 2020
6303be2
Update train.py
pengyuan-zhou Jun 23, 2020
658b9cb
Update multiagent_traffic_light_grid.py
pengyuan-zhou Jun 23, 2020
4984ffc
rm flow-project from travis.yml
Jun 23, 2020
3f175a4
Merge remote-tracking branch 'upstream/master' into DQN
pengyuan-zhou Jul 24, 2020
2535b0c
Update .travis.yml
pengyuan-zhou Sep 8, 2020
95d9182
Update multiagent_traffic_light_grid.py
pengyuan-zhou Sep 8, 2020
b387646
Update train.py
pengyuan-zhou Sep 8, 2020
c0d41cf
Update README.md
pengyuan-zhou Dec 18, 2020
00c5c6c
Update README.md
pengyuan-zhou Dec 18, 2020
85a84e4
Set theme jekyll-theme-midnight
pengyuan-zhou Dec 19, 2020
ef777f7
Set theme jekyll-theme-midnight
pengyuan-zhou Dec 19, 2020
04631ad
Update README.md
pengyuan-zhou Dec 19, 2020
bdf80f2
Set theme jekyll-theme-cayman
pengyuan-zhou Dec 19, 2020
3b998de
Update README.md
pengyuan-zhou Dec 19, 2020
c9a7377
Update README.md
pengyuan-zhou Dec 19, 2020
d0ab1b0
Update README.md
pengyuan-zhou Dec 19, 2020
1172cb3
Update .travis.yml
pengyuan-zhou Dec 20, 2020
c01258b
Delete _config.yml
pengyuan-zhou Dec 20, 2020
1282c67
Update multiagent_traffic_light_grid.py
pengyuan-zhou Dec 20, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
"target_velocity": 50,
"switch_time": 3,
"num_observed": 2,
"discrete": False,
"discrete": False, # set True for DQN
"tl_type": "actuated",
"num_local_edges": 4,
"num_local_lights": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def get_non_flow_params(enter_speed, add_net_params):
'target_velocity': 50,
'switch_time': 3.0,
'num_observed': 2,
'discrete': False,
'tl_type': 'controlled'
'discrete': False, # set True for DQN
'tl_type': 'actuated'
}

additional_net_params = {
Expand Down
43 changes: 30 additions & 13 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def parse_args(args):
parser.add_argument(
'--rl_trainer', type=str, default="rllib",
help='the RL trainer to use. either rllib or Stable-Baselines')

parser.add_argument(
'--algorithm', type=str, default="PPO",
help='RL algorithm to use. Options are PPO and DQN right now.')
parser.add_argument(
'--num_cpus', type=int, default=1,
help='How many CPUs to use')
Expand Down Expand Up @@ -101,6 +103,7 @@ def run_model_stablebaseline(flow_params,
def setup_exps_rllib(flow_params,
n_cpus,
n_rollouts,
flags,
policy_graphs=None,
policy_mapping_fn=None,
policies_to_train=None):
Expand All @@ -114,6 +117,8 @@ def setup_exps_rllib(flow_params,
number of CPUs to run the experiment over
n_rollouts : int
number of rollouts per training iteration
flags:
custom arguments
policy_graphs : dict, optional
TODO
policy_mapping_fn : function, optional
Expand All @@ -139,19 +144,31 @@ def setup_exps_rllib(flow_params,

horizon = flow_params['env'].horizon

alg_run = "PPO"

agent_cls = get_agent_class(alg_run)
config = deepcopy(agent_cls._default_config)
alg_run = flags.algorithm.upper()

if alg_run == "PPO":
agent_cls = get_agent_class(alg_run)
config = deepcopy(agent_cls._default_config)
config["gamma"] = 0.999 # discount rate
config["model"].update({"fcnet_hiddens": [32, 32, 32]})
config["use_gae"] = True
config["lambda"] = 0.97
config["kl_target"] = 0.02
config["num_sgd_iter"] = 10
elif alg_run == "DQN":
agent_cls = get_agent_class(alg_run)
config = deepcopy(agent_cls._default_config)
config['clip_actions'] = False
config["timesteps_per_iteration"] = horizon * n_rollouts
# https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/dqn/atari-dist-dqn.yaml
config["hiddens"] = [512]
config["lr"] = 0.0000625
config["schedule_max_timesteps"] = 2000000
config["buffer_size"] = 1000000
config["target_network_update_freq"] = 8000

config["num_workers"] = n_cpus
config["train_batch_size"] = horizon * n_rollouts
config["gamma"] = 0.999 # discount rate
config["model"].update({"fcnet_hiddens": [32, 32, 32]})
config["use_gae"] = True
config["lambda"] = 0.97
config["kl_target"] = 0.02
config["num_sgd_iter"] = 10
config["horizon"] = horizon

# save the flow params for replay
Expand Down Expand Up @@ -190,10 +207,10 @@ def train_rllib(submodule, flags):
policies_to_train = getattr(submodule, "policies_to_train", None)

alg_run, gym_name, config = setup_exps_rllib(
flow_params, n_cpus, n_rollouts,
flow_params, n_cpus, n_rollouts, flags,
policy_graphs, policy_mapping_fn, policies_to_train)

ray.init(num_cpus=n_cpus + 1, object_store_memory=200 * 1024 * 1024)
ray.init(num_cpus=n_cpus + 1, ignore_reinit_error=True, object_store_memory=200 * 1024 * 1024)
exp_config = {
"run": alg_run,
"env": gym_name,
Expand Down
2 changes: 1 addition & 1 deletion flow/envs/multiagent/traffic_light_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _apply_rl_actions(self, rl_actions):
for rl_id, rl_action in rl_actions.items():
i = int(rl_id.split("center")[ID_IDX])
if self.discrete:
raise NotImplementedError
action = rl_action
else:
# convert values less than 0.0 to zero and above to 1. 0's
# indicate that we should not switch the direction
Expand Down
2 changes: 1 addition & 1 deletion flow/envs/traffic_light_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"switch_time": 2.0,
# whether the traffic lights should be actuated by sumo or RL
# options are "controlled" and "actuated"
"tl_type": "controlled",
"tl_type": "actuated",
# determines whether the action space is meant to be discrete or continuous
"discrete": False,
}
Expand Down