Skip to content

Commit

Permalink
Convolutional neural network training.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Jul 2, 2021
1 parent f35447d commit 99c0549
Show file tree
Hide file tree
Showing 208 changed files with 3,612 additions and 1,862 deletions.
2 changes: 1 addition & 1 deletion BMR/network/Node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void Node::Broadcast2(SendBuffer& msg) {
}

void Node::_identify() {
char* msg = id_msg;
char msg[strlen(ID_HDR)+sizeof(_id)];
memcpy(msg, ID_HDR, strlen(ID_HDR));
memcpy(msg+strlen(ID_HDR), (const char *)&_id, sizeof(_id));
//printf("Node:: identifying myself:\n");
Expand Down
2 changes: 0 additions & 2 deletions BMR/network/Node.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ class Node : public ServerUpdatable, public ClientUpdatable {
std::map<struct sockaddr_in*,int> _clientsmap;
bool* _clients_connected;
NodeUpdatable* _updatable;

char id_msg[strlen(ID_HDR)+sizeof(_id)];
};

#endif /* NETWORK_NODE_H_ */
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.2.5 (Jul 2, 2021)

- Training of convolutional neural networks
- Bit decomposition using edaBits
- Ability to force MAC checks from high-level code
- Ability to close client connection from high-level code
- Binary operators for comparison results
- Faster compilation for emulation
- More documentation
- Fixed security bug: insufficient LowGear secret key randomness
- Fixed security bug: skewed random bit generation

## 0.2.4 (Apr 19, 2021)

- ARM support
Expand Down
4 changes: 2 additions & 2 deletions Compiler/GC/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ class xorm(NonVectorInstruction):
code = opcodes['XORM']
arg_format = ['int','sbw','sb','cb']

class xorcb(NonVectorInstruction):
class xorcb(BinaryVectorInstruction):
""" Bitwise XOR of two single clear bit registers.
:param: result (cbit)
:param: operand (cbit)
:param: operand (cbit)
"""
code = opcodes['XORCB']
arg_format = ['cbw','cb','cb']
arg_format = ['int','cbw','cb','cb']

class xorcbi(NonVectorInstruction):
""" Bitwise XOR of single clear bit register and immediate.
Expand Down
81 changes: 57 additions & 24 deletions Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_type(cls, length):
class bitsn(cls):
n = length
cls.types[length] = bitsn
bitsn.clear_type = cbits.get_type(length)
bitsn.__name__ = cls.__name__ + str(length)
return cls.types[length]
@classmethod
Expand Down Expand Up @@ -115,7 +116,11 @@ def load_mem(cls, address, mem_type=None, size=None):
return res
def store_in_mem(self, address):
self.store_inst[isinstance(address, int)](self, address)
@classmethod
def new(cls, value=None, n=None):
return cls.get_type(n)(value)
def __init__(self, value=None, n=None, size=None):
assert n == self.n or n is None
if size != 1 and size is not None:
raise Exception('invalid size for bit type: %s' % size)
self.n = n or self.n
Expand All @@ -125,7 +130,7 @@ def __init__(self, value=None, n=None, size=None):
if value is not None:
self.load_other(value)
def copy(self):
return type(self)(n=instructions_base.get_global_vector_size())
return type(self).new(n=instructions_base.get_global_vector_size())
def set_length(self, n):
if n > self.n:
raise Exception('too long: %d/%d' % (n, self.n))
Expand Down Expand Up @@ -154,6 +159,8 @@ def load_other(self, other):
bits = other.bit_decompose()
bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits))
other = self.bit_compose(bits)
assert(isinstance(other, type(self)))
assert(other.n == self.n)
self.load_other(other)
except:
raise CompilerError('cannot convert %s/%s from %s to %s' % \
Expand All @@ -176,6 +183,16 @@ def _new_by_number(self, i, size=1):
res.i = i
res.program = self.program
return res
def if_else(self, x, y):
"""
Vectorized oblivious selection::
sb32 = sbits.get_type(32)
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
This will output 1.
"""
return result_conv(x, y)(self & (x ^ y) ^ y)

class cbits(bits):
""" Clear bits register. Helper type with limited functionality. """
Expand All @@ -202,14 +219,16 @@ def store_in_dynamic_mem(self, address):
inst.stmsdci(self, cbits.conv(address))
def clear_op(self, other, c_inst, ci_inst, op):
if isinstance(other, cbits):
res = cbits(n=max(self.n, other.n))
res = cbits.get_type(max(self.n, other.n))()
c_inst(res, self, other)
return res
elif isinstance(other, sbits):
return NotImplemented
else:
if util.is_constant(other):
if other >= 2**31 or other < -2**31:
return op(self, cbits(other))
res = cbits(n=max(self.n, len(bin(other)) - 2))
res = cbits.get_type(max(self.n, len(bin(other)) - 2))()
ci_inst(res, self, other)
return res
else:
Expand All @@ -221,26 +240,33 @@ def clear_op(self, other, c_inst, ci_inst, op):
def __xor__(self, other):
if isinstance(other, (sbits, sbitvec)):
return NotImplemented
elif isinstance(other, cbits):
res = cbits.get_type(max(self.n, other.n))()
assert res.size == self.size
assert res.size == other.size
inst.xorcb(res.n, res, self, other)
return res
else:
self.clear_op(other, inst.xorcb, inst.xorcbi, operator.xor)
return self.clear_op(other, None, inst.xorcbi, operator.xor)
__radd__ = __add__
__rxor__ = __xor__
def __mul__(self, other):
if isinstance(other, cbits):
return NotImplemented
else:
try:
res = cbits(n=min(self.max_length, self.n+util.int_len(other)))
res = cbits.get_type(min(self.max_length,
self.n+util.int_len(other)))()
inst.mulcbi(res, self, other)
return res
except TypeError:
return NotImplemented
def __rshift__(self, other):
res = cbits(n=self.n-other)
res = cbits.new(n=self.n-other)
inst.shrcbi(res, self, other)
return res
def __lshift__(self, other):
res = cbits(n=self.n+other)
res = cbits.get_type(self.n+other)()
inst.shlcbi(res, self, other)
return res
def print_reg(self, desc=''):
Expand Down Expand Up @@ -504,16 +530,6 @@ def trans(cls, rows):
res = [cls.new(n=len(rows)) for i in range(n_columns)]
inst.trans(len(res), *(res + rows))
return res
def if_else(self, x, y):
"""
Vectorized oblivious selection::
sb32 = sbits.get_type(32)
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
This will output 1.
"""
return result_conv(x, y)(self & (x ^ y) ^ y)
@staticmethod
def bit_adder(*args, **kwargs):
return sbitint.bit_adder(*args, **kwargs)
Expand Down Expand Up @@ -610,7 +626,7 @@ def __init__(self, other=None, size=None):
elif isinstance(other, (list, tuple)):
self.v = self.bit_extend(sbitvec(other).v, n)
else:
self.v = sbits(other, n=n).bit_decompose(n)
self.v = sbits.get_type(n)(other).bit_decompose()
assert len(self.v) == n
@classmethod
def load_mem(cls, address):
Expand All @@ -630,6 +646,8 @@ def store_in_mem(self, address):
for i in range(n):
v[i].store_in_mem(address + i)
def reveal(self):
if len(self) > cbits.unit:
return self.elements()[0].reveal()
revealed = [cbit() for i in range(len(self))]
for i in range(len(self)):
try:
Expand Down Expand Up @@ -784,15 +802,23 @@ class bit(object):

def result_conv(x, y):
try:
def f(res):
try:
return t.conv(res)
except:
return res
if util.is_constant(x):
if util.is_constant(y):
return lambda x: x
else:
return type(y).conv
t = type(y)
return f
if util.is_constant(y):
return type(x).conv
t = type(x)
return f
if type(x) is type(y):
return type(x).conv
t = type(x)
return f
except AttributeError:
pass
return lambda x: x
Expand All @@ -807,13 +833,19 @@ def if_else(self, x, y):
This will output 5.
"""
return result_conv(x, y)(self * (x ^ y) ^ y)
assert self.n == 1
diff = x ^ y
if isinstance(diff, cbits):
return result_conv(x, y)(self & (diff) ^ y)
else:
return result_conv(x, y)(self * (diff) ^ y)

class cbit(bit, cbits):
pass

sbits.bit_type = sbit
cbits.bit_type = cbit
sbit.clear_type = cbit

class bitsBlock(oram.Block):
value_type = sbits
Expand Down Expand Up @@ -881,7 +913,7 @@ def round(self, k, m, kappa=None, nearest=None, signed=None):
return self.get_type(k - m).compose(res_bits)
def int_div(self, other, bit_length=None):
k = bit_length or max(self.n, other.n)
return (library.IntDiv(self.extend(k), other.extend(k), k) >> k).cast(k)
return (library.IntDiv(self.cast(k), other.cast(k), k) >> k).cast(k)
def Norm(self, k, f, kappa=None, simplex_flag=False):
absolute_val = abs(self)
#next 2 lines actually compute the SufOR for little indian encoding
Expand Down Expand Up @@ -1100,7 +1132,8 @@ def output(self):
bits = self.v.bit_decompose(self.k)
sign = bits[-1]
v += (sign << (self.k)) * -1
inst.print_float_plainb(v, cbits(-self.f, n=32), cbits(0), cbits(0), cbits(0))
inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0),
cbits(0), cbits(0))

