Skip to content

Commit

Permalink
Decision tree training.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Nov 9, 2022
1 parent 9033656 commit cd25c2e
Show file tree
Hide file tree
Showing 187 changed files with 2,356 additions and 328 deletions.
3 changes: 3 additions & 0 deletions BMR/Register.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ class Phase
template <class T>
static void ands(T& processor, const vector<int>& args) { processor.ands(args); }
template <class T>
static void andrsvec(T& processor, const vector<int>& args)
{ processor.andrsvec(args); }
template <class T>
static void xors(T& processor, const vector<int>& args) { processor.xors(args); }
template <class T>
static void inputb(T& processor, const vector<int>& args) { processor.input(args); }
Expand Down
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.3.4 (Nov 9, 2022)

- Decision tree learning
- Optimized oblivious shuffle in Rep3
- Optimized daBit generation in Rep3 and semi-honest HE-based 2PC
- Optimized element-vector AND in SemiBin
- Optimized input protocol in Shamir-based protocols
- Square-root ORAM (@Quitlox)
- Improved ORAM in binary circuits
- UTF-8 outputs

## 0.3.3 (Aug 25, 2022)

- Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate
Expand Down
4 changes: 3 additions & 1 deletion CONFIG
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ endif
# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5

LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS)
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
LDLIBS += -lboost_system -lssl -lcrypto

CFLAGS += -I./local/include

ifeq ($(USE_NTL),1)
CFLAGS += -DUSE_NTL
LDLIBS := -lntl $(LDLIBS)
Expand Down Expand Up @@ -100,5 +103,4 @@ ifeq ($(USE_KOS),1)
CFLAGS += -DUSE_KOS
else
CFLAGS += -std=c++17
LDLIBS += -llibOTe -lcryptoTools
endif
49 changes: 49 additions & 0 deletions Compiler/GC/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import Compiler.tools as tools
import collections
import itertools
import math

class SecretBitsAF(base.RegisterArgFormat):
reg_type = 'sb'
Expand Down Expand Up @@ -50,6 +51,7 @@ class ClearBitsAF(base.RegisterArgFormat):
INPUTBVEC = 0x247,
SPLIT = 0x248,
CONVCBIT2S = 0x249,
ANDRSVEC = 0x24a,
XORCBI = 0x210,
BITDECC = 0x211,
NOTCB = 0x212,
Expand Down Expand Up @@ -155,6 +157,52 @@ class andrs(BinaryVectorInstruction):

def add_usage(self, req_node):
req_node.increment(('bit', 'triple'), sum(self.args[::4]))
req_node.increment(('bit', 'mixed'),
sum(int(math.ceil(x / 64)) for x in self.args[::4]))

class andrsvec(base.VarArgsInstruction, base.Mergeable,
base.DynFormatInstruction):
""" Constant-vector AND of secret bit registers (vectorized version).
:param: total number of arguments to follow (int)
:param: number of arguments to follow for one operation /
operation vector size plus three (int)
:param: vector size (int)
:param: result vector (sbit)
:param: (repeat)...
:param: constant operand (sbits)
:param: vector operand
:param: (repeat)...
:param: (repeat from number of arguments to follow for one operation)...
"""
code = opcodes['ANDRSVEC']

def __init__(self, *args, **kwargs):
super(andrsvec, self).__init__(*args, **kwargs)
for i, n in self.bases(iter(self.args)):
size = self.args[i + 1]
for x in self.args[i + 2:i + n]:
assert x.n == size

@classmethod
def dynamic_arg_format(cls, args):
yield 'int'
for i, n in cls.bases(args):
yield 'int'
n_args = (n - 3) // 2
assert n_args > 0
for j in range(n_args):
yield 'sbw'
for j in range(n_args + 1):
yield 'sb'
yield 'int'

