From ec30982ae6a5b8178a8ef204727f391371e8221a Mon Sep 17 00:00:00 2001 From: adrian_egli Date: Sat, 28 Oct 2023 22:59:10 +0200 Subject: [PATCH 1/2] Removed dependency on gym.utils --- flatland/envs/rail_env.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 428483dd..71563f37 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -5,7 +5,6 @@ from typing import List, Optional, Dict, Tuple import numpy as np -from gym.utils import seeding from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder @@ -203,8 +202,38 @@ def __init__(self, self.motionCheck = ac.MotionCheck() + """Set of random number generator functions: seeding, generator, hashing seeds.""" + from typing import Any, Optional, Tuple + + import numpy as np + + from gym import error + + @staticmethod + def _np_random(seed: Optional[int] = None) -> Tuple[np.random.Generator, Any]: + """Generates a random number generator from the seed and returns the Generator and seed. + + Args: + seed: The seed used to create the generator + + Returns: + The generator and resulting seed + + Raises: + Error: Seed must be a non-negative integer or omitted + """ + if seed is not None and not (isinstance(seed, int) and 0 <= seed): + raise print(f"Seed must be a non-negative integer or omitted, not {seed}") + + seed_seq = np.random.SeedSequence(seed) + np_seed = seed_seq.entropy + rng = np.random.Generator(np.random.PCG64(seed_seq)) + return rng, np_seed + + RNG = RandomNumberGenerator = np.random.Generator + def _seed(self, seed): - self.np_random, seed = seeding.np_random(seed) + self.np_random, seed = RailEnv._np_random(seed) random.seed(seed) self.random_seed = seed From f2efc1729c94590e9c3f7b8316534febea6e3645 Mon Sep 17 00:00:00 2001 From: adrian_egli Date: Sat, 28 Oct 2023 23:25:26 +0200 Subject: [PATCH 2/2] Removed dependency on gym.utils --- flatland/action_plan/__init__.py | 3 ++ flatland/contrib/__init__.py | 3 ++ flatland/envs/rail_env.py | 33 +----------- flatland/utils/seeding.py | 91 ++++++++++++++++++++++++++++++++ pyproject.toml | 1 - 5 files changed, 99 insertions(+), 32 deletions(-) create mode 100644 flatland/contrib/__init__.py create mode 100644 flatland/utils/seeding.py diff --git a/flatland/action_plan/__init__.py b/flatland/action_plan/__init__.py index e69de29b..53dc195e 100644 --- a/flatland/action_plan/__init__.py +++ b/flatland/action_plan/__init__.py @@ -0,0 +1,3 @@ +from warnings import warn + +warn('The action_plan is deprecated', DeprecationWarning, stacklevel=2) diff --git a/flatland/contrib/__init__.py b/flatland/contrib/__init__.py new file mode 100644 index 00000000..4a851297 --- /dev/null +++ b/flatland/contrib/__init__.py @@ -0,0 +1,3 @@ +from warnings import warn + +warn('The contrib is deprecated', DeprecationWarning, stacklevel=2) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 71563f37..32b1a28e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -5,6 +5,7 @@ from typing import List, Optional, Dict, Tuple import numpy as np +from flatland.utils import seeding from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder @@ -202,38 +203,8 @@ def __init__(self, self.motionCheck = ac.MotionCheck() - """Set of random number generator functions: seeding, generator, hashing seeds.""" - from typing import Any, Optional, Tuple - - import numpy as np - - from gym import error - - @staticmethod - def _np_random(seed: Optional[int] = None) -> Tuple[np.random.Generator, Any]: - """Generates a random number generator from the seed and returns the Generator and seed. - - Args: - seed: The seed used to create the generator - - Returns: - The generator and resulting seed - - Raises: - Error: Seed must be a non-negative integer or omitted - """ - if seed is not None and not (isinstance(seed, int) and 0 <= seed): - raise print(f"Seed must be a non-negative integer or omitted, not {seed}") - - seed_seq = np.random.SeedSequence(seed) - np_seed = seed_seq.entropy - rng = np.random.Generator(np.random.PCG64(seed_seq)) - return rng, np_seed - - RNG = RandomNumberGenerator = np.random.Generator - def _seed(self, seed): - self.np_random, seed = RailEnv._np_random(seed) + self.np_random, seed = seeding.np_random(seed) random.seed(seed) self.random_seed = seed diff --git a/flatland/utils/seeding.py b/flatland/utils/seeding.py new file mode 100644 index 00000000..af517515 --- /dev/null +++ b/flatland/utils/seeding.py @@ -0,0 +1,91 @@ +import hashlib +import os +import struct + +import numpy as np + +def np_random(seed=None): + if seed is not None and not (isinstance(seed, int) and 0 <= seed): + raise Exception('Seed must be a non-negative integer or omitted, not {}'.format(seed)) + + seed = create_seed(seed) + + rng = np.random.RandomState() + rng.seed(_int_list_from_bigint(hash_seed(seed))) + return rng, seed + + +def hash_seed(seed=None, max_bytes=8): + """Any given evaluation is likely to have many PRNG's active at + once. (Most commonly, because the environment is running in + multiple processes.) There's literature indicating that having + linear correlations between seeds of multiple PRNG's can correlate + the outputs: + + http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/ + http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be + http://dl.acm.org/citation.cfm?id=1276928 + + Thus, for sanity we hash the seeds before using them. (This scheme + is likely not crypto-strength, but it should be good enough to get + rid of simple correlations.) + + Args: + seed (Optional[int]): None seeds from an operating system specific randomness source. + max_bytes: Maximum number of bytes to use in the hashed seed. + """ + if seed is None: + seed = create_seed(max_bytes=max_bytes) + hash = hashlib.sha512(str(seed).encode('utf8')).digest() + return _bigint_from_bytes(hash[:max_bytes]) + + +def create_seed(a=None, max_bytes=8): + """Create a strong random seed. Otherwise, Python 2 would seed using + the system time, which might be non-robust especially in the + presence of concurrency. + + Args: + a (Optional[int, str]): None seeds from an operating system specific randomness source. + max_bytes: Maximum number of bytes to use in the seed. + """ + # Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py + if a is None: + a = _bigint_from_bytes(os.urandom(max_bytes)) + elif isinstance(a, str): + a = a.encode('utf8') + a += hashlib.sha512(a).digest() + a = _bigint_from_bytes(a[:max_bytes]) + elif isinstance(a, int): + a = a % 2 ** (8 * max_bytes) + else: + raise Exception('Invalid type for seed: {} ({})'.format(type(a), a)) + + return a + + +# TODO: don't hardcode sizeof_int here +def _bigint_from_bytes(bytes): + sizeof_int = 4 + padding = sizeof_int - len(bytes) % sizeof_int + bytes += b'\0' * padding + int_count = int(len(bytes) / sizeof_int) + unpacked = struct.unpack("{}I".format(int_count), bytes) + accum = 0 + for i, val in enumerate(unpacked): + accum += 2 ** (sizeof_int * 8 * i) * val + return accum + + +def _int_list_from_bigint(bigint): + # Special case 0 + if bigint < 0: + raise Exception('Seed must be non-negative, not {}'.format(bigint)) + elif bigint == 0: + return [0] + + ints = [] + while bigint > 0: + bigint, mod = divmod(bigint, 2 ** 32) + ints.append(mod) + return ints diff --git a/pyproject.toml b/pyproject.toml index 922620c6..1ad66b7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,6 @@ dependencies = [ "crowdai_api", "dataclasses", "graphviz", - "gym==0.14.0", "importlib_resources<2.0.0", "ipycanvas", "ipyevents",