class sbitfix(_fix):
""" Secret signed integer in one binary register.
Expand Down
4 changes: 2 additions & 2 deletions Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def dealloc_reg(self, reg, inst, free):
for x in itertools.chain(dup.duplicates, base.duplicates):
to_check.add(x)

free[reg.reg_type, base.size].add(self.alloc[base])
free[reg.reg_type, base.size].append(self.alloc[base])
if inst.is_vec() and base.vector:
self.defined[base] = inst
for i in base.vector:
Expand Down Expand Up @@ -604,4 +604,4 @@ def run(self, instructions):
elif op == 1:
instructions[i] = None
inst.args[0].link(inst.args[1])
instructions[:] = filter(lambda x: x is not None, instructions)
instructions[:] = list(filter(lambda x: x is not None, instructions))
2 changes: 1 addition & 1 deletion Compiler/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def sha3_256(x):
from circuit import sha3_256
a = sbitvec.from_vec([])
b = sbitvec(sint(0xcc), 8)
b = sbitvec(sint(0xcc), 8, 8)
for x in a, b:
sha3_256(x).elements()[0].reveal().print_reg()
Expand Down
1 change: 1 addition & 0 deletions Compiler/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def require_ring_size(k, op):
if int(program.options.ring) < k:
raise CompilerError('ring size too small for %s, compile '
'with \'-R %d\' or more' % (op, k))
program.curr_tape.require_bit_length(k)

@instructions_base.cisc
def LTZ(s, a, k, kappa):
Expand Down
Loading

0 comments on commit 99c0549

Please sign in to comment.