Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pull request to test CI #33

Merged
merged 9 commits into from
Oct 5, 2023
31 changes: 30 additions & 1 deletion flatland/core/transition_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,38 @@ def set_transition(self, cell_id, transition_index, new_transition, remove_deade
send_infrastructure_data_change_signal_to_reset_lru_cache()
#assert len(cell_id) == 3, \
# 'GridTransitionMap.set_transition() ERROR: cell_id tuple must have length 3.'

nDir = cell_id[2]
#print(cell_id, type(nDir))
if type(nDir) == np.ndarray:
# I can't work out how to dump a complete backtrace here
#import traceback
#try:
# assert type(nDir)==int, "cell direction is not an int"
#except Exception as e:
# traceback.print_exception(e)
# traceback.print_tb(e.__traceback__)
# print(traceback.format_exc())
print("fixing nDir:", cell_id, nDir)
nDir = int(nDir[0])

if type(transition_index) is not int:
print("fixing transition_index:", cell_id, transition_index)
if type(transition_index) == np.ndarray:
transition_index = int(transition_index[0])
else:
print("transition_index type:", type(transition_index))
transition_index = int(transition_index)

if type(new_transition) is not int:
print("fixing new_transition:", cell_id, new_transition)
new_transition = int(new_transition[0])

#print("fixed:", cell_id, type(nDir), transition_index, new_transition, remove_deadends)

hagrid67 marked this conversation as resolved.
Show resolved Hide resolved
self.grid[cell_id[0]][cell_id[1]] = self.transitions.set_transition(
self.grid[cell_id[0:2]],
cell_id[2],
nDir, # cell_id[2],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nDir, # cell_id[2],
nDir,

transition_index,
new_transition,
remove_deadends)
Expand Down
30 changes: 19 additions & 11 deletions tests/test_flatland_envs_rail_env_shortest_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from flatland.utils.simple_rail import make_disconnected_simple_rail, make_simple_rail_with_alternatives
from flatland.envs.persistence import RailEnvPersister

from typing import List

def test_get_shortest_paths_unreachable():
rail, rail_map, optionals = make_disconnected_simple_rail()
Expand Down Expand Up @@ -112,13 +113,21 @@ def test_get_shortest_paths():
Waypoint(position=(2, 4), direction=3),
Waypoint(position=(2, 3), direction=3),
Waypoint(position=(2, 2), direction=3),
Waypoint(position=(2, 1), direction=3)]

# Change a point to test the assertion works :)
Waypoint(position=(2, 1), direction=3)
#Waypoint(position=(2, 2), direction=3)
]
hagrid67 marked this conversation as resolved.
Show resolved Hide resolved
}

for agent_handle in expected:
assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
"[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
for iA, lWP in expected.items():
_compare_paths(iA, actual[iA], lWP)

def _compare_paths(iAgent:int, actual:List[Waypoint], expected:List[Waypoint]):
hagrid67 marked this conversation as resolved.
Show resolved Hide resolved
assert len(actual) == len(expected), f"Lengths differ: actual={len(actual)}, expected={len(expected)}"
for iWP, (wpA, wpE) in enumerate(zip(actual, expected)):
assert wpA.position == wpE.position, f"Agent {iAgent} Waypoints at step {iWP} differ: actual={wpA.position}, expected={wpE.position}"
assert wpA.direction == wpE.direction, f"Agent {iAgent} Waypoint directions at step {iWP} differ:actual={wpA.direction}, expected={wpE.direction}"

# todo file test_002.pkl has to be generated automatically
def test_get_shortest_paths_max_depth():
Expand All @@ -138,9 +147,9 @@ def test_get_shortest_paths_max_depth():
]
}

for agent_handle in expected:
assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
"[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
for iA, lWP in expected.items():
_compare_paths(iA, actual[iA], lWP)



# todo file Level_distance_map_shortest_path.pkl has to be generated automatically
Expand Down Expand Up @@ -228,9 +237,8 @@ def test_get_shortest_paths_agent_handle():
direction=3)
]}

for agent_handle in expected:
assert np.array_equal(actual[agent_handle], expected[agent_handle]), \
"[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle])
for iA, lWP in expected.items():
_compare_paths(iA, actual[iA], lWP)


def test_get_k_shortest_paths(rendering=False):
Expand Down Expand Up @@ -310,7 +318,7 @@ def test_get_k_shortest_paths(rendering=False):
Waypoint(position=(3, 9), direction=0))
])

assert actual == expected, "actual={},expected={}".format(actual, expected)
assert actual == expected, "Sets are different:\nactual={},\nexpected={}".format(actual, expected)

def main():
test_get_shortest_paths()
Expand Down
21 changes: 21 additions & 0 deletions tests/test_flatland_envs_sparse_rail_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,27 @@ def test_sparse_generator_changes_to_grid_mode():
grid_mode=False
), line_generator=sparse_line_generator(), number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv())

# Catch warnings and check that a warning *IS* raised
with warnings.catch_warnings(record=True) as w:
rail_env.reset(True, True, random_seed=15)
assert "[WARNING]" in str(w[-1].message)



def main():
# Make warnings into errors, to generate stack backtraces
warnings.simplefilter("error",) # category=DeprecationWarning)

# Then run selected tests.
test_sparse_rail_generator()
#test_sparse_rail_generator_deterministic()
#test_rail_env_action_required_info()
#test_rail_env_malfunction_speed_info()
#test_sparse_generator_with_too_man_cities_does_not_break_down()
#test_sparse_generator_with_illegal_params_aborts()
#test_sparse_generator_changes_to_grid_mode()

if __name__ == "__main__":
main()

5 changes: 3 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
[tox]
env_list = py{37,38,39}
env_list = py{37,38,39,310}

[gh-actions]
python =
3.7: py37
3.8: py38
3.9: py39
3.10: py310

[testenv]
set_env =
Expand All @@ -17,7 +18,7 @@ pass_env =
HTTP_PROXY
HTTPS_PROXY

[testenv:py{37,38,39}]
[testenv:py{37,38,39,310}]
platform = linux|linux2|darwin
deps =
-r requirements_dev.txt
Expand Down