Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed dependency on gym.utils #56

Merged
merged 2 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions flatland/action_plan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from warnings import warn

warn('The action_plan is deprecated', DeprecationWarning, stacklevel=2)
3 changes: 3 additions & 0 deletions flatland/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from warnings import warn

warn('The contrib is deprecated', DeprecationWarning, stacklevel=2)
2 changes: 1 addition & 1 deletion flatland/envs/rail_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Optional, Dict, Tuple

import numpy as np
from gym.utils import seeding
from flatland.utils import seeding

from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
Expand Down
91 changes: 91 additions & 0 deletions flatland/utils/seeding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import hashlib
import os
import struct

import numpy as np

def np_random(seed=None):
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise Exception('Seed must be a non-negative integer or omitted, not {}'.format(seed))

seed = create_seed(seed)

rng = np.random.RandomState()
rng.seed(_int_list_from_bigint(hash_seed(seed)))
return rng, seed


def hash_seed(seed=None, max_bytes=8):
"""Any given evaluation is likely to have many PRNG's active at
once. (Most commonly, because the environment is running in
multiple processes.) There's literature indicating that having
linear correlations between seeds of multiple PRNG's can correlate
the outputs:

http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/
http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
http://dl.acm.org/citation.cfm?id=1276928

Thus, for sanity we hash the seeds before using them. (This scheme
is likely not crypto-strength, but it should be good enough to get
rid of simple correlations.)

Args:
seed (Optional[int]): None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the hashed seed.
"""
if seed is None:
seed = create_seed(max_bytes=max_bytes)
hash = hashlib.sha512(str(seed).encode('utf8')).digest()
return _bigint_from_bytes(hash[:max_bytes])


def create_seed(a=None, max_bytes=8):
"""Create a strong random seed. Otherwise, Python 2 would seed using
the system time, which might be non-robust especially in the
presence of concurrency.

Args:
a (Optional[int, str]): None seeds from an operating system specific randomness source.
max_bytes: Maximum number of bytes to use in the seed.
"""
# Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py
if a is None:
a = _bigint_from_bytes(os.urandom(max_bytes))
elif isinstance(a, str):
a = a.encode('utf8')
a += hashlib.sha512(a).digest()
a = _bigint_from_bytes(a[:max_bytes])
elif isinstance(a, int):
a = a % 2 ** (8 * max_bytes)
else:
raise Exception('Invalid type for seed: {} ({})'.format(type(a), a))

return a


# TODO: don't hardcode sizeof_int here
def _bigint_from_bytes(bytes):
sizeof_int = 4
padding = sizeof_int - len(bytes) % sizeof_int
bytes += b'\0' * padding
int_count = int(len(bytes) / sizeof_int)
unpacked = struct.unpack("{}I".format(int_count), bytes)
accum = 0
for i, val in enumerate(unpacked):
accum += 2 ** (sizeof_int * 8 * i) * val
return accum


def _int_list_from_bigint(bigint):
# Special case 0
if bigint < 0:
raise Exception('Seed must be non-negative, not {}'.format(bigint))
elif bigint == 0:
return [0]

ints = []
while bigint > 0:
bigint, mod = divmod(bigint, 2 ** 32)
ints.append(mod)
return ints
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ dependencies = [
"crowdai_api",
"dataclasses",
"graphviz",
"gym==0.14.0",
"importlib_resources<2.0.0",
"ipycanvas",
"ipyevents",
Expand Down