def add_usage(self, req_node):
for i, n in self.bases(iter(self.args)):
size = self.args[i + 1]
req_node.increment(('bit', 'triple'), size * (n - 3) // 2)
req_node.increment(('bit', 'mixed'), size)

class ands(BinaryVectorInstruction):
""" Bitwise AND of secret bit register vector.
Expand Down Expand Up @@ -605,6 +653,7 @@ def dynamic_arg_format(cls, args):
for i, n in cls.bases(args):
yield 'int'
yield 'p'
assert n > 3
for j in range(n - 3):
yield 'sbw'
yield 'int'
Expand Down
59 changes: 49 additions & 10 deletions Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ class sbitvec(_vec, _bit):
You can access the rows by member :py:obj:`v` and the columns by calling
:py:obj:`elements`.
There are three ways to create an instance:
There are four ways to create an instance:
1. By transposition::
Expand Down Expand Up @@ -685,6 +685,11 @@ class sbitvec(_vec, _bit):
This should output::
[1, 0, 1]
4. Private input::
x = sbitvec.get_type(32).get_input_from(player)
"""
bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v)))
is_clear = False
Expand Down Expand Up @@ -904,6 +909,34 @@ def half_adder(self, other):
def __mul__(self, other):
if isinstance(other, int):
return self.from_vec(x * other for x in self.v)
if isinstance(other, sbitvec):
if len(other.v) == 1:
other = other.v[0]
elif len(self.v) == 1:
self, other = other, self.v[0]
else:
raise CompilerError('no operand of lenght 1: %d/%d',
(len(self.v), len(other.v)))
if not isinstance(other, sbits):
return NotImplemented
ops = []
for x in self.v:
if not util.is_zero(x):
assert x.n == other.n
ops.append(x)
if ops:
prods = [sbits.get_type(other.n)() for i in ops]
inst.andrsvec(3 + 2 * len(ops), other.n, *prods, other, *ops)
res = []
i = 0
for x in self.v:
if util.is_zero(x):
res.append(0)
else:
res.append(prods[i])
i += 1
return sbitvec.from_vec(res)
__rmul__ = __mul__
def __add__(self, other):
return self.from_vec(x + y for x, y in zip(self.v, other))
def bit_and(self, other):
Expand Down Expand Up @@ -945,6 +978,13 @@ def expand(self, other, expand=True):
else:
res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v])
return res
def demux(self):
if len(self) == 1:
return sbitvec.from_vec([self.v[0].bit_not(), self.v[0]])
a = sbitvec.from_vec(self.v[:len(self) // 2]).demux()
b = sbitvec.from_vec(self.v[len(self) // 2:]).demux()
prod = [a * bb for bb in b.v]
return sbitvec.from_vec(reduce(operator.add, (x.v for x in prod)))

class bit(object):
n = 1
Expand Down Expand Up @@ -1243,20 +1283,19 @@ def __mul__(self, other):
return other * self.v[0]
elif isinstance(other, sbitfixvec):
return NotImplemented
_, other_bits = self.expand(other, False)
my_bits, other_bits = self.expand(other, False)
matrix = []
m = float('inf')
for x in itertools.chain(self.v, other_bits):
for x in itertools.chain(my_bits, other_bits):
try:
m = min(m, x.n)
except:
pass
if m == 1:
op = operator.mul
else:
op = operator.and_
matrix = []
for i, b in enumerate(other_bits):
matrix.append([op(x, b) for x in self.v[:len(self.v)-i]])
if m == 1:
matrix.append([x * b for x in my_bits[:len(self.v)-i]])
else:
matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v)
v = sbitint.wallace_tree_from_matrix(matrix)
return self.from_vec(v[:len(self.v)])
__rmul__ = __mul__
Expand Down Expand Up @@ -1366,7 +1405,7 @@ class cls(_fix):
cls.set_precision(f, k)
return cls._new(cls.int_type(other), k, f)

class sbitfixvec(_fix):
class sbitfixvec(_fix, _vec):
""" Vector of fixed-point numbers for parallel binary computation.
Use :py:obj:`set_precision()` to change the precision.
Expand Down
6 changes: 5 additions & 1 deletion Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def longest_paths_merge(self):
instructions = self.instructions
merge_nodes = self.open_nodes
depths = self.depths
self.req_num = defaultdict(lambda: 0)
if not merge_nodes:
return 0

Expand All @@ -281,6 +282,7 @@ def longest_paths_merge(self):
print('Merging %d %s in round %d/%d' % \
(len(merge), t.__name__, i, len(merges)))
self.do_merge(merge)
self.req_num[t.__name__, 'round'] += 1

preorder = None

Expand Down Expand Up @@ -530,7 +532,9 @@ def eliminate_dead_code(self):
can_eliminate_defs = True
for reg in inst.get_def():
for dup in reg.duplicates:
if not dup.can_eliminate:
if not (dup.can_eliminate and reduce(
operator.and_,
(x.can_eliminate for x in dup.vector), True)):
can_eliminate_defs = False
break
# remove if instruction has result that isn't used
Expand Down
2 changes: 0 additions & 2 deletions Compiler/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ def sha3_256(x):
0x4a43f8804b0ad882fa493be44dff80f562d661a05647c15166d71ebff8c6ffa7
0xf0d7aa0ab2d92d580bb080e17cbb52627932ba37f085d3931270d31c39357067
Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only
implemented for computation modulo a power of two.
"""

global Keccak_f
Expand Down
3 changes: 2 additions & 1 deletion Compiler/circuit_oram.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from Compiler.path_oram import *
from Compiler.oram import *
from Compiler.path_oram import PathORAM, XOR
from Compiler.util import bit_compose

def first_diff(a_bits, b_bits):
Expand Down
7 changes: 7 additions & 0 deletions Compiler/compilerLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ def build_option_parser(self):
default=defaults.binary,
help="bit length of sint in binary circuit (default: 0 for arithmetic)",
)
parser.add_option(
"-G",
"--garbled-circuit",
dest="garbled",
action="store_true",
help="compile for binary circuits only (default: false)",
)
parser.add_option(
"-F",
"--field",
Expand Down
Loading

0 comments on commit cd25c2e

Please sign in to comment.