Skip to content

Commit

Permalink
Distributed key generation for homomorphic encryption with active sec…
Browse files Browse the repository at this point in the history
…urity.
  • Loading branch information
mkskeller committed Feb 23, 2021
1 parent 08a80cf commit c9b03d8
Show file tree
Hide file tree
Showing 136 changed files with 2,117 additions and 492 deletions.
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.2.2 (Jan 21, 2020)
## 0.2.3 (Feb 23, 2021)

- Distributed key generation for homomorphic encryption with active security similar to [Rotaru et al.](https://eprint.iacr.org/2019/1300)
- Homomorphic encryption parameters more similar to SCALE-MAMBA
- Fixed security bug: all-zero secret keys in homomorphic encryption
- Fixed security bug: missing check in binary Rep4
- Fixed security bug: insufficient "blaming" (covert security) in CowGear and HighGear due to low default security parameter

## 0.2.2 (Jan 21, 2021)

- Infrastructure for random element generation
- Programs generating as much preprocessing data as required by a particular high-level program
Expand Down
20 changes: 18 additions & 2 deletions Compiler/GC/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,9 @@ class bitb(NonVectorInstruction):
code = opcodes['BITB']
arg_format = ['sbw']

def add_usage(self, req_node):
req_node.increment(('bit', 'bit'), 1)

class reveal(BinaryVectorInstruction, base.VarArgsInstruction, base.Mergeable):
""" Reveal secret bit register vectors and copy result to clear bit
register vectors.
Expand Down Expand Up @@ -519,6 +522,10 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
arg_format = tools.cycle(['p','int','int','sbw'])
is_vec = lambda self: True

def add_usage(self, req_node):
for i in range(0, len(self.args), 4):
req_node.increment(('bit', 'input', self.args[0]), self.args[1])

class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
base.Mergeable):
""" Copy private input to secret bit registers bit by bit. The input is
Expand All @@ -538,17 +545,26 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,

def __init__(self, *args, **kwargs):
self.arg_format = []
for x in self.get_arg_tuples(args):
self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (x[0] - 3)
super(inputbvec, self).__init__(*args, **kwargs)

@staticmethod
def get_arg_tuples(args):
i = 0
while i < len(args):
self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (args[i] - 3)
yield args[i:i+args[i]]
i += args[i]
assert i == len(args)
super(inputbvec, self).__init__(*args, **kwargs)

def merge(self, other):
self.args += other.args
self.arg_format += other.arg_format

def add_usage(self, req_node):
for x in self.get_arg_tuples(self.args):
req_node.increment(('bit', 'input', x[2]), x[0] - 3)

class print_regb(base.VectorInstruction, base.IOInstruction):
""" Debug output of clear bit register.
Expand Down
23 changes: 11 additions & 12 deletions Compiler/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,14 @@ def mem_access(n, instr, last_access_this_kind, last_access_other_kind):
# hack
warned_about_mem.append(True)

def strict_mem_access(n, last_this_kind, last_other_kind):
if last_other_kind and last_this_kind and \
last_other_kind[-1] > last_this_kind[-1]:
last_this_kind[:] = []
last_this_kind.append(n)
for i in last_other_kind:
add_edge(i, n)

def keep_order(instr, n, t, arg_index=None):
if arg_index is None:
player = None
Expand Down Expand Up @@ -444,26 +452,17 @@ def keep_merged_order(instr, n, t):

if isinstance(instr, ReadMemoryInstruction):
if options.preserve_mem_order:
if last_mem_write and last_mem_read and last_mem_write[-1] > last_mem_read[-1]:
last_mem_read[:] = []
last_mem_read.append(n)
for i in last_mem_write:
add_edge(i, n)
strict_mem_access(n, last_mem_read, last_mem_write)
else:
mem_access(n, instr, last_mem_read_of, last_mem_write_of)
elif isinstance(instr, WriteMemoryInstruction):
if options.preserve_mem_order:
if last_mem_write and last_mem_read and last_mem_write[-1] < last_mem_read[-1]:
last_mem_write[:] = []
last_mem_write.append(n)
for i in last_mem_read:
add_edge(i, n)
strict_mem_access(n, last_mem_write, last_mem_read)
else:
mem_access(n, instr, last_mem_write_of, last_mem_read_of)
elif isinstance(instr, matmulsm):
if options.preserve_mem_order:
for i in last_mem_write:
add_edge(i, n)
strict_mem_access(n, last_mem_read, last_mem_write)
else:
for i in last_mem_write_of.values():
for j in i:
Expand Down
5 changes: 5 additions & 0 deletions Compiler/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,11 @@ def CarryOutRaw(a, b, c=0):
assert len(a) == len(b)
k = len(a)
from . import types
if program.linear_rounds():
carry = 0
for (ai, bi) in zip(a, b):
carry = bi.carry_out(ai, carry)
return carry
d = [program.curr_block.new_reg('s') for i in range(k)]
s = [program.curr_block.new_reg('s') for i in range(3)]
for i in range(k):
Expand Down
106 changes: 57 additions & 49 deletions Compiler/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,15 +627,14 @@ class Dense(DenseBase):
:param d_out: output dimension
"""
def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
self.activation = activation
if activation == 'id':
self.f = lambda x: x
self.activation_layer = None
elif activation == 'relu':
self.f = relu
self.f_prime = relu_prime
elif activation == 'sigmoid':
self.f = sigmoid
self.f_prime = sigmoid_prime
self.activation_layer = Relu([N, d, d_out])
elif activation == 'square':
self.activation_layer = Square([N, d, d_out])
else:
raise CompilerError('activation not supported: %s', activation)

self.N = N
self.d_in = d_in
Expand All @@ -652,10 +651,16 @@ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
self.nabla_W = sfix.Matrix(d_in, d_out)
self.nabla_b = sfix.Array(d_out)

self.f_input = MultiArray([N, d, d_out], sfix)

self.debug = debug

l = self.activation_layer
if l:
self.f_input = l.X
l.Y = self.Y
l.nabla_Y = self.nabla_Y
else:
self.f_input = self.Y

def reset(self):
d_in = self.d_in
d_out = self.d_out
Expand Down Expand Up @@ -699,10 +704,8 @@ def _(i):

def forward(self, batch=None):
self.compute_f_input(batch=batch)
@multithread(self.n_threads, len(batch), 128)
def _(base, size):
self.Y.assign_part_vector(self.f(
self.f_input.get_part_vector(base, size)), base)
if self.activation_layer:
self.activation_layer.forward(batch)
if self.debug:
limit = self.debug
@for_range_opt(len(batch))
Expand Down Expand Up @@ -731,26 +734,11 @@ def backward(self, compute_nabla_X=True, batch=None):
nabla_W = self.nabla_W
nabla_b = self.nabla_b

if self.activation == 'id':
f_schur_Y = nabla_Y
if self.activation_layer:
self.activation_layer.backward(batch)
f_schur_Y = self.activation_layer.nabla_X
else:
f_prime_bit = MultiArray([N, d, d_out], sint)
f_schur_Y = MultiArray([N, d, d_out], sfix)

@multithread(self.n_threads, f_prime_bit.total_size())
def _(base, size):
f_prime_bit.assign_vector(
self.f_prime(self.f_input.get_vector(base, size)), base)

progress('f prime')

@multithread(self.n_threads, f_prime_bit.total_size())
def _(base, size):
f_schur_Y.assign_vector(nabla_Y.get_vector(base, size) *
f_prime_bit.get_vector(base, size),
base)

progress('f prime schur Y')
f_schur_Y = nabla_Y

if compute_nabla_X:
@multithread(self.n_threads, N)
Expand Down Expand Up @@ -841,35 +829,55 @@ def _(k):
def backward(self):
self.nabla_X = self.nabla_Y.schur(self.B)

class Relu(NoVariableLayer):
""" Fixed-point ReLU layer.
:param shape: input/output shape (tuple/list of int)
"""
class ElementWiseLayer(NoVariableLayer):
def __init__(self, shape, inputs=None):
self.X = Tensor(shape, sfix)
self.Y = Tensor(shape, sfix)
self.nabla_X = Tensor(shape, sfix)
self.nabla_Y = Tensor(shape, sfix)
self.inputs = inputs

def forward(self, batch=[0]):
assert len(batch) == 1
@multithread(self.n_threads, self.X[batch[0]].total_size())
@multithread(self.n_threads, len(batch), 128)
def _(base, size):
tmp = relu(self.X[batch[0]].get_vector(base, size))
self.Y[batch[0]].assign_vector(tmp, base)
self.Y.assign_part_vector(self.f(
self.X.get_part_vector(base, size)), base)

class Square(NoVariableLayer):
""" Fixed-point square layer.
def backward(self, batch):
f_prime_bit = MultiArray(self.X.sizes, self.prime_type)

@multithread(self.n_threads, f_prime_bit.total_size())
def _(base, size):
f_prime_bit.assign_vector(
self.f_prime(self.X.get_vector(base, size)), base)

progress('f prime')

@multithread(self.n_threads, f_prime_bit.total_size())
def _(base, size):
self.nabla_X.assign_vector(self.nabla_Y.get_vector(base, size) *
f_prime_bit.get_vector(base, size),
base)

progress('f prime schur Y')

class Relu(ElementWiseLayer):
""" Fixed-point ReLU layer.
:param shape: input/output shape (tuple/list of int)
"""
def __init__(self, shape):
self.X = MultiArray(shape, sfix)
self.Y = MultiArray(shape, sfix)
f = staticmethod(relu)
f_prime = staticmethod(relu_prime)
prime_type = sint

def forward(self, batch=[0]):
assert len(batch) == 1
self.Y.assign_vector(self.X.get_part_vector(batch[0]) ** 2)
class Square(ElementWiseLayer):
""" Fixed-point square layer.
:param shape: input/output shape (tuple/list of int)
"""
f = staticmethod(lambda x: x ** 2)
f_prime = staticmethod(lambda x: cfix(2, size=x.size) * x)
prime_type = sfix

class MaxPool(NoVariableLayer):
""" Fixed-point MaxPool layer.
Expand Down
13 changes: 12 additions & 1 deletion Compiler/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(self, args, options=defaults):
self.use_split(int(options.split))
self._square = False
self._always_raw = False
self._linear_rounds = False
self.warn_about_mem = [True]
Program.prog = self
from . import instructions_base, instructions, types, comparison
Expand Down Expand Up @@ -461,6 +462,12 @@ def always_raw(self, change=None):
else:
self._always_raw = change

def linear_rounds(self, change=None):
if change is None:
return self._linear_rounds
else:
self._linear_rounds = change

def options_from_args(self):
""" Set a number of options from the command-line arguments. """
if 'trunc_pr' in self.args:
Expand All @@ -475,6 +482,8 @@ def options_from_args(self):
self.always_raw(True)
if 'edabit' in self.args:
self.use_edabit(True)
if 'linear_rounds' in self.args:
self.linear_rounds(True)

def disable_memory_warnings(self):
self.warn_about_mem.append(False)
Expand Down Expand Up @@ -855,7 +864,9 @@ def write_bytes(self, filename=None):
filename = self.program.programs_dir + '/Bytecode/' + filename
print('Writing to', filename)
f = open(filename, 'wb')
f.write(self.get_bytes())
for i in self._get_instructions():
if i is not None:
f.write(i.get_bytes())
f.close()

def new_reg(self, reg_type, size=None):
Expand Down
7 changes: 6 additions & 1 deletion Compiler/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __mul__(self, other):
def __pow__(self, exp):
""" Exponentation through square-and-multiply.
:param exp: compile-time (int) """
:param exp: any type allowing bit decomposition """
if isinstance(exp, int) and exp >= 0:
if exp == 0:
return self.__class__(1)
Expand Down Expand Up @@ -425,6 +425,10 @@ def half_adder(self, other):
:rtype: depending on inputs (secret if any of them is) """
return self ^ other, self & other

def carry_out(self, a, b):
s = a ^ b
return a ^ (s & (self ^ a))

class _gf2n(_bit):
""" :math:`\mathrm{GF}(2^n)` functionality. """

Expand Down Expand Up @@ -5247,6 +5251,7 @@ def reveal(self):
bit_decompose = lambda self,*args,**kwargs: self.read().bit_decompose(*args, **kwargs)

if_else = lambda self,*args,**kwargs: self.read().if_else(*args, **kwargs)
bit_and = lambda self,other: self.read().bit_and(other)

def expand_to_vector(self, size=None):
if program.curr_block == self.last_write_block:
Expand Down
12 changes: 6 additions & 6 deletions FHE/DiscreteGauss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ int DiscreteGauss::sample(PRNG &G, int stretch) const
void RandomVectors::set(int nn,int hh,double R)
{
n=nn;
if (hh > 0)
h=hh;
h=hh;
DG.set(R);
assert(h > 0);
}


void RandomVectors::set_n(int nn)
{
n = nn;
}

vector<bigint> RandomVectors::sample_Gauss(PRNG& G, int stretch) const
{
Expand All @@ -54,8 +55,7 @@ vector<bigint> RandomVectors::sample_Gauss(PRNG& G, int stretch) const

vector<bigint> RandomVectors::sample_Hwt(PRNG& G) const
{
assert(h > 0);
if (h>n/2) { return sample_Gauss(G); }
if (h > n/2 or h <= 0) { return sample_Gauss(G); }
vector<bigint> ans(n);
for (int i=0; i<n; i++) { ans[i]=0; }
int cnt=0,j=0;
Expand Down
4 changes: 3 additions & 1 deletion FHE/DiscreteGauss.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ class RandomVectors
public:

void set(int nn,int hh,double R); // R is input STANDARD DEVIATION
void set_n(int nn);

void pack(octetStream& o) const { o.store(n); o.store(h); DG.pack(o); }
void unpack(octetStream& o)
{ o.get(n); o.get(h); DG.unpack(o); if(h <= 0) throw exception(); }
{ o.get(n); o.get(h); DG.unpack(o); }

RandomVectors(int h, double R) : RandomVectors(0, h, R) {}
RandomVectors(int nn,int hh,double R) : DG(R) { set(nn,hh,R); }
Expand All @@ -61,6 +62,7 @@ class RandomVectors

double get_R() const { return DG.get_R(); }
DiscreteGauss get_DG() const { return DG; }
int get_h() const { return h; }

// Sample from Discrete Gauss distribution
vector<bigint> sample_Gauss(PRNG& G, int stretch = 1) const;
Expand Down
Loading

0 comments on commit c9b03d8

Please sign in to comment.