Skip to content

Commit

Permalink
Fixing testing code errors
Browse files Browse the repository at this point in the history
Fix np.int deprecation errors
Fix non terminal's shape on test
  • Loading branch information
ishihara-y committed Jul 21, 2023
1 parent a09f7ad commit 966ef43
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def array_and_dtype(mp_arrays_item):
def _compute_action(self, s, *, begin_of_episode=False):
action, info = self._exploration_actor(s, begin_of_episode=begin_of_episode)
if self._env_info.is_discrete_action_env():
return np.int(action), info
return np.int32(action), info
else:
return action, info

Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def _compute_action(self, s, *, begin_of_episode=False):
info = {}
info['log_prob'] = log_prob
if self._env_info.is_discrete_action_env():
return np.int(action), info
return np.int32(action), info
else:
return action, info

Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/distributions/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, z):
self._distribution = NF.concatenate(self._p, 1 - self._p)
self._log_distribution = NF.concatenate(self._log_p, self._log_1_minus_p)

labels = np.array([1, 0], dtype=np.int)
labels = np.array([1, 0], dtype=np.int32)
labels = nn.Variable.from_numpy_array(labels)
self._labels = labels
for size in reversed(z.shape[0:-1]):
Expand Down
2 changes: 1 addition & 1 deletion nnabla_rl/distributions/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, z):
self._num_class = z.shape[-1]

labels = np.array(
[label for label in range(self._num_class)], dtype=np.int)
[label for label in range(self._num_class)], dtype=np.int32)
self._labels = nn.Variable.from_numpy_array(labels)
self._actions = self._labels
for size in reversed(z.shape[0:-1]):
Expand Down
4 changes: 2 additions & 2 deletions tests/algorithms/test_common_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021,2022 Sony Group Corporation.
# Copyright 2021,2022,2023 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -55,7 +55,7 @@ def _collect_dummy_experience(self, num_episodes=1, episode_length=3, tupled_sta
r = np.ones(1, )
non_terminal = np.ones(1, )
if i == episode_length-1:
non_terminal = 0
non_terminal = np.zeros(1, )
experience.append((s_current, a, r, non_terminal, s_next))
return experience

Expand Down

0 comments on commit 966ef43

Please sign in to comment.