-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add template reinforcement learning project as example
- Loading branch information
1 parent
81c1833
commit 176b9b8
Showing
5 changed files
with
302 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
snapshots/** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# RL project template | ||
|
||
This directory provides a minimum template for reinforcement learning (RL) projects. | ||
|
||
## How to use | ||
|
||
Start using this template by copying and pasting the entire directory and modifying the python files depending on your needs. | ||
We created 3 main python files to get started. | ||
|
||
- environment.py | ||
- models.py | ||
- training.py | ||
|
||
See below descriptions for the usages of each file. | ||
|
||
### environment.py | ||
|
||
Sample implementation of an environment class. | ||
Only the basic implementation is provided. | ||
You will need to modify and add extra implementations to the file to | ||
properly to make the algorithm solve your problem. | ||
|
||
### models.py | ||
|
||
Sample implementation of DNN models to be learned by the RL algorithm. | ||
You may also need to modify the models to get desired result. | ||
|
||
### training.py | ||
|
||
Main file that runs the training with RL algorithm. | ||
See this file to understand the basics of how to implement the training process. | ||
|
||
## How to run the script | ||
|
||
Run the training.py script. | ||
By default, it runs on cpu. | ||
|
||
```sh | ||
$ python training.py | ||
``` | ||
|
||
To run on gpu, first, install nnabla-ext-cuda as follows. | ||
|
||
```sh | ||
# $ pip install nnabla-ext-cuda[cuda_version] | ||
# Example: when you have installed CUDA-11.6 on your machine. | ||
$ pip install nnabla-ext-cuda116 | ||
``` | ||
|
||
For the installation of nnabla-ext-cuda see also [here](https://github.com/sony/nnabla) and [here](https://github.com/sony/nnabla-ext-cuda). | ||
|
||
Then, run the script by specifying the gpu id. | ||
|
||
```sh | ||
# This will run the script on gpu id 0. | ||
$ python training.py --gpu=0 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright 2023 Sony Group Corporation. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import gym | ||
from gym.envs.registration import EnvSpec | ||
|
||
|
||
class TemplateEnv(gym.Env): | ||
def __init__(self, max_episode_steps=100): | ||
# max_episode_steps is the maximum possible steps that the rl agent interacts with this environment. | ||
# You can set this value to None if the is no limits. | ||
# The first argument is the name of this environment used when registering this environment to gym. | ||
self.spec = EnvSpec('template-v0', max_episode_steps=max_episode_steps) | ||
self._episode_steps = 0 | ||
|
||
# Use gym's spaces to define the shapes and ranges of states and actions. | ||
# observation_space: definition of states's shape and its ranges | ||
# action_space: definition of actions's shape and its ranges | ||
observation_shape = (10, ) # Example 10 dimensional state with range of [0.0, 1.0] each. | ||
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=observation_shape) | ||
action_shape = (1, ) # 1 dimensional continuous action with range of [0.0, 1.0]. | ||
self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=action_shape) | ||
|
||
def reset(self): | ||
# Reset the entire environment and return the initial state after the reset. | ||
# In this example, we just reset the steps to 0 and return a random state as initial state. | ||
self._episode_steps = 0 | ||
|
||
return self.observation_space.sample() | ||
|
||
def step(self, action): | ||
# step is the core of the environment class. | ||
# You will need to compute the next_state according to the action received and | ||
# the reward to be received. | ||
# You will also need to set and return a done flag if the next_state is the end of the episode. | ||
|
||
# Increment episode steps | ||
self._episode_steps += 1 | ||
|
||
next_state = self._compute_next_state(action) | ||
reward = self._compute_reward(next_state, action) | ||
|
||
# Here we set done flag to false if current episode steps exceeds | ||
# the max episode steps defined in this environment. | ||
done = False if self.spec.max_episode_steps is None else (self.spec.max_episode_steps <= self._episode_steps) | ||
|
||
# info is a convenient dictionary that you can fill | ||
# any additional information to return back to the RL algorithm. | ||
# If you have no extra information, return a empty dictionary. | ||
info = {} | ||
return next_state, reward, done, info | ||
|
||
def _compute_next_state(self, action): | ||
# In this template, we just return a randomly sampled state. | ||
# But in real application, you should compute the next_state according to the given action. | ||
return self.observation_space.sample() | ||
|
||
def _compute_reward(self, state, action): | ||
# In this template, we implemented a easy to understand reward function. | ||
# But in real an application, the design of this reward function is extremely important. | ||
if self._episode_steps < self.spec.max_episode_steps: | ||
return 0 | ||
else: | ||
return 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2023 Sony Group Corporation. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import nnabla as nn | ||
import nnabla.functions as NF | ||
import nnabla.parametric_functions as NPF | ||
import nnabla_rl.distributions as D | ||
from nnabla_rl.distributions.distribution import Distribution | ||
from nnabla_rl.models import ContinuousQFunction, StochasticPolicy | ||
|
||
|
||
class TemplateQFunction(ContinuousQFunction): | ||
# This is a sample QFunction model to be used in Soft-Actor Critic(SAC) algorithm. | ||
# Model to implement depends on the RL algorithm. | ||
def __init__(self, scope_name: str): | ||
super().__init__(scope_name=scope_name) | ||
|
||
def q(self, state: nn.Variable, action: nn.Variable) -> nn.Variable: | ||
# Using the input state and action, here we implemented a simple model that outputs | ||
# q-value (1 dim) using these inputs. | ||
# Modify this model to better perform on your environment. | ||
# The state's and action's shape are both (batch size, variable's shape) | ||
with nn.parameter_scope(self.scope_name): | ||
h = NF.concatenate(state, action) | ||
h = NPF.affine(h, n_outmaps=256, name="linear1") | ||
h = NF.relu(x=h) | ||
h = NPF.affine(h, n_outmaps=256, name="linear2") | ||
h = NF.relu(x=h) | ||
h = NPF.affine(h, n_outmaps=1, name="linear3") | ||
return h | ||
|
||
|
||
class TemplatePolicy(StochasticPolicy): | ||
# This is a sample stochstic policy model to be used in Soft-Actor Critic(SAC) algorithm. | ||
# Stochastic policy is a type of policy that outputs the action's probability distribution | ||
# instead of action itself. | ||
# Model to implement depends on the RL algorithm. | ||
def __init__(self, scope_name: str, action_dim=1): | ||
super().__init__(scope_name=scope_name) | ||
self._action_dim = action_dim | ||
|
||
def pi(self, state: nn.Variable) -> Distribution: | ||
# The state's shape is (batch size, state's shape) | ||
# state's shape is same as the one set in the environment's implementation. | ||
|
||
with nn.parameter_scope(self.scope_name): | ||
h = NPF.affine(state, n_outmaps=256, name="linear1") | ||
h = NF.relu(x=h) | ||
h = NPF.affine(h, n_outmaps=256, name="linear2") | ||
h = NF.relu(x=h) | ||
h = NPF.affine(h, n_outmaps=self._action_dim*2, name="linear3") | ||
reshaped = NF.reshape(h, shape=(-1, 2, self._action_dim)) | ||
|
||
# Split the output into mean and variance of the Gaussian distribution. | ||
mean, ln_var = NF.split(reshaped, axis=1) | ||
|
||
# Check that output shape is as expected | ||
assert mean.shape == ln_var.shape | ||
assert mean.shape == (state.shape[0], self._action_dim) | ||
# SquashedGaussian is a distribution that applies tanh to the output of Gaussian distribution. | ||
return D.SquashedGaussian(mean=mean, ln_var=ln_var) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2023 Sony Group Corporation. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import argparse | ||
import pathlib | ||
|
||
from environment import TemplateEnv | ||
from models import TemplatePolicy, TemplateQFunction | ||
|
||
import nnabla_rl.hooks as H | ||
from nnabla_rl.algorithm import AlgorithmConfig | ||
from nnabla_rl.algorithms import SAC, SACConfig | ||
from nnabla_rl.builders import ModelBuilder | ||
from nnabla_rl.environments.environment_info import EnvironmentInfo | ||
from nnabla_rl.utils.reproductions import set_global_seed | ||
|
||
|
||
class QFunctionBuilder(ModelBuilder): | ||
def build_model(self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs): | ||
return TemplateQFunction(scope_name) | ||
|
||
|
||
class PolicyBuilder(ModelBuilder): | ||
def build_model(self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs): | ||
return TemplatePolicy(scope_name, env_info.action_dim) | ||
|
||
|
||
def create_env(seed): | ||
env = TemplateEnv() | ||
env.seed(seed) | ||
return env | ||
|
||
|
||
def build_algorithm(env, args): | ||
# Select a RL algorithm of your choice and build its instance. | ||
# In this sample, we selected Soft-Actor Critic(SAC). | ||
|
||
# Setting the configuration of the algorithm. | ||
config = SACConfig(gpu_id=args.gpu) | ||
|
||
# Algorithm requires not only the configuration but also the target models to train. | ||
algorithm = SAC(env, config, q_function_builder=QFunctionBuilder(), policy_builder=PolicyBuilder()) | ||
|
||
return algorithm | ||
|
||
|
||
def run_training(args): | ||
# It is optional but we set the random seed to ensure reproducibility. | ||
set_global_seed(args.seed) | ||
|
||
# Create two different environments to avoid using training environment in the evaluation. | ||
train_env = create_env(args.seed) | ||
eval_env = create_env(args.seed + 100) | ||
algorithm = build_algorithm(train_env, args) | ||
|
||
# nnabla-rl has a convenient feature that enables running additional operations | ||
# in each specified timing (iteration). | ||
# This hook evaluates the training model every "timing" iteration steps. | ||
evaluation_hook = H.EvaluationHook(eval_env, timing=1000) | ||
|
||
# Adding this hook to just check that the training runs properly. | ||
# This hook prints current iteration number every "timing" (=100) steps. | ||
iteration_num_hook = H.IterationNumHook(timing=100) | ||
|
||
# Save trained parameters every "timing" steps. | ||
# Without this, the parameters will not be saved. | ||
# We recommend saving parameters at every evaluation timing. | ||
outdir = pathlib.Path(args.save_dir) / 'snapshots' | ||
save_snapshot_hook = H.SaveSnapshotHook(outdir=outdir, timing=1000) | ||
|
||
# All instantiated hooks should be set at once. | ||
# set_hooks will override previously set hooks. | ||
algorithm.set_hooks([evaluation_hook, iteration_num_hook, save_snapshot_hook]) | ||
|
||
algorithm.train(train_env) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--gpu', type=int, default=-1) | ||
parser.add_argument('--save-dir', type=str, default=str(pathlib.Path(__file__).parent)) | ||
parser.add_argument('--seed', type=int, default=0) | ||
args = parser.parse_args() | ||
|
||
run_training(args) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |