-
Notifications
You must be signed in to change notification settings - Fork 0
/
predicates.py
76 lines (56 loc) · 1.7 KB
/
predicates.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import dataclasses
from dataclasses import dataclass
from enum import IntEnum
import numpy as np
class AtomColor(IntEnum):
NO_COLOR = 0
BLUE = 1
GREEN = 2
CYAN = 3
RED = 4
PURPLE = 5
YELLOW = 6
WHITE = 7
def nice_name(self):
return self.name.lower() if self else ""
def to_rgba(self):
return tuple(float(ch) for ch in np.binary_repr(self, width=3)) + tuple([1.0])
class AtomObject(IntEnum):
ACTOR = 0
TABLE = 1
CUBE = 2
SPHERE = 3
PYRAMID = 4
def nice_name(self):
return self.name.lower()
class AtomRelation(IntEnum):
ON = 0
NEAR = 1
ON_LEFT_SIDE_OF = 2
ON_RIGHT_SIDE_OF = 3
ON_NEAR_SIDE_OF = 4
ON_FAR_SIDE_OF = 5
IN_CENTER_OF = 6
def nice_name(self):
return self.name.lower()
def to_one_hot(e: IntEnum):
return np.eye(len(e.__class__.__members__))[e]
def from_one_hot(arr: np.array):
return np.nonzero(arr)[0][0]
@dataclass(frozen=True)
class AtomPredicate:
relation: AtomRelation
obj: AtomObject
obj_color: AtomColor
subj: AtomObject
subj_color: AtomColor
def __repr__(self):
return f"""({self.obj_color.nice_name()} {self.obj.nice_name()} {self.relation.nice_name()} {self.subj_color.nice_name()} {self.subj.nice_name()})""".replace(
" ", " "
)
def to_one_hot(self):
return np.concatenate([to_one_hot(e) for e in dataclasses.astuple(self)])
def from_one_hot(arr: np.array):
fields = [f.type for f in dataclasses.fields(AtomPredicate)]
indices = np.cumsum([len(e.__members__) for e in fields])
return AtomPredicate(*(E(from_one_hot(arr)) for arr, E in zip(np.split(arr, indices), fields)))