Skip to content

Commit

Permalink
Add template reinforcement learning project as example
Browse files Browse the repository at this point in the history
  • Loading branch information
ishihara-y committed Aug 21, 2023
1 parent 81c1833 commit 176b9b8
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/rl_project_template/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
snapshots/**
57 changes: 57 additions & 0 deletions examples/rl_project_template/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# RL project template

This directory provides a minimum template for reinforcement learning (RL) projects.

## How to use

Start using this template by copying and pasting the entire directory and modifying the python files depending on your needs.
We created 3 main python files to get started.

- environment.py
- models.py
- training.py

See below descriptions for the usages of each file.

### environment.py

Sample implementation of an environment class.
Only the basic implementation is provided.
You will need to modify and add extra implementations to the file to
properly to make the algorithm solve your problem.

### models.py

Sample implementation of DNN models to be learned by the RL algorithm.
You may also need to modify the models to get desired result.

### training.py

Main file that runs the training with RL algorithm.
See this file to understand the basics of how to implement the training process.

## How to run the script

Run the training.py script.
By default, it runs on cpu.

```sh
$ python training.py
```

To run on gpu, first, install nnabla-ext-cuda as follows.

```sh
# $ pip install nnabla-ext-cuda[cuda_version]
# Example: when you have installed CUDA-11.6 on your machine.
$ pip install nnabla-ext-cuda116
```

For the installation of nnabla-ext-cuda see also [here](https://github.com/sony/nnabla) and [here](https://github.com/sony/nnabla-ext-cuda).

Then, run the script by specifying the gpu id.

```sh
# This will run the script on gpu id 0.
$ python training.py --gpu=0
```
74 changes: 74 additions & 0 deletions examples/rl_project_template/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2023 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
from gym.envs.registration import EnvSpec


class TemplateEnv(gym.Env):
def __init__(self, max_episode_steps=100):
# max_episode_steps is the maximum possible steps that the rl agent interacts with this environment.
# You can set this value to None if the is no limits.
# The first argument is the name of this environment used when registering this environment to gym.
self.spec = EnvSpec('template-v0', max_episode_steps=max_episode_steps)
self._episode_steps = 0

# Use gym's spaces to define the shapes and ranges of states and actions.
# observation_space: definition of states's shape and its ranges
# action_space: definition of actions's shape and its ranges
observation_shape = (10, ) # Example 10 dimensional state with range of [0.0, 1.0] each.
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=observation_shape)
action_shape = (1, ) # 1 dimensional continuous action with range of [0.0, 1.0].
self.action_space = gym.spaces.Box(low=0.0, high=1.0, shape=action_shape)

def reset(self):
# Reset the entire environment and return the initial state after the reset.
# In this example, we just reset the steps to 0 and return a random state as initial state.
self._episode_steps = 0

return self.observation_space.sample()

def step(self, action):
# step is the core of the environment class.
# You will need to compute the next_state according to the action received and
# the reward to be received.
# You will also need to set and return a done flag if the next_state is the end of the episode.

# Increment episode steps
self._episode_steps += 1

next_state = self._compute_next_state(action)
reward = self._compute_reward(next_state, action)

# Here we set done flag to false if current episode steps exceeds
# the max episode steps defined in this environment.
done = False if self.spec.max_episode_steps is None else (self.spec.max_episode_steps <= self._episode_steps)

# info is a convenient dictionary that you can fill
# any additional information to return back to the RL algorithm.
# If you have no extra information, return a empty dictionary.
info = {}
return next_state, reward, done, info

def _compute_next_state(self, action):
# In this template, we just return a randomly sampled state.
# But in real application, you should compute the next_state according to the given action.
return self.observation_space.sample()

def _compute_reward(self, state, action):
# In this template, we implemented a easy to understand reward function.
# But in real an application, the design of this reward function is extremely important.
if self._episode_steps < self.spec.max_episode_steps:
return 0
else:
return 1
71 changes: 71 additions & 0 deletions examples/rl_project_template/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2023 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import nnabla as nn
import nnabla.functions as NF
import nnabla.parametric_functions as NPF
import nnabla_rl.distributions as D
from nnabla_rl.distributions.distribution import Distribution
from nnabla_rl.models import ContinuousQFunction, StochasticPolicy


class TemplateQFunction(ContinuousQFunction):
# This is a sample QFunction model to be used in Soft-Actor Critic(SAC) algorithm.
# Model to implement depends on the RL algorithm.
def __init__(self, scope_name: str):
super().__init__(scope_name=scope_name)

def q(self, state: nn.Variable, action: nn.Variable) -> nn.Variable:
# Using the input state and action, here we implemented a simple model that outputs
# q-value (1 dim) using these inputs.
# Modify this model to better perform on your environment.
# The state's and action's shape are both (batch size, variable's shape)
with nn.parameter_scope(self.scope_name):
h = NF.concatenate(state, action)
h = NPF.affine(h, n_outmaps=256, name="linear1")
h = NF.relu(x=h)
h = NPF.affine(h, n_outmaps=256, name="linear2")
h = NF.relu(x=h)
h = NPF.affine(h, n_outmaps=1, name="linear3")
return h


class TemplatePolicy(StochasticPolicy):
# This is a sample stochstic policy model to be used in Soft-Actor Critic(SAC) algorithm.
# Stochastic policy is a type of policy that outputs the action's probability distribution
# instead of action itself.
# Model to implement depends on the RL algorithm.
def __init__(self, scope_name: str, action_dim=1):
super().__init__(scope_name=scope_name)
self._action_dim = action_dim

def pi(self, state: nn.Variable) -> Distribution:
# The state's shape is (batch size, state's shape)
# state's shape is same as the one set in the environment's implementation.

with nn.parameter_scope(self.scope_name):
h = NPF.affine(state, n_outmaps=256, name="linear1")
h = NF.relu(x=h)
h = NPF.affine(h, n_outmaps=256, name="linear2")
h = NF.relu(x=h)
h = NPF.affine(h, n_outmaps=self._action_dim*2, name="linear3")
reshaped = NF.reshape(h, shape=(-1, 2, self._action_dim))

# Split the output into mean and variance of the Gaussian distribution.
mean, ln_var = NF.split(reshaped, axis=1)

# Check that output shape is as expected
assert mean.shape == ln_var.shape
assert mean.shape == (state.shape[0], self._action_dim)
# SquashedGaussian is a distribution that applies tanh to the output of Gaussian distribution.
return D.SquashedGaussian(mean=mean, ln_var=ln_var)
99 changes: 99 additions & 0 deletions examples/rl_project_template/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import pathlib

from environment import TemplateEnv
from models import TemplatePolicy, TemplateQFunction

import nnabla_rl.hooks as H
from nnabla_rl.algorithm import AlgorithmConfig
from nnabla_rl.algorithms import SAC, SACConfig
from nnabla_rl.builders import ModelBuilder
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.utils.reproductions import set_global_seed


class QFunctionBuilder(ModelBuilder):
def build_model(self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs):
return TemplateQFunction(scope_name)


class PolicyBuilder(ModelBuilder):
def build_model(self, scope_name: str, env_info: EnvironmentInfo, algorithm_config: AlgorithmConfig, **kwargs):
return TemplatePolicy(scope_name, env_info.action_dim)


def create_env(seed):
env = TemplateEnv()
env.seed(seed)
return env


def build_algorithm(env, args):
# Select a RL algorithm of your choice and build its instance.
# In this sample, we selected Soft-Actor Critic(SAC).

# Setting the configuration of the algorithm.
config = SACConfig(gpu_id=args.gpu)

# Algorithm requires not only the configuration but also the target models to train.
algorithm = SAC(env, config, q_function_builder=QFunctionBuilder(), policy_builder=PolicyBuilder())

return algorithm


def run_training(args):
# It is optional but we set the random seed to ensure reproducibility.
set_global_seed(args.seed)

# Create two different environments to avoid using training environment in the evaluation.
train_env = create_env(args.seed)
eval_env = create_env(args.seed + 100)
algorithm = build_algorithm(train_env, args)

# nnabla-rl has a convenient feature that enables running additional operations
# in each specified timing (iteration).
# This hook evaluates the training model every "timing" iteration steps.
evaluation_hook = H.EvaluationHook(eval_env, timing=1000)

# Adding this hook to just check that the training runs properly.
# This hook prints current iteration number every "timing" (=100) steps.
iteration_num_hook = H.IterationNumHook(timing=100)

# Save trained parameters every "timing" steps.
# Without this, the parameters will not be saved.
# We recommend saving parameters at every evaluation timing.
outdir = pathlib.Path(args.save_dir) / 'snapshots'
save_snapshot_hook = H.SaveSnapshotHook(outdir=outdir, timing=1000)

# All instantiated hooks should be set at once.
# set_hooks will override previously set hooks.
algorithm.set_hooks([evaluation_hook, iteration_num_hook, save_snapshot_hook])

algorithm.train(train_env)


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--save-dir', type=str, default=str(pathlib.Path(__file__).parent))
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()

run_training(args)


if __name__ == '__main__':
main()

0 comments on commit 176b9b8

Please sign in to comment.