diff --git a/BMR/Party.cpp b/BMR/Party.cpp index 84ba909b3..beddd64cf 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -259,7 +259,6 @@ ProgramParty::~ProgramParty() reset(); if (P) { - cerr << "Data sent: " << 1e-6 * P->total_comm().total_data() << " MB" << endl; delete P; } delete[] eval_threads; diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 8e16c3077..ae69cb7f5 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -28,7 +28,7 @@ RealProgramParty* RealProgramParty::singleton = 0; template RealProgramParty::RealProgramParty(int argc, const char** argv) : - garble_processor(garble_machine), dummy_proc({{}, 0}) + garble_processor(garble_machine), dummy_proc({}, 0) { assert(singleton == 0); singleton = this; @@ -157,6 +157,9 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : MC->Check(*P); data_sent = P->total_comm().sent; + if (online_opts.verbose) + P->total_comm().print(); + this->machine.write_memory(this->N.my_num()); } diff --git a/CHANGELOG.md b/CHANGELOG.md index 744d0ff1b..18cc92ae3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.3.2 (Mai 27, 2022) + +- Secure shuffling +- O(n log n) radix sorting +- Documented BGV encryption interface +- Optimized matrix multiplication in dealer protocol +- Fixed security bug in homomorphic encryption parameter generation +- Fixed Security bug in Temi matrix multiplication + ## 0.3.1 (Apr 19, 2022) - Protocol in dealer model diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index b34e68c82..fdd987225 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -382,7 +382,6 @@ class sbits(bits): reg_type = 'sb' is_clear = False clear_type = cbits - default_type = cbits load_inst = (inst.ldmsbi, inst.ldmsb) store_inst = (inst.stmsbi, inst.stmsb) bitdec = inst.bitdecs @@ -404,6 +403,9 @@ def new(value=None, n=None): else: return sbits.get_type(n)(value) @staticmethod + def _new(value): + return value + @staticmethod def get_random_bit(): res = sbit() inst.bitb(res) @@ -909,6 +911,7 @@ class cbit(bit, cbits): sbits.bit_type = sbit cbits.bit_type = cbit sbit.clear_type = cbit +sbits.default_type = sbits class bitsBlock(oram.Block): value_type = sbits diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 8a10ee58c..5f5b82dbc 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -17,6 +17,7 @@ import itertools import operator +import math from . import tools from random import randint from functools import reduce @@ -2406,6 +2407,70 @@ class trunc_pr(base.VarArgsInstruction): code = base.opcodes['TRUNC_PR'] arg_format = tools.cycle(['sw','s','int','int']) +@base.gf2n +class secshuffle(base.VectorInstruction, base.DataInstruction): + """ Secure shuffling. + + :param: destination (sint) + :param: source (sint) + """ + __slots__ = [] + code = base.opcodes['SECSHUFFLE'] + arg_format = ['sw','s','int'] + + def __init__(self, *args, **kwargs): + super(secshuffle_class, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + assert len(args[0]) > args[2] + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'input', 0), float('inf')) + +class gensecshuffle(base.DataInstruction): + """ Generate secure shuffle to bit used several times. + + :param: destination (regint) + :param: size (int) + + """ + __slots__ = [] + code = base.opcodes['GENSECSHUFFLE'] + arg_format = ['ciw','int'] + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'input', 0), float('inf')) + +class applyshuffle(base.VectorInstruction, base.DataInstruction): + """ Generate secure shuffle to bit used several times. + + :param: destination (sint) + :param: source (sint) + :param: number of elements to be treated as one (int) + :param: handle (regint) + :param: reverse (0/1) + + """ + __slots__ = [] + code = base.opcodes['APPLYSHUFFLE'] + arg_format = ['sw','s','int','ci','int'] + + def __init__(self, *args, **kwargs): + super(applyshuffle, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + assert len(args[0]) > args[2] + + def add_usage(self, req_node): + req_node.increment((self.field_type, 'triple', 0), float('inf')) + +class delshuffle(base.Instruction): + """ Delete secure shuffle. + + :param: handle (regint) + + """ + code = base.opcodes['DELSHUFFLE'] + arg_format = ['ci'] + class check(base.Instruction): """ Force MAC check in current thread and all idle thread if current diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 8ae0b86fc..d598d8a71 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -106,6 +106,11 @@ CONV2DS = 0xAC, CHECK = 0xAF, PRIVATEOUTPUT = 0xAD, + # Shuffling + SECSHUFFLE = 0xFA, + GENSECSHUFFLE = 0xFB, + APPLYSHUFFLE = 0xFC, + DELSHUFFLE = 0xFD, # Data access TRIPLE = 0x50, BIT = 0x51, diff --git a/Compiler/oram.py b/Compiler/oram.py index 543fc4aab..d4b434385 100644 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -348,7 +348,7 @@ def __iter__(self): def __len__(self): return 2 + len(self.x) def __repr__(self): - return '{empty=%s}' % self.is_empty if self.is_empty \ + return '{empty=%s}' % self.is_empty if util.is_one(self.is_empty) \ else '{%s: %s}' % (self.v, self.x) def __add__(self, other): try: @@ -466,12 +466,14 @@ class AbstractORAM(object): def get_array(size, t, *args, **kwargs): return t.dynamic_array(size, t, *args, **kwargs) def read(self, index): - return self._read(self.value_type.hard_conv(index)) + res = self._read(self.index_type.hard_conv(index)) + res = [self.value_type._new(x) for x in res] + return res def write(self, index, value): + value = util.tuplify(value) + value = [self.value_type.conv(x) for x in value] new_value = [self.value_type.get_type(length).hard_conv(v) \ - for length,v in zip(self.entry_size, value \ - if isinstance(value, (tuple, list)) \ - else (value,))] + for length,v in zip(self.entry_size, value)] return self._write(self.index_type.hard_conv(index), *new_value) def access(self, index, new_value, write, new_empty=False): return self._access(self.index_type.hard_conv(index), @@ -795,7 +797,8 @@ def batch_init(self, values): for i,value in enumerate(values): index = MemValue(self.value_type.hard_conv(i)) new_value = [MemValue(self.value_type.hard_conv(v)) \ - for v in (value if isinstance(value, (tuple, list)) \ + for v in (value if isinstance( + value, (tuple, list, Array)) \ else (value,))] self.ram[i] = Entry(index, new_value, value_type=self.value_type) @@ -986,7 +989,8 @@ def batch_init(self, values): for i,value in enumerate(values): index = self.value_type.hard_conv(i) new_value = [self.value_type.hard_conv(v) \ - for v in (value if isinstance(value, (tuple, list)) \ + for v in (value if isinstance( + value, (tuple, list, Array)) \ else (value,))] self.__setitem__(index, new_value) def __repr__(self): @@ -1062,11 +1066,12 @@ def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \ stop_timer(1) start_timer() self.root = RefBucket(1, self) - self.index = self.index_structure(size, self.D, value_type, init_rounds, True) + self.index = self.index_structure(size, self.D, self.index_type, + init_rounds, True) - self.read_value = Array(self.value_length, value_type) + self.read_value = Array(self.value_length, value_type.default_type) self.read_non_empty = MemValue(self.value_type.bit_type(0)) - self.state = MemValue(self.value_type(0)) + self.state = MemValue(self.value_type.default_type(0)) @method_block def add_to_root(self, state, is_empty, v, *x): if len(x) != self.value_length: @@ -1106,10 +1111,10 @@ def evict2(self, p_bucket1, p_bucket2, d): self.evict_bucket(RefBucket(p_bucket2, self), d) @method_block def read_and_renew_index(self, u): - l_star = random_block(self.D, self.value_type) + l_star = random_block(self.D, self.index_type) if use_insecure_randomness: new_path = regint.get_random(self.D) - l_star = self.value_type(new_path) + l_star = self.index_type(new_path) self.state.write(l_star) return self.index.update(u, l_star, evict=False).reveal() @method_block @@ -1120,7 +1125,7 @@ def read_and_remove_levels(self, u, read_path): parallel = get_parallel(self.index_size, *self.internal_value_type()) @map_sum(get_n_threads_for_tree(self.size), parallel, levels, \ self.value_length + 1, [self.value_type.bit_type] + \ - [self.value_type] * self.value_length) + [self.value_type.default_type] * self.value_length) def process(level): b_index = regint(cint(2**(self.D) + read_path) >> cint(self.D - level)) bucket = RefBucket(b_index, self) @@ -1142,9 +1147,9 @@ def f(): Program.prog.curr_tape.start_new_basicblock() crash() def internal_value_type(self): - return self.value_type, self.value_length + 1 + return self.value_type.default_type, self.value_length + 1 def internal_entry_size(self): - return self.value_type, [self.D] + list(self.entry_size) + return self.value_type.default_type, [self.D] + list(self.entry_size) def n_buckets(self): return 2**(self.D+1) @method_block @@ -1176,8 +1181,9 @@ def add(self, entry, state=None, evict=True): #print 'pre-add', self maybe_start_timer(4) self.add_to_root(state, entry.empty(), \ - self.value_type(entry.v.read()), \ - *(self.value_type(i.read()) for i in entry.x)) + self.index_type(entry.v.read()), \ + *(self.value_type.default_type(i.read()) + for i in entry.x)) maybe_stop_timer(4) #print 'pre-evict', self if evict: @@ -1228,21 +1234,27 @@ def batch_init(self, values): raise CompilerError('Batch initialization only possible with sint.') depth = log2(m) - leaves = [0] * m - entries = [0] * m - indexed_values = [0] * m + leaves = self.value_type.Array(m) + indexed_values = \ + self.value_type.Matrix(m, len(util.tuplify(values[0])) + 1) # assign indices 0, ..., m-1 - for i,value in enumerate(values): + @for_range(m) + def _(i): + value = values[i] index = MemValue(self.value_type.hard_conv(i)) new_value = [MemValue(self.value_type.hard_conv(v)) \ for v in (value if isinstance(value, (tuple, list)) \ else (value,))] indexed_values[i] = [index] + new_value - + entries = sint.Matrix(self.bucket_size * 2 ** self.D, + len(Entry(0, list(indexed_values[0]), False))) + # assign leaves - for i,index_value in enumerate(indexed_values): + @for_range(len(indexed_values)) + def _(i): + index_value = list(indexed_values[i]) leaves[i] = random_block(self.D, self.value_type) index = index_value[0] @@ -1252,18 +1264,20 @@ def batch_init(self, values): # save unsorted leaves for position map unsorted_leaves = [MemValue(self.value_type(leaf)) for leaf in leaves] - permutation.sort(leaves, comp=permutation.normal_comparator) + leaves.sort() bucket_sz = 0 # B[i] = (pos, leaf, "last in bucket" flag) for i-th entry - B = [[0]*3 for i in range(m)] + B = sint.Matrix(m, 3) B[0] = [0, leaves[0], 0] B[-1] = [None, None, sint(1)] - s = 0 + s = MemValue(sint(0)) - for i in range(1, m): + @for_range_opt(m - 1) + def _(j): + i = j + 1 eq = leaves[i].equal(leaves[i-1]) - s = (s + eq) * eq + s.write((s + eq) * eq) B[i][0] = s B[i][1] = leaves[i] B[i-1][2] = 1 - eq @@ -1271,7 +1285,7 @@ def batch_init(self, values): #last_in_bucket[i-1] = 1 - eq # shuffle - permutation.shuffle(B, value_type=sint) + B.secure_shuffle() #cint(0).print_reg('shuf') sz = MemValue(0) #cint(0) @@ -1279,7 +1293,8 @@ def batch_init(self, values): empty_positions = Array(nleaves, self.value_type) empty_leaves = Array(nleaves, self.value_type) - for i in range(m): + @for_range(m) + def _(i): if_then(reveal(B[i][2])) #if B[i][2] == 1: #cint(i).print_reg('last') @@ -1291,12 +1306,13 @@ def batch_init(self, values): empty_positions[szval] = B[i][0] #pos[i][0] #empty_positions[szval].reveal().print_reg('ps0') empty_leaves[szval] = B[i][1] #pos[i][1] - sz += 1 + sz.iadd(1) end_if() - pos_bits = [] + pos_bits = self.value_type.Matrix(self.bucket_size * nleaves, 2) - for i in range(nleaves): + @for_range_opt(nleaves) + def _(i): leaf = empty_leaves[i] # split into 2 if bucket size can't fit into one field elem if self.bucket_size + Program.prog.security > 128: @@ -1315,46 +1331,39 @@ def batch_init(self, values): bucket_bits = [b for sl in zip(bits2,bits) for b in sl] else: bucket_bits = floatingpoint.B2U(empty_positions[i]+1, self.bucket_size, Program.prog.security)[0] - pos_bits += [[b, leaf] for b in bucket_bits] + assert len(bucket_bits) == self.bucket_size + for j, b in enumerate(bucket_bits): + pos_bits[i * self.bucket_size + j] = [b, leaf] # sort to get empty positions first - permutation.sort(pos_bits, comp=permutation.bitwise_list_comparator) + pos_bits.sort(n_bits=1) # now assign positions to empty entries - empty_entries = [0] * (self.bucket_size*2**self.D - m) - - for i in range(self.bucket_size*2**self.D - m): + @for_range(len(entries) - m) + def _(i): vtype, vlength = self.internal_value_type() leaf = vtype(pos_bits[i][1]) # set leaf in empty entry for assigning after shuffle - value = tuple([leaf] + [vtype(0) for j in range(vlength)]) + value = tuple([leaf] + [vtype(0) for j in range(vlength - 1)]) entry = Entry(vtype(0), value, vtype.hard_conv(True), vtype) - empty_entries[i] = entry + entries[m + i] = entry # now shuffle, reveal positions and place entries - entries = entries + empty_entries - while len(entries) & (len(entries)-1) != 0: - entries.append(None) - permutation.shuffle(entries, value_type=sint) - entries = [entry for entry in entries if entry is not None] - clear_leaves = [MemValue(entry.x[0].reveal()) for entry in entries] + entries.secure_shuffle() + clear_leaves = Array.create_from( + Entry(entries.get_columns()).x[0].reveal()) Program.prog.curr_tape.start_new_basicblock() bucket_sizes = Array(2**self.D, regint) for i in range(2**self.D): bucket_sizes[i] = 0 - k = 0 - for entry,leaf in zip(entries, clear_leaves): - leaf = leaf.read() - k += 1 - - # for some reason leaf_buckets is in bit-reversed order - bits = bit_decompose(leaf, self.D) - rev_leaf = sum(b*2**i for i,b in enumerate(bits[::-1])) - bucket = RefBucket(rev_leaf + (1 << self.D), self) - # hack: 1*entry ensures MemValues are converted to sints - bucket.bucket.ram[bucket_sizes[leaf]] = 1*entry + + @for_range_opt(len(entries)) + def _(k): + leaf = clear_leaves[k] + bucket = RefBucket(leaf + (1 << self.D), self) + bucket.bucket.ram[bucket_sizes[leaf]] = Entry(entries[k]) bucket_sizes[leaf] += 1 self.index.batch_init([leaf.read() for leaf in unsorted_leaves]) @@ -1599,16 +1608,20 @@ def __setitem__(self, index, value): def batch_init(self, values): """ Initialize m values with indices 0, ..., m-1 """ m = len(values) - n_entries = max(1, m/self.entries_per_block) - new_values = [0] * n_entries + n_entries = max(1, m//self.entries_per_block) + new_values = sint.Matrix(n_entries, self.elements_per_block) + values = Array.create_from(values) - for i in range(n_entries): + @for_range(n_entries) + def _(i): block = [0] * self.elements_per_block for j in range(self.elements_per_block): base = i * self.entries_per_block + j * self.entries_per_element for k in range(self.entries_per_element): - if base + k < m: - block[j] += values[base + k] << (k * self.entry_size) + @if_(base + k < m) + def _(): + block[j] += \ + values[base + k] << (k * sum(self.entry_size)) new_values[i] = block @@ -1667,7 +1680,8 @@ def OptimalORAM(size,*args,**kwargs): experiments. :param size: number of elements - :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` + :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` / + :py:class:`sfix` """ if optimal_threshold is None: if n_threads == 1: @@ -1784,7 +1798,7 @@ def test_batch_init(oram_type, N): oram = oram_type(N, value_type) print('initialized') print_reg(cint(0), 'init') - oram.batch_init([value_type(i) for i in range(N)]) + oram.batch_init(Array.create_from(sint(regint.inc(N)))) print_reg(cint(0), 'done') @for_range(N) def f(i): diff --git a/Compiler/path_oram.py b/Compiler/path_oram.py index fb1601c3d..b9e3952ba 100644 --- a/Compiler/path_oram.py +++ b/Compiler/path_oram.py @@ -111,24 +111,6 @@ def bucket_size_sorter(x, y): return 1 - reduce(lambda x,y: x*y, t.bit_decompose(2*Z)[:Z]) -def shuffle(x, config=None, value_type=sgf2n, reverse=False): - """ Simulate secure shuffling with Waksman network for 2 players. - - - Returns the network switching config so it may be re-used later. """ - n = len(x) - if n & (n-1) != 0: - raise CompilerError('shuffle requires n a power of 2') - if config is None: - config = permutation.configure_waksman(permutation.random_perm(n)) - for i,c in enumerate(config): - config[i] = [value_type(b) for b in c] - permutation.waksman(x, config, reverse=reverse) - permutation.waksman(x, config, reverse=reverse) - - return config - - def LT(a, b): a_bits = bit_decompose(a) b_bits = bit_decompose(b) @@ -472,10 +454,15 @@ def f(): print_ln() # shuffle entries and levels - while len(merged_entries) & (len(merged_entries)-1) != 0: - merged_entries.append(None) #self.root.bucket.empty_entry(False)) - permutation.rec_shuffle(merged_entries, value_type=self.value_type) - merged_entries = [e for e in merged_entries if e is not None] + flat = [] + for x in merged_entries: + flat += list(x[0]) + [x[1]] + flat = self.value_type(flat) + assert len(flat) % len(merged_entries) == 0 + l = len(flat) // len(merged_entries) + shuffled = flat.secure_shuffle(l) + merged_entries = [[Entry(shuffled[i*l:(i+1)*l-1]), shuffled[(i+1)*l-1]] + for i in range(len(shuffled) // l)] # need to copy entries/levels to memory for re-positioning entries_ram = RAM(self.temp_size, self.entry_type, self.get_array) diff --git a/Compiler/permutation.py b/Compiler/permutation.py index 6e1273ec4..07d3a3e70 100644 --- a/Compiler/permutation.py +++ b/Compiler/permutation.py @@ -10,16 +10,6 @@ from Compiler.program import Program _Array = Array -SORT_BITS = [] -insecure_random = Random(0) - -def predefined_comparator(x, y): - """ Assumes SORT_BITS is populated with the required sorting network bits """ - if predefined_comparator.sort_bits_iter is None: - predefined_comparator.sort_bits_iter = iter(SORT_BITS) - return next(predefined_comparator.sort_bits_iter) -predefined_comparator.sort_bits_iter = None - def list_comparator(x, y): """ Uses the first element in the list for comparison """ return x[0] < y[0] @@ -37,10 +27,6 @@ def bitwise_comparator(x, y): def cond_swap_bit(x,y, b): """ swap if b == 1 """ - if x is None: - return y, None - elif y is None: - return x, None if isinstance(x, list): t = [(xi - yi) * b for xi,yi in zip(x, y)] return [xi - ti for xi,ti in zip(x, t)], \ @@ -87,23 +73,6 @@ def odd_even_merge_sort(a, comp=bitwise_comparator): else: raise CompilerError('Length of list must be power of two') -def merge(a, b, comp): - """ General length merge (pads to power of 2) """ - while len(a) & (len(a)-1) != 0: - a.append(None) - while len(b) & (len(b)-1) != 0: - b.append(None) - if len(a) < len(b): - a += [None] * (len(b) - len(a)) - elif len(b) < len(a): - b += [None] * (len(b) - len(b)) - t = a + b - odd_even_merge(t, comp) - for i,v in enumerate(t[::]): - if v is None: - t.remove(None) - return t - def sort(a, comp): """ Pads to power of 2, sorts, removes padding """ length = len(a) @@ -112,47 +81,12 @@ def sort(a, comp): odd_even_merge_sort(a, comp) del a[length:] -def recursive_merge(a, comp): - """ Recursively merge a list of sorted lists (initially sorted by size) """ - if len(a) == 1: - return - # merge smallest two lists, place result in correct position, recurse - t = merge(a[0], a[1], comp) - del a[0] - del a[0] - added = False - for i,c in enumerate(a): - if len(c) >= len(t): - a.insert(i, t) - added = True - break - if not added: - a.append(t) - recursive_merge(a, comp) - -def random_perm(n): - """ Generate a random permutation of length n - - WARNING: randomness fixed at compile-time, this is NOT secure - """ - if not Program.prog.options.insecure: - raise CompilerError('no secure implementation of Waksman permution, ' - 'use --insecure to activate') - a = list(range(n)) - for i in range(n-1, 0, -1): - j = insecure_random.randint(0, i) - t = a[i] - a[i] = a[j] - a[j] = t - return a - -def inverse(perm): - inv = [None] * len(perm) - for i, p in enumerate(perm): - inv[p] = i - return inv +# The following functionality for shuffling isn't used any more as it +# has been moved to the virtual machine. The code has been kept for +# reference. -def configure_waksman(perm): +def configure_waksman(perm, n_iter=[0]): + top = n_iter == [0] n = len(perm) if n == 2: return [(perm[0], perm[0])] @@ -175,6 +109,7 @@ def configure_waksman(perm): via = 0 j0 = j while True: + n_iter[0] += 1 #print ' I[%d] = %d' % (inv_perm[j]/2, ((inv_perm[j] % 2) + via) % 2) i = inv_perm[j] @@ -209,8 +144,11 @@ def configure_waksman(perm): assert sorted(p0) == list(range(n//2)) assert sorted(p1) == list(range(n//2)) - p0_config = configure_waksman(p0) - p1_config = configure_waksman(p1) + p0_config = configure_waksman(p0, n_iter) + p1_config = configure_waksman(p1, n_iter) + if top: + print(n_iter[0], 'iterations for Waksman') + assert O[0] == 0, 'not a Waksman network' return [I + O] + [a+b for a,b in zip(p0_config, p1_config)] def waksman(a, config, depth=0, start=0, reverse=False): @@ -358,23 +296,10 @@ def _(i): # nblocks /= 2 # depth -= 1 -def rec_shuffle(x, config=None, value_type=sgf2n, reverse=False): - n = len(x) - if n & (n-1) != 0: - raise CompilerError('shuffle requires n a power of 2') - if config is None: - config = configure_waksman(random_perm(n)) - for i,c in enumerate(config): - config[i] = [value_type.bit_type(b) for b in c] - waksman(x, config, reverse=reverse) - waksman(x, config, reverse=reverse) - -def config_shuffle(n, value_type): - """ Compute config for oblivious shuffling. - - Take mod 2 for active sec. """ - perm = random_perm(n) +def config_from_perm(perm, value_type): + n = len(perm) + assert(list(sorted(perm))) == list(range(n)) if n & (n-1) != 0: # pad permutation to power of 2 m = 2**int(math.ceil(math.log(n, 2))) @@ -394,103 +319,3 @@ def _(i): for j,b in enumerate(c): config[i * len(perm) + j] = b return config - -def shuffle(x, config=None, value_type=sgf2n, reverse=False): - """ Simulate secure shuffling with Waksman network for 2 players. - WARNING: This is not a properly secure implementation but has roughly the right complexity. - - Returns the network switching config so it may be re-used later. """ - n = len(x) - m = 2**int(math.ceil(math.log(n, 2))) - assert n == m, 'only working for powers of two' - if config is None: - config = config_shuffle(n, value_type) - - if isinstance(x, list): - if isinstance(x[0], list): - length = len(x[0]) - assert len(x) == length - for i in range(length): - xi = Array(m, value_type.reg_type) - for j in range(n): - xi[j] = x[j][i] - for j in range(n, m): - xi[j] = value_type(0) - iter_waksman(xi, config, reverse=reverse) - iter_waksman(xi, config, reverse=reverse) - for j, y in enumerate(xi): - x[j][i] = y - else: - xa = Array(m, value_type.reg_type) - for i in range(n): - xa[i] = x[i] - for i in range(n, m): - xa[i] = value_type(0) - iter_waksman(xa, config, reverse=reverse) - iter_waksman(xa, config, reverse=reverse) - x[:] = xa - elif isinstance(x, Array): - if len(x) != m and config is None: - raise CompilerError('Non-power of 2 Array input not yet supported') - iter_waksman(x, config, reverse=reverse) - iter_waksman(x, config, reverse=reverse) - else: - raise CompilerError('Invalid type for shuffle:', type(x)) - - return config - -def shuffle_entries(x, entry_cls, config=None, value_type=sgf2n, reverse=False, perm_size=None): - """ Shuffle a list of ORAM entries. - - Randomly permutes the first "perm_size" entries, leaving the rest (empty - entry padding) in the same position. """ - n = len(x) - l = len(x[0]) - if n & (n-1) != 0: - raise CompilerError('Entries must be padded to power of two length.') - if perm_size is None: - perm_size = n - - xarrays = [Array(n, value_type.reg_type) for i in range(l)] - for i in range(n): - for j,value in enumerate(x[i]): - if isinstance(value, MemValue): - xarrays[j][i] = value.read() - else: - xarrays[j][i] = value - - if config is None: - config = config_shuffle(perm_size, value_type) - for xi in xarrays: - shuffle(xi, config, value_type, reverse) - for i in range(n): - x[i] = entry_cls(xarrays[j][i] for j in range(l)) - return config - - -def sort_zeroes(bits, x, n_ones, value_type): - """ Return Array of values in "x" where the corresponding bit in "bits" is - a 0. - - The total number of zeroes in "bits" must be known. - "bits" and "x" must be Arrays. """ - config = config_shuffle(len(x), value_type) - shuffle(bits, config=config, value_type=value_type) - shuffle(x, config=config, value_type=value_type) - result = Array(n_ones, value_type.reg_type) - - sz = MemValue(0) - last_x = MemValue(value_type(0)) - #for i,b in enumerate(bits): - #if_then(b.reveal() == 0) - #result[sz.read()] = x[i] - #sz += 1 - #end_if() - @for_range(len(bits)) - def f(i): - found = (bits[i].reveal() == 0) - szval = sz.read() - result[szval] = last_x + (x[i] - last_x) * found - sz.write(sz + found) - last_x.write(result[szval]) - return result diff --git a/Compiler/sorting.py b/Compiler/sorting.py new file mode 100644 index 000000000..248b3ea07 --- /dev/null +++ b/Compiler/sorting.py @@ -0,0 +1,54 @@ +import itertools +from Compiler import types, library, instructions + +def dest_comp(B): + Bt = B.transpose() + Bt_flat = Bt.get_vector() + St_flat = Bt.value_type.Array(len(Bt_flat)) + St_flat.assign(Bt_flat) + @library.for_range(len(St_flat) - 1) + def _(i): + St_flat[i + 1] = St_flat[i + 1] + St_flat[i] + Tt_flat = Bt.get_vector() * St_flat.get_vector() + Tt = types.Matrix(*Bt.sizes, B.value_type) + Tt.assign_vector(Tt_flat) + return sum(Tt) - 1 + +def reveal_sort(k, D, reverse=False): + assert len(k) == len(D) + library.break_point() + shuffle = types.sint.get_secure_shuffle(len(k)) + k_prime = k.get_vector().secure_permute(shuffle).reveal() + idx = types.Array.create_from(k_prime) + if reverse: + D.assign_vector(D.get_slice_vector(idx)) + library.break_point() + D.secure_permute(shuffle, reverse=True) + else: + D.secure_permute(shuffle) + library.break_point() + v = D.get_vector() + D.assign_slice_vector(idx, v) + library.break_point() + instructions.delshuffle(shuffle) + +def radix_sort(k, D, n_bits=None, signed=True): + assert len(k) == len(D) + bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits)) + if signed and len(bs) > 1: + bs[-1][:] = bs[-1][:].bit_not() + B = types.sint.Matrix(len(k), 2) + h = types.Array.create_from(types.sint(types.regint.inc(len(k)))) + @library.for_range(len(bs)) + def _(i): + b = bs[i] + B.set_column(0, 1 - b.get_vector()) + B.set_column(1, b.get_vector()) + c = types.Array.create_from(dest_comp(B)) + reveal_sort(c, h, reverse=False) + @library.if_e(i < len(bs) - 1) + def _(): + reveal_sort(h, bs[i + 1], reverse=True) + @library.else_ + def _(): + reveal_sort(h, D, reverse=True) diff --git a/Compiler/types.py b/Compiler/types.py index 1d06f3f71..098f493f0 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -1937,6 +1937,11 @@ def matrix_mul(cls, A, B, n, res_params=None): matmuls(res, A, B, n_rows, n, n_cols) return res + @staticmethod + def _new(self): + # mirror sfix + return self + @no_doc def __init__(self, reg_type, val=None, size=None): if isinstance(val, self.clear_type): @@ -2093,6 +2098,12 @@ def square(self): else: return self * self + @set_instruction_type + def secure_shuffle(self, unit_size=1): + res = type(self)(size=self.size) + secshuffle(res, self, unit_size) + return res + @set_instruction_type @vectorize def reveal(self): @@ -2741,6 +2752,17 @@ def private_division(self, divisor, active=True, dividend_length=None, return w + @staticmethod + def get_secure_shuffle(n): + res = regint() + gensecshuffle(res, n) + return res + + def secure_permute(self, shuffle, unit_size=1, reverse=False): + res = sint(size=self.size) + applyshuffle(res, self, unit_size, shuffle, reverse) + return res + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ @@ -4291,6 +4313,10 @@ class revealed_fix(self.clear_type): k = self.k return revealed_fix._new(val) + def bit_decompose(self, n_bits=None): + """ Bit decomposition. """ + return self.v.bit_decompose(n_bits or self.k) + class sfix(_fix): """ Secret fixed-point number represented as secret integer, by multiplying with ``2^f`` and then rounding. See :py:class:`sint` @@ -4312,6 +4338,8 @@ class sfix(_fix): int_type = sint bit_type = sintbit clear_type = cfix + get_type = staticmethod(lambda n: sint) + default_type = sint @vectorized_classmethod def get_input_from(cls, player): @@ -4385,6 +4413,10 @@ def expand_to_vector(self, size): def coerce(self, other): return parse_type(other, k=self.k, f=self.f) + def hard_conv_me(self, cls): + assert cls == sint + return self.v + def mul_no_reduce(self, other, res_params=None): assert self.f == other.f assert self.k == other.k @@ -4409,6 +4441,14 @@ def reveal_to(self, player): return personal(player, cfix._new(self.v.reveal_to(player)._v, self.k, self.f)) + def secure_shuffle(self, *args, **kwargs): + return self._new(self.v.secure_shuffle(*args, **kwargs), + k=self.k, f=self.f) + + def secure_permute(self, *args, **kwargs): + return self._new(self.v.secure_permute(*args, **kwargs), + k=self.k, f=self.f) + class unreduced_sfix(_single): int_type = sint @@ -5395,13 +5435,21 @@ def get(self, indices): regint.inc(len(indices), self.address, 0) + indices, size=len(indices)) - def get_slice_vector(self, slice): + def get_slice_addresses(self, slice): assert self.value_type.n_elements() == 1 assert len(slice) <= self.total_size() base = regint.inc(len(slice), slice.address, 1, 1) - inc = regint.inc(len(slice), 0, 1, 1, 1) + inc = regint.inc(len(slice), self.address, 1, 1, 1) addresses = slice.value_type.load_mem(base) + inc - return self.value_type.load_mem(self.address + addresses) + return addresses + + def get_slice_vector(self, slice): + addresses = self.get_slice_addresses(slice) + return self.value_type.load_mem(addresses) + + def assign_slice_vector(self, slice, vector): + addresses = self.get_slice_addresses(slice) + vector.store_in_mem(addresses) def expand_to_vector(self, index, size): """ Create vector from single entry. @@ -5514,6 +5562,14 @@ def shuffle(self): """ Insecure shuffle in place. """ self.assign_vector(self.get(regint.inc(len(self)).shuffle())) + def secure_shuffle(self): + """ Secure shuffle in place according to the security model. """ + self.assign_vector(self.get_vector().secure_shuffle()) + + def secure_permute(self, *args, **kwargs): + """ Secure permutate in place according to the security model. """ + self.assign_vector(self.get_vector().secure_permute(*args, **kwargs)) + def randomize(self, *args): """ Randomize according to data type. """ self.assign_vector(self.value_type.get_random(*args, size=len(self))) @@ -5570,15 +5626,26 @@ def reveal_to(self, player): """ return personal(player, self.create_from(self[:].reveal_to(player)._v)) - def sort(self, n_threads=None): + def sort(self, n_threads=None, batcher=False, n_bits=None): """ - Sort in place using Batchers' odd-even merge mergesort - with complexity :math:`O(n (\log n)^2)`. + Sort in place using radix sort with complexity :math:`O(n \log + n)` for :py:class:`sint` and :py:class:`sfix`, and Batcher's + odd-even mergesort with :math:`O(n (\log n)^2)` for + :py:class:`sfloat`. :param n_threads: number of threads to use (single thread by - default) + default), need to use Batcher's algorithm for several threads + :param batcher: use Batcher's odd-even mergesort in any case + :param n_bits: number of bits in keys (default: global bit length) """ - library.loopy_odd_even_merge_sort(self, n_threads=n_threads) + if batcher or self.value_type.n_elements() > 1: + library.loopy_odd_even_merge_sort(self, n_threads=n_threads) + else: + if n_threads or 1 > 1: + raise CompilerError('multi-threaded sorting only implemented ' + 'with Batcher\'s odd-even mergesort') + import sorting + sorting.radix_sort(self, self, n_bits=n_bits) def Array(self, size): # compatibility with registers @@ -5619,6 +5686,8 @@ def __getitem__(self, index): :return: :py:class:`Array` if one-dimensional, :py:class:`SubMultiArray` otherwise""" if isinstance(index, slice) and index == slice(None): return self.get_vector() + if isinstance(index, int) and index < 0: + index += self.sizes[0] key = program.curr_block, str(index) if key not in self.sub_cache: if util.is_constant(index) and \ @@ -5673,6 +5742,10 @@ def f(i): def total_size(self): return reduce(operator.mul, self.sizes) * self.value_type.n_elements() + def part_size(self): + return reduce(operator.mul, self.sizes[1:]) * \ + self.value_type.n_elements() + def get_vector(self, base=0, size=None): """ Return vector with content. Not implemented for floating-point. @@ -5731,13 +5804,21 @@ def get_slice_vector(self, slice): :param slice: regint array """ + addresses = self.get_slice_addresses(slice) + return self.value_type.load_mem(self.address + addresses) + + def assign_slice_vector(self, slice, vector): + addresses = self.get_slice_addresses(slice) + vector.store_in_mem(self.address + addresses) + + def get_slice_addresses(self, slice): assert self.value_type.n_elements() == 1 part_size = reduce(operator.mul, self.sizes[1:]) assert len(slice) * part_size <= self.total_size() base = regint.inc(len(slice) * part_size, slice.address, 1, part_size) inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size) addresses = slice.value_type.load_mem(base) * part_size + inc - return self.value_type.load_mem(self.address + addresses) + return addresses def get_addresses(self, *indices): assert self.value_type.n_elements() == 1 @@ -6218,6 +6299,31 @@ def diag(self): n = self.sizes[0] return self.array.get(regint.inc(n, 0, n + 1)) + def secure_shuffle(self): + """ Securely shuffle rows (first index). """ + self.assign_vector(self.get_vector().secure_shuffle(self.part_size())) + + def secure_permute(self, permutation, reverse=False): + """ Securely permute rows (first index). """ + self.assign_vector(self.get_vector().secure_permute( + permutation, self.part_size(), reverse)) + + def sort(self, key_indices=None, n_bits=None): + """ Sort sub-arrays (different first index) in place. + + :param key_indices: indices to sorting keys, for example + ``(1, 2)`` to sort three-dimensional array ``a`` by keys + ``a[*][1][2]``. Default is ``(0, ..., 0)`` of correct length. + :param n_bits: number of bits in keys (default: global bit length) + + """ + if key_indices is None: + key_indices = (0,) * (len(self.sizes) - 1) + key_indices = (None,) + util.tuplify(key_indices) + import sorting + keys = self.get_vector_by_indices(*key_indices) + sorting.radix_sort(keys, self, n_bits=n_bits) + def randomize(self, *args): """ Randomize according to data type. """ if self.total_size() < program.options.budget: @@ -6334,6 +6440,18 @@ def __init__(self, rows, columns, value_type, debug=None, address=None): MultiArray.__init__(self, [rows, columns], value_type, debug=debug, \ address=address) + @staticmethod + def create_from(rows): + rows = list(rows) + if isinstance(rows[0], (list, tuple)): + t = type(rows[0][0]) + else: + t = type(rows[0]) + res = Matrix(len(rows), len(rows[0]), t) + for i in range(len(rows)): + res[i].assign(rows[i]) + return res + def get_column(self, index): """ Get column as vector. @@ -6344,6 +6462,9 @@ def get_column(self, index): self.sizes[1]) return self.value_type.load_mem(addresses) + def get_columns(self): + return (self.get_column(i) for i in range(self.sizes[1])) + def get_column_by_row_indices(self, rows, column): assert self.value_type.n_elements() == 1 addresses = rows * self.sizes[1] + \ diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h index 4e1e4c4b0..fc5571b1d 100644 --- a/ExternalIO/Client.h +++ b/ExternalIO/Client.h @@ -47,17 +47,8 @@ inline void receive(client_socket* socket, octet* data, size_t len) #else typedef ssl_ctx client_ctx; +typedef ssl_socket client_socket; -class client_socket : public ssl_socket -{ -public: - client_socket(boost::asio::io_service& io_service, - boost::asio::ssl::context& ctx, int plaintext_socket, string other, - string me, bool client) : - ssl_socket(io_service, ctx, plaintext_socket, other, me, client) - { - } -}; #endif /** diff --git a/FHE/AddableVector.h b/FHE/AddableVector.h index 1efe1e228..b0a287444 100644 --- a/FHE/AddableVector.h +++ b/FHE/AddableVector.h @@ -58,7 +58,8 @@ class AddableVector: public vector { if (this->size() != y.size()) throw out_of_range("vector length mismatch"); - for (unsigned int i = 0; i < this->size(); i++) + size_t n = this->size(); + for (unsigned int i = 0; i < n; i++) (*this)[i] += y[i]; return *this; } @@ -67,9 +68,11 @@ class AddableVector: public vector { if (this->size() != y.size()) throw out_of_range("vector length mismatch"); - AddableVector res(y.size()); - for (unsigned int i = 0; i < this->size(); i++) - res[i] = (*this)[i] - y[i]; + AddableVector res; + res.reserve(y.size()); + size_t n = this->size(); + for (unsigned int i = 0; i < n; i++) + res.push_back((*this)[i] - y[i]); return res; } diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp index 9afef83ce..00e051318 100644 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -31,6 +31,12 @@ word check_pk_id(word a, word b) } +void Ciphertext::Scale() +{ + Scale(params->get_plaintext_modulus()); +} + + void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1) { if (c0.params!=c1.params) { throw params_mismatch(); } @@ -115,9 +121,28 @@ void Ciphertext::add(octetStream& os) *this += tmp; } +void Ciphertext::rerandomize(const FHE_PK& pk) +{ + Rq_Element tmp(*params); + SeededPRNG G; + vector r(params->FFTD()[0].m()); + bigint p = pk.p(); + assert(p != 0); + for (auto& x : r) + { + G.get(x, params->p0().numBits() - p.numBits() - 1); + x *= p; + } + tmp.from(r, 0); + Scale(); + cc0 += tmp; + auto zero = pk.encrypt(*params); + zero.Scale(pk.p()); + *this += zero; +} + template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); -template void mul(Ciphertext& ans,const Plaintext& a,const Ciphertext& c); - - +template void mul(Ciphertext& ans, const Plaintext& a, + const Ciphertext& c); diff --git a/FHE/Ciphertext.h b/FHE/Ciphertext.h index d455f1268..11a23e2ab 100644 --- a/FHE/Ciphertext.h +++ b/FHE/Ciphertext.h @@ -15,6 +15,12 @@ template void mul(Ciphertext& ans,const Ciphertext& c, void add(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1); void mul(Ciphertext& ans,const Ciphertext& c0,const Ciphertext& c1,const FHE_PK& pk); +/** + * BGV ciphertext. + * The class allows adding two ciphertexts as well as adding a plaintext and + * a ciphertext via operator overloading. The multiplication of two ciphertexts + * requires the public key and thus needs a separate function. + */ class Ciphertext { Rq_Element cc0,cc1; @@ -54,6 +60,7 @@ class Ciphertext // Scale down an element from level 1 to level 0, if at level 0 do nothing void Scale(const bigint& p) { cc0.Scale(p); cc1.Scale(p); } + void Scale(); // Throws error if ans,c0,c1 etc have different params settings // - Thus programmer needs to ensure this rather than this being done @@ -90,6 +97,12 @@ class Ciphertext template Ciphertext& operator*=(const Plaintext_& other) { ::mul(*this, *this, other); return *this; } + /** + * Ciphertext multiplication. + * @param pk public key + * @param x second ciphertext + * @returns product ciphertext + */ Ciphertext mul(const FHE_PK& pk, const Ciphertext& x) const { Ciphertext res(*params); ::mul(res, *this, x, pk); return res; } @@ -98,14 +111,18 @@ class Ciphertext return {cc0.mul_by_X_i(i), cc1.mul_by_X_i(i), *this}; } + /// Re-randomize for circuit privacy. + void rerandomize(const FHE_PK& pk); + int level() const { return cc0.level(); } - // pack/unpack (like IO) also assume params are known and already set - // correctly + /// Append to buffer void pack(octetStream& o) const { cc0.pack(o); cc1.pack(o); o.store(pk_id); } - void unpack(octetStream& o) - { cc0.unpack(o); cc1.unpack(o); o.get(pk_id); } + + /// Read from buffer. Assumes parameters are set correctly + void unpack(octetStream& o) + { cc0.unpack(o, *params); cc1.unpack(o, *params); o.get(pk_id); } void output(ostream& s) const { cc0.output(s); cc1.output(s); s.write((char*)&pk_id, sizeof(pk_id)); } diff --git a/FHE/Diagonalizer.cpp b/FHE/Diagonalizer.cpp index 9cc1a0840..958cd28cc 100644 --- a/FHE/Diagonalizer.cpp +++ b/FHE/Diagonalizer.cpp @@ -64,8 +64,11 @@ Diagonalizer::MatrixVector Diagonalizer::dediag( { auto& c = products.at(i); for (int j = 0; j < n_matrices; j++) + { + res.at(j).entries.init(); for (size_t k = 0; k < n_rows; k++) res.at(j)[{k, i}] = c.element(j * n_rows + k); + } } return res; } diff --git a/FHE/FFT_Data.cpp b/FHE/FFT_Data.cpp index c71a4c5da..d3b67b506 100644 --- a/FHE/FFT_Data.cpp +++ b/FHE/FFT_Data.cpp @@ -7,6 +7,11 @@ +FFT_Data::FFT_Data() : + twop(-1) +{ +} + void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD) { R=Rg; diff --git a/FHE/FFT_Data.h b/FHE/FFT_Data.h index c5d6b2063..4fb37ed48 100644 --- a/FHE/FFT_Data.h +++ b/FHE/FFT_Data.h @@ -50,7 +50,7 @@ class FFT_Data void pack(octetStream& o) const; void unpack(octetStream& o); - FFT_Data() { ; } + FFT_Data(); FFT_Data(const Ring& Rg,const Zp_Data& PrD) { init(Rg,PrD); } diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 20dfb1bb5..742c85452 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -12,6 +12,11 @@ FHE_SK::FHE_SK(const FHE_PK& pk) : FHE_SK(pk.get_params(), pk.p()) { } +FHE_SK::FHE_SK(const FHE_Params& pms) : + FHE_SK(pms, pms.get_plaintext_modulus()) +{ +} + FHE_SK& FHE_SK::operator+=(const FHE_SK& c) { @@ -38,6 +43,11 @@ void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G) } +FHE_PK::FHE_PK(const FHE_Params& pms) : + FHE_PK(pms, pms.get_plaintext_modulus()) +{ +} + Rq_Element FHE_PK::sample_secret_key(PRNG& G) { Rq_Element sk = FHE_SK(*this).s(); @@ -179,32 +189,51 @@ Ciphertext FHE_PK::encrypt(const Plaintext& template Ciphertext FHE_PK::encrypt( const Plaintext& mess) const +{ + return encrypt(Rq_Element(*params, mess)); +} + +Ciphertext FHE_PK::encrypt(const Rq_Element& mess) const { Random_Coins rc(*params); PRNG G; G.ReSeed(); rc.generate(G); - return encrypt(mess, rc); + Ciphertext res(*params); + quasi_encrypt(res, mess, rc); + return res; } template void FHE_SK::decrypt(Plaintext& mess,const Ciphertext& c) const { - if (&c.get_params()!=params) { throw params_mismatch(); } if (T::characteristic_two ^ (pr == 2)) throw pr_mismatch(); + Rq_Element ans = quasi_decrypt(c); + mess.set_poly_mod(ans.get_iterator(), ans.get_modulus()); +} + +Rq_Element FHE_SK::quasi_decrypt(const Ciphertext& c) const +{ + if (&c.get_params()!=params) { throw params_mismatch(); } + Rq_Element ans; mul(ans,c.c1(),sk); sub(ans,c.c0(),ans); ans.change_rep(polynomial); - mess.set_poly_mod(ans.get_iterator(), ans.get_modulus()); + return ans; } +Plaintext_ FHE_SK::decrypt(const Ciphertext& c) +{ + return decrypt(c, params->get_plaintext_field_data()); +} + template Plaintext FHE_SK::decrypt(const Ciphertext& c, const FD& FieldD) { @@ -299,12 +328,12 @@ void FHE_PK::unpack(octetStream& o) o.consume((octet*) tag, 8); if (memcmp(tag, "PKPKPKPK", 8)) throw runtime_error("invalid serialization of public key"); - a0.unpack(o); - b0.unpack(o); + a0.unpack(o, *params); + b0.unpack(o, *params); if (params->n_mults() > 0) { - Sw_a.unpack(o); - Sw_b.unpack(o); + Sw_a.unpack(o, *params); + Sw_b.unpack(o, *params); } pr.unpack(o); } @@ -322,7 +351,6 @@ bool FHE_PK::operator!=(const FHE_PK& x) const return false; } - void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const { @@ -345,8 +373,6 @@ void FHE_SK::check(const FHE_PK& pk, const FD& FieldD) throw runtime_error("incorrect key pair"); } - - void FHE_PK::check(const FHE_Params& params, const bigint& pr) const { if (this->pr != pr) @@ -361,6 +387,24 @@ void FHE_PK::check(const FHE_Params& params, const bigint& pr) const } } +bigint FHE_SK::get_noise(const Ciphertext& c) +{ + sk.lower_level(); + Ciphertext cc = c; + if (cc.level()) + cc.Scale(); + Rq_Element tmp = quasi_decrypt(cc); + bigint res; + bigint q = tmp.get_modulus(); + bigint half_q = q / 2; + for (auto& x : tmp.to_vec_bigint()) + { +// cout << numBits(x) << "/" << (x > half_q) << "/" << (x < 0) << " "; + res = max(res, x > half_q ? x - q : x); + } + return res; +} + template void FHE_PK::encrypt(Ciphertext&, const Plaintext_& mess, diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index 30ecc2925..f342e203b 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -12,6 +12,10 @@ class FHE_PK; class Ciphertext; +/** + * BGV secret key. + * The class allows addition. + */ class FHE_SK { Rq_Element sk; @@ -29,6 +33,8 @@ class FHE_SK // secret key always on lower level void assign(const Rq_Element& s) { sk=s; sk.lower_level(); } + FHE_SK(const FHE_Params& pms); + FHE_SK(const FHE_Params& pms, const bigint& p) : sk(pms.FFTD(),evaluation,evaluation) { params=&pms; pr=p; } @@ -38,8 +44,11 @@ class FHE_SK const Rq_Element& s() const { return sk; } + /// Append to buffer void pack(octetStream& os) const { sk.pack(os); pr.pack(os); } - void unpack(octetStream& os) { sk.unpack(os); pr.unpack(os); } + + /// Read from buffer. Assumes parameters are set correctly + void unpack(octetStream& os) { sk.unpack(os, *params); pr.unpack(os); } // Assumes Ring and prime of mess have already been set correctly // Ciphertext c must be at level 0 or an error occurs @@ -50,9 +59,14 @@ class FHE_SK template Plaintext decrypt(const Ciphertext& c, const FD& FieldD); + /// Decryption for cleartexts modulo prime + Plaintext_ decrypt(const Ciphertext& c); + template void decrypt_any(Plaintext_& mess, const Ciphertext& c); + Rq_Element quasi_decrypt(const Ciphertext& c) const; + // Three stage procedure for Distributed Decryption // - First stage produces my shares // - Second stage adds in another players shares, do this once for each other player @@ -62,7 +76,6 @@ class FHE_SK void dist_decrypt_1(vector& vv,const Ciphertext& ctx,int player_number,int num_players) const; void dist_decrypt_2(vector& vv,const vector& vv1) const; - friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); /* Add secret keys @@ -82,10 +95,15 @@ class FHE_SK template void check(const FHE_PK& pk, const FD& FieldD); + bigint get_noise(const Ciphertext& c); + friend ostream& operator<<(ostream& o, const FHE_SK&) { throw not_implemented(); return o; } }; +/** + * BGV public key. + */ class FHE_PK { Rq_Element a0,b0; @@ -104,8 +122,10 @@ class FHE_PK ) { a0=a; b0=b; Sw_a=sa; Sw_b=sb; } - - FHE_PK(const FHE_Params& pms, const bigint& p = 0) + + FHE_PK(const FHE_Params& pms); + + FHE_PK(const FHE_Params& pms, const bigint& p) : a0(pms.FFTD(),evaluation,evaluation), b0(pms.FFTD(),evaluation,evaluation), Sw_a(pms.FFTD(),evaluation,evaluation), @@ -143,8 +163,11 @@ class FHE_PK template Ciphertext encrypt(const Plaintext& mess, const Random_Coins& rc) const; + + /// Encryption template Ciphertext encrypt(const Plaintext& mess) const; + Ciphertext encrypt(const Rq_Element& mess) const; friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); @@ -156,8 +179,10 @@ class FHE_PK void check_noise(const FHE_SK& sk) const; void check_noise(const Rq_Element& x, bool check_modulo = false) const; - // params setting is done out of these IO/pack/unpack functions + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer. Assumes parameters are set correctly void unpack(octetStream& o); bool operator!=(const FHE_PK& x) const; @@ -170,21 +195,39 @@ class FHE_PK void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); +/** + * BGV key pair + */ class FHE_KeyPair { public: + /// Public key FHE_PK pk; + /// Secret key FHE_SK sk; - FHE_KeyPair(const FHE_Params& params, const bigint& pr = 0) : + FHE_KeyPair(const FHE_Params& params, const bigint& pr) : pk(params, pr), sk(params, pr) { } + /// Initialization + FHE_KeyPair(const FHE_Params& params) : + pk(params), sk(params) + { + } + void generate(PRNG& G) { KeyGen(pk, sk, G); } + + /// Generate fresh keys + void generate() + { + SeededPRNG G; + generate(G); + } }; template diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 5fb07f233..5a0f3991c 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -1,5 +1,6 @@ #include "FHE_Params.h" +#include "NTL-Subs.h" #include "FHE/Ring_Element.h" #include "Tools/Exceptions.h" #include "Protocols/HemiOptions.h" @@ -67,6 +68,7 @@ void FHE_Params::pack(octetStream& o) const Bval.pack(o); o.store(sec_p); o.store(matrix_dim); + fd.pack(o); } void FHE_Params::unpack(octetStream& o) @@ -80,6 +82,7 @@ void FHE_Params::unpack(octetStream& o) Bval.unpack(o); o.get(sec_p); o.get(matrix_dim); + fd.unpack(o); } bool FHE_Params::operator!=(const FHE_Params& other) const @@ -92,3 +95,37 @@ bool FHE_Params::operator!=(const FHE_Params& other) const else return false; } + +void FHE_Params::basic_generation_mod_prime(int plaintext_length) +{ + if (n_mults() == 0) + generate_semi_setup(plaintext_length, 0, *this, fd, false); + else + { + Parameters parameters(1, plaintext_length, 0); + parameters.generate_setup(*this, fd); + } +} + +template<> +const FFT_Data& FHE_Params::get_plaintext_field_data() const +{ + return fd; +} + +template<> +const P2Data& FHE_Params::get_plaintext_field_data() const +{ + throw not_implemented(); +} + +template<> +const PPData& FHE_Params::get_plaintext_field_data() const +{ + throw not_implemented(); +} + +bigint FHE_Params::get_plaintext_modulus() const +{ + return fd.get_prime(); +} diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index 8821e2e29..4733245ca 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -15,6 +15,9 @@ #include "Tools/random.h" #include "Protocols/config.h" +/** + * Cryptosystem parameters + */ class FHE_Params { protected: @@ -29,8 +32,15 @@ class FHE_Params bigint Bval; int matrix_dim; + FFT_Data fd; + public: + /** + * Initialization. + * @param n_mults number of ciphertext multiplications (0/1) + * @param drown_sec parameter for function privacy (default 40) + */ FHE_Params(int n_mults = 1, int drown_sec = DEFAULT_SECURITY); int n_mults() const { return FFTData.size() - 1; } @@ -59,10 +69,24 @@ class FHE_Params int phi_m() const { return FFTData[0].phi_m(); } const Ring& get_ring() { return FFTData[0].get_R(); } + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer void unpack(octetStream& o); bool operator!=(const FHE_Params& other) const; + + /** + * Generate parameter for computation modulo a prime + * @param plaintext_length bit length of prime + */ + void basic_generation_mod_prime(int plaintext_length); + + template + const FD& get_plaintext_field_data() const; + + bigint get_plaintext_modulus() const; }; #endif diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 22705bedf..794e7431d 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -107,10 +107,12 @@ int generate_semi_setup(int plaintext_length, int sec, int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, bool round_up) { +#ifdef VERBOSE cout << "Need ciphertext modulus of length " << lgp0; if (params.n_mults() > 0) cout << "+" << lgp1; cout << " and " << phi_N(m) << " slots" << endl; +#endif int extra_slack = 0; if (round_up) @@ -125,8 +127,10 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, } extra_slack = i - 1; lgp0 += extra_slack; +#ifdef VERBOSE cout << "Rounding up to " << lgp0 << ", giving extra slack of " << extra_slack << " bits" << endl; +#endif } Ring R; @@ -148,11 +152,15 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int& lgp0, int lgp1, int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, bool round_up, FHE_Params& params) { + (void) lg2pi, (void) n; + +#ifdef VERBOSE if (n >= 2 and n <= 10) cout << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2] << ", for p1: " << lg2p1 - lg2pi[9 + n - 2] << endl; cout << "p0 needs " << int(ceil(1. * lg2p0 / 64)) << " words" << endl; cout << "p1 needs " << int(ceil(1. * lg2p1 / 64)) << " words" << endl; +#endif int extra_slack = 0; if (round_up) @@ -171,11 +179,15 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, extra_slack = 2 * i; lg2p0 += i; lg2p1 += i; +#ifdef VERBOSE cout << "Rounding up to " << lg2p0 << "+" << lg2p1 << ", giving extra slack of " << extra_slack << " bits" << endl; +#endif } +#ifdef VERBOSE cout << "Total length: " << lg2p0 + lg2p1 << endl; +#endif return extra_slack; } @@ -215,12 +227,21 @@ int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p, { double phi_m_bound = NoiseBounds(p, phi_N(m), n, sec, slack, params).optimize(lg2p0, lg2p1); + +#ifdef VERBOSE cout << "Trying primes of length " << lg2p0 << " and " << lg2p1 << endl; +#endif + if (phi_N(m) < phi_m_bound) { int old_m = m; + (void) old_m; m = 2 << int(ceil(log2(phi_m_bound))); + +#ifdef VERBOSE cout << "m = " << old_m << " too small, increasing it to " << m << endl; +#endif + generate_prime(p, numBits(p), m); } else @@ -244,6 +265,8 @@ void generate_moduli(bigint& pr0, bigint& pr1, const int m, const bigint p, void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, const string& i, const bigint& pr0) { + (void) i; + if (lg2pr==0) { throw invalid_params(); } bigint step=m; @@ -260,13 +283,14 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, assert(numBits(pr) == lg2pr); } +#ifdef VERBOSE cout << "\t pr" << i << " = " << pr << " : " << numBits(pr) << endl; + cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl; +#endif assert(pr % m == 1); assert(pr % p == 1); assert(numBits(pr) == lg2pr); - - cout << "Minimal MAX_MOD_SZ = " << int(ceil(1. * lg2pr / 64)) << endl; } /* @@ -626,6 +650,9 @@ void char_2_dimension(int& m, int& lg2) case 16: m = 4369; break; + case 15: + m = 4681; + break; case 12: m = 4095; break; diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index e2df9583f..f4502317e 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -167,7 +167,7 @@ bigint NoiseBounds::min_p0(const bigint& p1) bigint NoiseBounds::min_p1() { - return drown * B_KS + 1; + return max(bigint(drown * B_KS), bigint((phi_m * p) << 10)); } bigint NoiseBounds::opt_p1() @@ -181,8 +181,10 @@ bigint NoiseBounds::opt_p1() // solve mpf_class s = (-b + sqrt(b * b - 4 * a * c)) / (2 * a); bigint res = ceil(s); +#ifdef VERBOSE cout << "Optimal p1 vs minimal: " << numBits(res) << "/" << numBits(min_p1()) << endl; +#endif return res; } @@ -194,8 +196,10 @@ double NoiseBounds::optimize(int& lg2p0, int& lg2p1) { min_p0 *= 2; min_p1 *= 2; +#ifdef VERBOSE cout << "increasing lengths: " << numBits(min_p0) << "/" << numBits(min_p1) << endl; +#endif } lg2p1 = numBits(min_p1); lg2p0 = numBits(min_p0); diff --git a/FHE/NoiseBounds.h b/FHE/NoiseBounds.h index ccd50808a..565c663ef 100644 --- a/FHE/NoiseBounds.h +++ b/FHE/NoiseBounds.h @@ -42,6 +42,8 @@ class SemiHomomorphicNoiseBounds bigint min_p0(bool scale, const bigint& p1) { return scale ? min_p0(p1) : min_p0(); } static double min_phi_m(int log_q, double sigma); static double min_phi_m(int log_q, const FHE_Params& params); + + bigint get_B_clean() { return B_clean; } }; // as per ePrint 2012:642 for slack = 0 diff --git a/FHE/P2Data.cpp b/FHE/P2Data.cpp index 7d9a8ca47..ac4ae6f16 100644 --- a/FHE/P2Data.cpp +++ b/FHE/P2Data.cpp @@ -55,13 +55,13 @@ void P2Data::check_dimensions() const // cout << "Ai: " << Ai.size() << "x" << Ai[0].size() << endl; if (A.size() != Ai.size()) throw runtime_error("forward and backward mapping dimensions mismatch"); - if (A.size() != A[0].size()) + if (A.size() != A.at(0).size()) throw runtime_error("forward mapping not square"); - if (Ai.size() != Ai[0].size()) + if (Ai.size() != Ai.at(0).size()) throw runtime_error("backward mapping not square"); - if ((int)A[0].size() != slots * gf2n_short::degree()) + if ((int)A.at(0).size() != slots * gf2n_short::degree()) throw runtime_error( - "mapping dimension incorrect: " + to_string(A[0].size()) + "mapping dimension incorrect: " + to_string(A.at(0).size()) + " != " + to_string(slots) + " * " + to_string(gf2n_short::degree())); } diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp index 84cbb9d19..4eba6e8f0 100644 --- a/FHE/Plaintext.cpp +++ b/FHE/Plaintext.cpp @@ -11,10 +11,43 @@ +template +Plaintext::Plaintext(const FHE_Params& params) : + Plaintext(params.get_plaintext_field_data(), Both) +{ +} + + +template +unsigned int Plaintext::num_slots() const +{ + return (*Field_Data).phi_m(); +} + +template +int Plaintext::degree() const +{ + return (*Field_Data).phi_m(); +} + + +template<> +unsigned int Plaintext::num_slots() const +{ + return (*Field_Data).num_slots(); +} + +template<> +int Plaintext::degree() const +{ + return (*Field_Data).degree(); +} + + template<> void Plaintext::from(const Generator& source) const { - b.resize(degree); + b.resize(degree()); for (auto& x : b) { source.get(bigint::tmp); @@ -31,7 +64,7 @@ void Plaintext::from_poly() const Ring_Element e(*Field_Data,polynomial); e.from(b); e.change_rep(evaluation); - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::from_poly() const for (unsigned int i=0; iget_prD()}; type=Both; @@ -90,7 +123,7 @@ template<> void Plaintext::from_poly() const { if (type!=Polynomial) { return; } - a.resize(n_slots); + a.resize(num_slots()); (*Field_Data).backward(a,b); type=Both; } @@ -106,34 +139,13 @@ void Plaintext::to_poly() const -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).phi_m(); - degree = n_slots; -} - - -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).phi_m(); - degree = n_slots; -} - - -template<> -void Plaintext::set_sizes() -{ n_slots = (*Field_Data).num_slots(); - degree = (*Field_Data).degree(); -} - - template void Plaintext::allocate(PT_Type type) const { if (type != Evaluation) - b.resize(degree); + b.resize(degree()); if (type != Polynomial) - a.resize(n_slots); + a.resize(num_slots()); this->type = type; } @@ -141,7 +153,7 @@ void Plaintext::allocate(PT_Type type) const template void Plaintext::allocate_slots(const bigint& value) { - b.resize(degree); + b.resize(degree()); for (auto& x : b) x.allocate_slots(value); } @@ -236,7 +248,7 @@ void Plaintext::randomize(PRNG& G,condition cond) type=Polynomial; break; case Diagonal: - a.resize(n_slots); + a.resize(num_slots()); a[0].randomize(G); for (unsigned int i=1; i::randomize(PRNG& G,condition cond) break; default: // Gen a plaintext with 0/1 in each slot - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::randomize(PRNG& G, int n_bits, bool Diag, PT_Type t) b[0].generateUniform(G, n_bits, false); } else - for (int i = 0; i < n_slots; i++) + for (size_t i = 0; i < num_slots(); i++) b[i].generateUniform(G, n_bits, false); break; default: @@ -288,7 +300,7 @@ void Plaintext::assign_zero(PT_Type t) allocate(); if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::assign_one(PT_Type t) allocate(); if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i& z,const Plaintext& z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& x, if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& z,const Plaintext& z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext& x, z.allocate(); if (z.type!=Polynomial) { - z.a.resize(z.n_slots); + z.a.resize(z.num_slots()); for (unsigned int i=0; i& z,const Plaintext::negate() { if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::negate() { if (type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::equals(const Plaintext& x) const if (type!=Polynomial and x.type!=Polynomial) { - a.resize(n_slots); + a.resize(num_slots()); for (unsigned int i=0; i::unpack(octetStream& o) unsigned int size; o.get(size); allocate(); - if (size != b.size()) + if (size != b.size() and size != 0) throw length_error("unexpected length received"); - for (unsigned int i = 0; i < b.size(); i++) + for (unsigned int i = 0; i < size; i++) b[i] = o.get(); } diff --git a/FHE/Plaintext.h b/FHE/Plaintext.h index 52ff8b6d4..c8fb93c73 100644 --- a/FHE/Plaintext.h +++ b/FHE/Plaintext.h @@ -18,6 +18,7 @@ */ #include "FHE/Generator.h" +#include "FHE/FFT_Data.h" #include "Math/fixint.h" #include @@ -25,6 +26,8 @@ using namespace std; class FHE_PK; class Rq_Element; +class FHE_Params; +class FFT_Data; template class AddableVector; // Forward declaration as apparently this is needed for friends in templates @@ -38,13 +41,19 @@ enum condition { Full, Diagonal, Bits }; enum PT_Type { Polynomial, Evaluation, Both }; +/** + * BGV plaintext. + * Use ``Plaintext_mod_prime`` instead of filling in the templates. + * The plaintext is held in one of the two representations or both, + * polynomial and evaluation. The latter is the one allowing element-wise + * multiplication over a vector. + * Plaintexts can be added, subtracted, and multiplied via operator overloading. + */ template class Plaintext { typedef typename FD::poly_type S; - int n_slots; - int degree; mutable vector a; // The thing in evaluation/FFT form mutable vector b; // Now in polynomial form @@ -58,33 +67,47 @@ class Plaintext const FD *Field_Data; - void set_sizes(); + int degree() const; public: const FD& get_field() const { return *Field_Data; } - unsigned int num_slots() const { return n_slots; } + + /// Number of slots + unsigned int num_slots() const; Plaintext(const FD& FieldD, PT_Type type = Polynomial) - { Field_Data=&FieldD; set_sizes(); allocate(type); } + { Field_Data=&FieldD; allocate(type); } Plaintext(const FD& FieldD, const Rq_Element& other); + /// Initialization + Plaintext(const FHE_Params& params); + void allocate(PT_Type type) const; void allocate() const { allocate(type); } void allocate_slots(const bigint& value); int get_min_alloc(); - // Access evaluation representation + /** + * Read slot. + * @param i slot number + * @returns slot content + */ T element(int i) const { if (type==Polynomial) { from_poly(); } return a[i]; } + /** + * Write to slot + * @param i slot number + * @param e new slot content + */ void set_element(int i,const T& e) { if (type==Polynomial) { throw not_implemented(); } - a.resize(n_slots); + a.resize(num_slots()); a[i]=e; type=Evaluation; } @@ -171,10 +194,10 @@ class Plaintext bool is_diagonal() const; - /* Pack and unpack into an octetStream - * For unpack we assume the FFTD has been assigned correctly already - */ + /// Append to buffer void pack(octetStream& o) const; + + /// Read from buffer. Assumes parameters are set correctly void unpack(octetStream& o); size_t report_size(ReportType type); @@ -185,4 +208,6 @@ class Plaintext template using Plaintext_ = Plaintext; +typedef Plaintext_ Plaintext_mod_prime; + #endif diff --git a/FHE/Ring.cpp b/FHE/Ring.cpp index c1c318b8d..3b63f3069 100644 --- a/FHE/Ring.cpp +++ b/FHE/Ring.cpp @@ -24,7 +24,7 @@ void Ring::unpack(octetStream& o) o.get(pi_inv); o.get(poly); } - else + else if (mm != 0) init(*this, mm); } diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 554d4dc10..39690fa6a 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -87,7 +87,6 @@ void Ring_Element::negate() void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) { - if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } if (a.element.empty()) { @@ -100,6 +99,8 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) return; } + if (a.rep!=b.rep) { throw rep_mismatch(); } + if (&ans == &a) { ans += b; diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index 531df90f7..d6a14aabd 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -5,7 +5,7 @@ #include "Math/modp.hpp" Rq_Element::Rq_Element(const FHE_PK& pk) : - Rq_Element(pk.get_params().FFTD()) + Rq_Element(pk.get_params().FFTD(), evaluation, evaluation) { } @@ -347,6 +347,12 @@ size_t Rq_Element::report_size(ReportType type) const return sz; } +void Rq_Element::unpack(octetStream& o, const FHE_Params& params) +{ + set_data(params.FFTD()); + unpack(o); +} + void Rq_Element::print_first_non_zero() const { vector v = to_vec_bigint(); diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index a58cb7de0..4e0cdf97b 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -69,8 +69,9 @@ class Rq_Element a({b0}), lev(n_mults()) {} template - Rq_Element(const FHE_Params& params, const Plaintext& plaintext) : - Rq_Element(params) + Rq_Element(const FHE_Params& params, const Plaintext& plaintext, + RepType r0 = polynomial, RepType r1 = polynomial) : + Rq_Element(params, r0, r1) { from(plaintext.get_iterator()); } @@ -159,6 +160,9 @@ class Rq_Element void pack(octetStream& o) const; void unpack(octetStream& o); + // without prior initialization + void unpack(octetStream& o, const FHE_Params& params); + void output(ostream& s) const; void input(istream& s); diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 43ad7e842..3df98c851 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -57,7 +57,7 @@ void Multiplier::multiply_and_add(Plaintext_& res, template void Multiplier::add(Plaintext_& res, const Ciphertext& c, - OT_ROLE role, int n_summands) + OT_ROLE role, int) { o.reset_write_head(); @@ -67,20 +67,10 @@ void Multiplier::add(Plaintext_& res, const Ciphertext& c, G.ReSeed(); timers["Mask randomization"].start(); product_share.randomize(G); - bigint B = 6 * machine.setup().params.get_R(); - B *= machine.setup().FieldD.get_prime(); - B <<= machine.setup().params.secp(); - // slack - B *= NonInteractiveProof::slack(machine.sec, - machine.setup().params.phi_m()); - B <<= machine.extra_slack; - B *= n_summands; - rc.generateUniform(G, 0, B, B); + mask = c; + mask.rerandomize(other_pk); timers["Mask randomization"].stop(); - timers["Encryption"].start(); - other_pk.encrypt(mask, product_share, rc); - timers["Encryption"].stop(); - mask += c; + mask += product_share; mask.pack(o); res -= product_share; } diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index 59223ad03..019711829 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -75,6 +75,8 @@ void secure_init(T& setup, Player& P, U& machine, + OnlineOptions::singleton.prime.get_str() + "-" + to_string(CowGearOptions::singleton.top_gear()) + "-P" + to_string(P.my_num()) + "-" + to_string(P.num_players()); + string reason; + try { ifstream file(filename); @@ -82,12 +84,30 @@ void secure_init(T& setup, Player& P, U& machine, os.input(file); os.get(machine.extra_slack); setup.unpack(os); + } + catch (exception& e) + { + reason = e.what(); + } + + try + { setup.check(P, machine); } - catch (...) + catch (exception& e) + { + reason = e.what(); + } + + if (not reason.empty()) { - cout << "Finding parameters for security " << sec << " and field size ~2^" - << plaintext_length << endl; + if (OnlineOptions::singleton.verbose) + cerr << "Generating parameters for security " << sec + << " and field size ~2^" << plaintext_length + << " because no suitable material " + "from a previous run was found (" << reason << ")" + << endl; + setup = {}; setup.generate(P, machine, plaintext_length, sec); setup.check(P, machine); octetStream os; diff --git a/GC/NoShare.h b/GC/NoShare.h index 49f93ac42..917e71c5e 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -50,11 +50,6 @@ class NoValue : public ValueInterface return "no"; } - static string type_short() - { - return "no"; - } - static DataFieldType field_type() { throw not_implemented(); @@ -66,7 +61,7 @@ class NoValue : public ValueInterface static void fail() { - throw runtime_error("VM does not support binary circuits"); + throw runtime_error("functionality not available"); } NoValue() {} @@ -143,6 +138,11 @@ class NoShare : public ShareInterface return 0; } + static int length() + { + return 0; + } + static void fail() { NoValue::fail(); diff --git a/Machines/dealer-ring-party.cpp b/Machines/dealer-ring-party.cpp index 4bc8fab1a..890a24ab5 100644 --- a/Machines/dealer-ring-party.cpp +++ b/Machines/dealer-ring-party.cpp @@ -5,6 +5,7 @@ #include "Protocols/DealerShare.h" #include "Protocols/DealerInput.h" +#include "Protocols/Dealer.h" #include "Processor/RingMachine.hpp" #include "Processor/Machine.hpp" @@ -12,6 +13,7 @@ #include "Protocols/DealerPrep.hpp" #include "Protocols/DealerInput.hpp" #include "Protocols/DealerMC.hpp" +#include "Protocols/DealerMatrixPrep.hpp" #include "Protocols/Beaver.hpp" #include "Semi.hpp" #include "GC/DealerPrep.h" diff --git a/Machines/mama-party.cpp b/Machines/mama-party.cpp index f270b87ce..87bf15eaa 100644 --- a/Machines/mama-party.cpp +++ b/Machines/mama-party.cpp @@ -21,5 +21,5 @@ using MamaShare_ = MamaShare; int main(int argc, const char** argv) { ez::ezOptionParser opt; - DishonestMajorityFieldMachine(argc, argv, opt); + DishonestMajorityFieldMachine(argc, argv, opt); } diff --git a/Makefile b/Makefile index 3c2be0090..03366f89d 100644 --- a/Makefile +++ b/Makefile @@ -244,6 +244,7 @@ paper-example.x: $(VM) $(OT) $(FHEOFFLINE) binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o Machines/Tinier.o l2h-example.x: $(VM) $(OT) Machines/Tinier.o +he-example.x: $(FHEOFFLINE) mascot-offline.x: $(VM) $(TINIER) cowgear-offline.x: $(TINIER) $(FHEOFFLINE) static/rep-bmr-party.x: $(BMR) diff --git a/Math/FixedVec.h b/Math/FixedVec.h index c0b2373ed..489ec5ae9 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -24,7 +24,12 @@ class FixedVec typedef T value_type; typedef FixedVec Scalar; - static const int length = L; + static const int vector_length = L; + + static int length() + { + return L * T::length(); + } static int size() { diff --git a/Math/Setup.cpp b/Math/Setup.cpp index dc76e47d7..715d480d6 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -136,7 +136,7 @@ void write_online_setup(string dirname, const bigint& p) if (mkdir_p(ss.str().c_str()) == -1) { cerr << "mkdir_p(" << ss.str() << ") failed\n"; - throw file_error(ss.str()); + throw file_error("cannot create " + dirname); } // Output the data @@ -167,6 +167,6 @@ string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, res += "-" + to_string(log2mod); res += "/"; if (mkdir_p(res.c_str()) < 0) - throw file_error(res); + throw file_error("cannot create " + res); return res; } diff --git a/Math/Z2k.h b/Math/Z2k.h index cdde3f40c..e8d2ba532 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -439,6 +439,12 @@ void Z2::randomize(PRNG& G, int n) template void Z2::randomize_part(PRNG& G, int n) { + if (n >= N_BITS) + { + randomize(G); + return; + } + *this = {}; G.get_octets((octet*)a, DIV_CEIL(n, 8)); a[DIV_CEIL(n, 64) - 1] &= mp_limb_t(-1LL) >> (N_LIMB_BITS - 1 - (n - 1) % N_LIMB_BITS); diff --git a/Math/Z2k.hpp b/Math/Z2k.hpp index 876aef939..ef2f84c98 100644 --- a/Math/Z2k.hpp +++ b/Math/Z2k.hpp @@ -67,7 +67,10 @@ Z2::Z2(const IntBase& x) : template bool Z2::get_bit(int i) const { - return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS)); + if (i < N_BITS) + return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS)); + else + return false; } template diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index 17fcdf24c..9dd0b7f04 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -174,7 +174,8 @@ void Zp_Data::unpack(octetStream& o) int m; o.get(m); montgomery = m; - init(pr, m); + if (pr != 0) + init(pr, m); } bool Zp_Data::operator!=(const Zp_Data& other) const diff --git a/Math/gf2n.cpp b/Math/gf2n.cpp index 44e424794..d39a8593e 100644 --- a/Math/gf2n.cpp +++ b/Math/gf2n.cpp @@ -44,6 +44,19 @@ int fields_2[num_2_fields][4] = { 128, 7, 2, 1 }, }; +template +string gf2n_::options() +{ + string res = to_string(fields_2[0][0]); + for (int i = 1; i < num_2_fields; i++) + { + int n = fields_2[i][0]; + if (n <= MAX_N_BITS) + res += ", " + to_string(n); + } + return res; +} + template void gf2n_::init_tables() { @@ -113,7 +126,7 @@ void gf2n_::init_field(int nn) if (j==-1) { - throw gf2n_not_supported(nn); + throw gf2n_not_supported(nn, options()); } n=nn; diff --git a/Math/gf2n.h b/Math/gf2n.h index 3ec8849af..56377072a 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -86,6 +86,8 @@ class gf2n_ : public ValueInterface static bool allows(Dtype type) { (void) type; return true; } + static string options(); + static const true_type invertible; static const true_type characteristic_two; @@ -154,6 +156,8 @@ class gf2n_ : public ValueInterface gf2n_ operator*(int x) const { return *this * gf2n_(x); } gf2n_ invert() const; + + gf2n_ operator-() const { return *this; } void negate() { return; } /* Bitwise Ops */ diff --git a/Math/gfpvar.h b/Math/gfpvar.h index a3b475f8c..7d332fdd8 100644 --- a/Math/gfpvar.h +++ b/Math/gfpvar.h @@ -107,6 +107,12 @@ class gfpvar_ a = other.get(); } + template + gfpvar_(const Z2& other) : + gfpvar_(bigint(other)) + { + } + void assign(const void* buffer); void assign_zero(); diff --git a/Networking/AllButLastPlayer.h b/Networking/AllButLastPlayer.h index 22482c481..3d6d18344 100644 --- a/Networking/AllButLastPlayer.h +++ b/Networking/AllButLastPlayer.h @@ -50,17 +50,12 @@ class AllButLastPlayer : public Player void Broadcast_Receive_no_stats(vector& os) const { - vector to_send(P.num_players(), os[P.my_num()]); - vector> channels(P.num_players(), - vector(P.num_players(), true)); - for (auto& x: channels) - x.back() = false; - channels.back() = vector(P.num_players(), false); - vector to_receive; - P.send_receive_all(channels, to_send, to_receive); - for (int i = 0; i < P.num_players() - 1; i++) - if (i != P.my_num()) - os[i] = to_receive[i]; + vector senders(P.num_players(), true), receivers(P.num_players(), + true); + senders.back() = false; + receivers.back() = false; + P.partial_broadcast(senders, receivers, os); + os.resize(num_players()); } }; diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index 43b2ada5c..faf8fda63 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -212,8 +212,8 @@ void CryptoPlayer::partial_broadcast(const vector& my_senders, for (int offset = 1; offset < num_players(); offset++) { int other = get_player(offset); - bool receive = my_senders[other]; - if (my_receivers[other]) + bool receive = my_senders.at(other); + if (my_receivers.at(other)) { this->senders[other]->request(os[my_num()]); sent += os[my_num()].get_length(); diff --git a/Networking/Player.cpp b/Networking/Player.cpp index a7935f305..3a8942148 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -811,14 +811,6 @@ NamedCommStats NamedCommStats::operator -(const NamedCommStats& other) const return res; } -size_t NamedCommStats::total_data() -{ - size_t res = 0; - for (auto& x : *this) - res += x.second.data; - return res; -} - void NamedCommStats::print(bool newline) { for (auto it = begin(); it != end(); it++) diff --git a/Networking/Player.h b/Networking/Player.h index a547d4795..cf8579c0e 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -157,7 +157,6 @@ class NamedCommStats : public map NamedCommStats& operator+=(const NamedCommStats& other); NamedCommStats operator+(const NamedCommStats& other) const; NamedCommStats operator-(const NamedCommStats& other) const; - size_t total_data(); void print(bool newline = false); void reset(); #ifdef VERBOSE_COMM diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index dd7466055..3d40e2ca7 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -230,7 +230,7 @@ void Sub_Data_Files::prune() my_input_buffers.prune(); for (int j = 0; j < num_players; j++) input_buffers[j].prune(); - for (auto it : extended) + for (auto& it : extended) it.second.prune(); dabit_buffer.prune(); if (part != 0) diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 09c6e056a..2eb8a63a6 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -293,7 +293,7 @@ int InputBase::get_player(SubProcessor& Proc, int arg, bool player_from_re if (player_from_reg) { assert(Proc.Proc); - auto res = Proc.Proc->read_Ci(arg); + auto res = Proc.Proc->sync_Ci(arg); if (res >= Proc.P.num_players()) throw runtime_error("player id too large: " + to_string(res)); return res; diff --git a/Processor/Instruction.h b/Processor/Instruction.h index 5279b2584..fd91e35d3 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -13,6 +13,7 @@ using namespace std; template class Machine; template class Processor; +template class SubProcessor; class ArithmeticProcessor; class SwitchableOutput; @@ -107,6 +108,11 @@ enum CONV2DS = 0xAC, CHECK = 0xAF, PRIVATEOUTPUT = 0xAD, + // Shuffling + SECSHUFFLE = 0xFA, + GENSECSHUFFLE = 0xFB, + APPLYSHUFFLE = 0xFC, + DELSHUFFLE = 0xFD, // Data access TRIPLE = 0x50, BIT = 0x51, @@ -250,6 +256,7 @@ enum GMULS = 0x1A6, GMULRS = 0x1A7, GDOTPRODS = 0x1A8, + GSECSHUFFLE = 0x1FA, // Data access GTRIPLE = 0x150, GBIT = 0x151, @@ -388,6 +395,9 @@ class Instruction : public BaseInstruction template void print(SwitchableOutput& out, T* v, T* p = 0, T* s = 0, T* z = 0, T* nan = 0) const; + + template + typename T::clear sanitize(SubProcessor& proc, int reg) const; }; #endif diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 2a5dce70c..5bed37037 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -157,6 +157,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case LISTEN: case CLOSECLIENTCONNECTION: case CRASH: + case DELSHUFFLE: r[0]=get_int(s); break; // instructions with 2 registers + 1 integer operand @@ -203,6 +204,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case DIGESTC: case INPUTMASK: case GINPUTMASK: + case SECSHUFFLE: + case GSECSHUFFLE: get_ints(r, s, 2); n = get_int(s); break; @@ -230,6 +233,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case CONDPRINTSTR: case CONDPRINTSTRB: case RANDOMS: + case GENSECSHUFFLE: r[0]=get_int(s); n = get_int(s); break; @@ -269,6 +273,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) // instructions with 5 register operands case PRINTFLOATPLAIN: case PRINTFLOATPLAINB: + case APPLYSHUFFLE: get_vector(5, start, s); break; case INCINT: @@ -558,6 +563,7 @@ int BaseInstruction::get_reg_type() const case CONVCBITVEC: case INTOUTPUT: case ACCEPTCLIENTCONNECTION: + case GENSECSHUFFLE: return INT; case PREP: case GPREP: @@ -835,11 +841,13 @@ inline void Instruction::execute(Processor& Proc) const { for (int i = 0; i < size; i++) Proc.write_Ci(r[0] + i, - Integer::convert_unsigned(Proc.read_Cp(r[1] + i)).get()); + Proc.sync( + Integer::convert_unsigned(Proc.read_Cp(r[1] + i)).get())); } else if (n <= 64) for (int i = 0; i < size; i++) - Proc.write_Ci(r[0] + i, Integer(Proc.read_Cp(r[1] + i), n).get()); + Proc.write_Ci(r[0] + i, + Proc.sync(Integer(Proc.read_Cp(r[1] + i), n).get())); else throw Processor_Error(to_string(n) + "-bit conversion impossible; " "integer registers only have 64 bits"); @@ -856,40 +864,32 @@ inline void Instruction::execute(Processor& Proc) const n++; break; case LDMCI: - Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.read_Ci(r[1]))); + Proc.write_Cp(r[0], Proc.machine.Mp.read_C(Proc.sync_Ci(r[1]))); break; case STMC: Proc.machine.Mp.write_C(n,Proc.read_Cp(r[0])); n++; break; case STMCI: - Proc.machine.Mp.write_C(Proc.read_Ci(r[1]), Proc.read_Cp(r[0])); + Proc.machine.Mp.write_C(Proc.sync_Ci(r[1]), Proc.read_Cp(r[0])); break; case MOVC: Proc.write_Cp(r[0],Proc.read_Cp(r[1])); break; case DIVC: - if (Proc.read_Cp(r[2]).is_zero()) - throw Processor_Error("Division by zero from register"); - Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / Proc.read_Cp(r[2])); + Proc.write_Cp(r[0], Proc.read_Cp(r[1]) / sanitize(Proc.Procp, r[2])); break; case GDIVC: - if (Proc.read_C2(r[2]).is_zero()) - throw Processor_Error("Division by zero from register"); - Proc.write_C2(r[0], Proc.read_C2(r[1]) / Proc.read_C2(r[2])); + Proc.write_C2(r[0], Proc.read_C2(r[1]) / sanitize(Proc.Proc2, r[2])); break; case FLOORDIVC: - if (Proc.read_Cp(r[2]).is_zero()) - throw Processor_Error("Division by zero from register"); Proc.temp.aa.from_signed(Proc.read_Cp(r[1])); - Proc.temp.aa2.from_signed(Proc.read_Cp(r[2])); + Proc.temp.aa2.from_signed(sanitize(Proc.Procp, r[2])); Proc.write_Cp(r[0], bigint(Proc.temp.aa / Proc.temp.aa2)); break; case MODC: - if (Proc.read_Cp(r[2]).is_zero()) - throw Processor_Error("Modulo by zero from register"); to_bigint(Proc.temp.aa, Proc.read_Cp(r[1])); - to_bigint(Proc.temp.aa2, Proc.read_Cp(r[2])); + to_bigint(Proc.temp.aa2, sanitize(Proc.Procp, r[2])); mpz_fdiv_r(Proc.temp.aa.get_mpz_t(), Proc.temp.aa.get_mpz_t(), Proc.temp.aa2.get_mpz_t()); Proc.temp.ansp.convert_destroy(Proc.temp.aa); Proc.write_Cp(r[0],Proc.temp.ansp); @@ -948,7 +948,7 @@ inline void Instruction::execute(Processor& Proc) const Procp.protocol.randoms_inst(Procp.get_S(), *this); return; case INPUTMASKREG: - Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.read_Ci(r[2])); + Procp.DataF.get_input(Proc.get_Sp_ref(r[0]), Proc.temp.rrp, Proc.sync_Ci(r[2])); Proc.write_Cp(r[1], Proc.temp.rrp); break; case INPUTMASK: @@ -1034,7 +1034,7 @@ inline void Instruction::execute(Processor& Proc) const return; case MATMULSM: Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this, - Proc.read_Ci(r[1]), Proc.read_Ci(r[2])); + Proc.sync_Ci(r[1]), Proc.sync_Ci(r[2])); return; case CONV2DS: Proc.Procp.protocol.conv2ds(Proc.Procp, *this); @@ -1042,6 +1042,21 @@ inline void Instruction::execute(Processor& Proc) const case TRUNC_PR: Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp); return; + case SECSHUFFLE: + Proc.Procp.secure_shuffle(*this); + return; + case GSECSHUFFLE: + Proc.Proc2.secure_shuffle(*this); + return; + case GENSECSHUFFLE: + Proc.write_Ci(r[0], Proc.Procp.generate_secure_shuffle(*this)); + return; + case APPLYSHUFFLE: + Proc.Procp.apply_shuffle(*this, Proc.read_Ci(start.at(3))); + return; + case DELSHUFFLE: + Proc.Procp.delete_shuffle(Proc.read_Ci(r[0])); + return; case CHECK: { CheckJob job; @@ -1056,14 +1071,14 @@ inline void Instruction::execute(Processor& Proc) const Proc.PC += (signed int) n; break; case JMPI: - Proc.PC += (signed int) Proc.read_Ci(r[0]); + Proc.PC += (signed int) Proc.sync_Ci(r[0]); break; case JMPNZ: - if (Proc.read_Ci(r[0]) != 0) + if (Proc.sync_Ci(r[0]) != 0) { Proc.PC += (signed int) n; } break; case JMPEQZ: - if (Proc.read_Ci(r[0]) == 0) + if (Proc.sync_Ci(r[0]) == 0) { Proc.PC += (signed int) n; } break; case PRINTREG: @@ -1123,7 +1138,7 @@ inline void Instruction::execute(Processor& Proc) const Proc.machine.join_tape(r[0]); break; case CRASH: - if (Proc.read_Ci(r[0])) + if (Proc.sync_Ci(r[0])) throw crash_requested(); break; case STARTGRIND: @@ -1146,7 +1161,7 @@ inline void Instruction::execute(Processor& Proc) const // *** case LISTEN: // listen for connections at port number n - Proc.external_clients.start_listening(Proc.read_Ci(r[0])); + Proc.external_clients.start_listening(Proc.sync_Ci(r[0])); break; case ACCEPTCLIENTCONNECTION: { @@ -1335,4 +1350,15 @@ void Instruction::print(SwitchableOutput& out, T* v, T* p, T* s, T* z, T* nan) c out << "]"; } +template +typename T::clear Instruction::sanitize(SubProcessor& proc, int reg) const +{ + if (not T::real_shares(proc.P)) + return 1; + auto& res = proc.get_C_ref(reg); + if (res.is_zero()) + throw Processor_Error("Division by zero from register"); + return res; +} + #endif diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index e0299c2f3..ce90e1b27 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -30,7 +30,7 @@ void Machine::init_binary_domains(int security_parameter, int lg2) if (not is_same()) { - if (sgf2n::clear::degree() < security_parameter) + if (sgf2n::mac_key_type::length() < security_parameter) { cerr << "Security parameter needs to be at most n in GF(2^n)." << endl; @@ -469,7 +469,10 @@ void Machine::run(const string& progname) for (auto& x : comm_stats) rounds += x.second.rounds; cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds - << " rounds (party " << my_number << ")" << endl; + << " rounds (party " << my_number; + if (threads.size() > 1) + cerr << "; rounds counted double due to multi-threading"; + cerr << ")" << endl; auto& P = *this->P; Bundle bundle(P); diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index d4c66e9aa..85ee25d0b 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -36,7 +36,9 @@ OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& op 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - ("Bit length of GF(2^n) field (default: " + to_string(V::default_degree()) + ")").c_str(), // Help description. + ("Bit length of GF(2^n) field (default: " + + to_string(V::default_degree()) + "; options are " + + V::options() + ")").c_str(), // Help description. "-lg2", // Flag token. "--lg2" // Flag token. ); diff --git a/Processor/Processor.h b/Processor/Processor.h index 38ea7f258..927e93279 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -20,6 +20,7 @@ #include "Tools/CheckVector.h" #include "GC/Processor.h" #include "GC/ShareThread.h" +#include "Protocols/SecureShuffle.h" class Program; @@ -31,6 +32,8 @@ class SubProcessor DataPositions bit_usage; + SecureShuffle shuffler; + void resize(size_t size) { C.resize(size); S.resize(size); } template friend class Processor; @@ -70,6 +73,11 @@ class SubProcessor size_t b); void conv2ds(const Instruction& instruction); + void secure_shuffle(const Instruction& instruction); + size_t generate_secure_shuffle(const Instruction& instruction); + void apply_shuffle(const Instruction& instruction, int handle); + void delete_shuffle(int handle); + void input_personal(const vector& args); void send_personal(const vector& args); void private_output(const vector& args); @@ -127,6 +135,10 @@ class ArithmeticProcessor : public ProcessorBase ArithmeticProcessor(OnlineOptions opts, int thread_num) : thread_num(thread_num), sent(0), rounds(0), opts(opts) {} + virtual ~ArithmeticProcessor() + { + } + bool use_stdin() { return thread_num == 0 and opts.interactive; @@ -146,6 +158,11 @@ class ArithmeticProcessor : public ProcessorBase CheckVector& get_Ci() { return Ci; } + virtual long sync_Ci(size_t) const + { + throw not_implemented(); + } + void shuffle(const Instruction& instruction); void bitdecint(const Instruction& instruction); }; @@ -241,6 +258,10 @@ class Processor : public ArithmeticProcessor cint get_inverse2(unsigned m); + // synchronize in asymmetric protocols + long sync_Ci(size_t i) const; + long sync(long x) const; + private: template friend class SPDZ; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index d74594b3d..861e8cfe0 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -9,6 +9,7 @@ #include "Processor/ProcessorBase.hpp" #include "GC/Processor.hpp" #include "GC/ShareThread.hpp" +#include "Protocols/SecureShuffle.hpp" #include #include @@ -23,6 +24,7 @@ SubProcessor::SubProcessor(ArithmeticProcessor& Proc, typename T::MAC_Check& template SubProcessor::SubProcessor(typename T::MAC_Check& MC, Preprocessing& DataF, Player& P, ArithmeticProcessor* Proc) : + shuffler(*this), Proc(Proc), MC(MC), P(P), DataF(DataF), protocol(P), input(*this, MC), bit_prep(bit_usage) { @@ -340,6 +342,9 @@ void Processor::read_socket_private(int client_id, // Tolerent to no file if no shares yet persisted. template void Processor::read_shares_from_file(int start_file_posn, int end_file_pos_register, const vector& data_registers) { + if (not sint::real_shares(P)) + return; + string filename; filename = "Persistence/Transactions-P" + to_string(P.my_num()) + ".data"; @@ -370,6 +375,9 @@ template void Processor::write_shares_to_file(long start_pos, const vector& data_registers) { + if (not sint::real_shares(P)) + return; + string filename = binary_file_io.filename(P.my_num()); unsigned int size = data_registers.size(); @@ -633,6 +641,33 @@ void SubProcessor::conv2ds(const Instruction& instruction) } } +template +void SubProcessor::secure_shuffle(const Instruction& instruction) +{ + SecureShuffle(S, instruction.get_size(), instruction.get_n(), + instruction.get_r(0), instruction.get_r(1), *this); +} + +template +size_t SubProcessor::generate_secure_shuffle(const Instruction& instruction) +{ + return shuffler.generate(instruction.get_n()); +} + +template +void SubProcessor::apply_shuffle(const Instruction& instruction, int handle) +{ + shuffler.apply(S, instruction.get_size(), instruction.get_start()[2], + instruction.get_start()[0], instruction.get_start()[1], handle, + instruction.get_start()[4]); +} + +template +void SubProcessor::delete_shuffle(int handle) +{ + shuffler.del(handle); +} + template void SubProcessor::input_personal(const vector& args) { @@ -690,4 +725,25 @@ typename sint::clear Processor::get_inverse2(unsigned m) return inverses2m[m]; } +template +long Processor::sync_Ci(size_t i) const +{ + return sync(read_Ci(i)); +} + +template +long Processor::sync(long x) const +{ + if (not sint::symmetric) + { + // send number to dealer + if (P.my_num() == 0) + P.send_long(P.num_players() - 1, x); + if (not sint::real_shares(P)) + return P.receive_long(0); + } + + return x; +} + #endif diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index 626942212..8527f98f7 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -50,7 +50,10 @@ RingMachine::RingMachine(int argc, const char** argv, case L: \ machine.template run, V>(); \ break; - X(64) X(72) X(128) X(192) + X(64) +#ifndef FEWER_RINGS + X(72) X(128) X(192) +#endif #ifdef RING_SIZE X(RING_SIZE) #endif diff --git a/Programs/Source/dijkstra_example.mpc b/Programs/Source/dijkstra_example.mpc new file mode 100644 index 000000000..950fe331f --- /dev/null +++ b/Programs/Source/dijkstra_example.mpc @@ -0,0 +1,50 @@ +# example code for graph with vertices 0,1,2 and with following weights +# 0 -> 1: 5 +# 0 -> 2: 20 +# 1 -> 2: 10 + +# output should be the following +# from 0 to 0 at cost 0 via vertex 0 +# from 0 to 1 at cost 5 via vertex 0 +# from 0 to 2 at cost 15 via vertex 1 + +from oram import OptimalORAM +from dijkstra import dijkstra + +# structure for edges +# contains tuples of form (neighbor, cost, last neighbor bit) +edges = OptimalORAM(4, # number of edges + entry_size=(2, # enough bits for vertices + 5, # enough bits for costs + 1) # always one +) + +# first edge from vertex 0 +edges[0] = (1, 5, 0) +# second and last edge from vertex 0 +edges[1] = (2, 20, 1) +# edge from vertex 1 +edges[2] = (2, 10, 1) +# dummy edge from vertex 2 to itself +edges[3] = (2, 0, 1) + +# structure assigning edge list indices to vertices +e_index = OptimalORAM(3, # number vertices + entry_size=2) # enough bits for edge indices + +# edges from 0 start at 0 +e_index[0] = 0 +# edges from 1 start at 2 +e_index[1] = 2 +# edges from 2 start at 3 +e_index[2] = 3 + +source = sint(0) + +res = dijkstra(source, edges, e_index, OptimalORAM) + +@for_range(res.size) +def _(i): + import util + print_ln('from %s to %s at cost %s via vertex %s', source.reveal(), i, + res[i][0].reveal(), res[i][1].reveal()) diff --git a/Programs/Source/dijkstra_tutorial.mpc b/Programs/Source/dijkstra_tutorial.mpc deleted file mode 100644 index 7ab220237..000000000 --- a/Programs/Source/dijkstra_tutorial.mpc +++ /dev/null @@ -1,9 +0,0 @@ -import dijkstra -from path_oram import OptimalORAM - -n = 1000 - -dist = dijkstra.test_dijkstra_on_cycle(n, OptimalORAM) - -for i in range(n): - print_ln('%s: %s', i, dist[i][0].reveal()) diff --git a/Protocols/Dealer.h b/Protocols/Dealer.h new file mode 100644 index 000000000..cc2c45baf --- /dev/null +++ b/Protocols/Dealer.h @@ -0,0 +1,36 @@ +/* + * Dealer.h + * + */ + +#ifndef PROTOCOLS_DEALER_H_ +#define PROTOCOLS_DEALER_H_ + +#include "Beaver.h" + +template +class Dealer : public Beaver +{ + SeededPRNG G; + +public: + Dealer(Player& P) : + Beaver(P) + { + } + + T get_random() + { + if (T::real_shares(this->P)) + return G.get(); + else + return {}; + } + + vector get_relevant_players() + { + return vector(1, this->P.num_players() - 1); + } +}; + +#endif /* PROTOCOLS_DEALER_H_ */ diff --git a/Protocols/DealerInput.h b/Protocols/DealerInput.h index 7d0699da4..7f0a26dd5 100644 --- a/Protocols/DealerInput.h +++ b/Protocols/DealerInput.h @@ -24,6 +24,7 @@ class DealerInput : public InputBase DealerInput(SubProcessor& proc, typename T::MAC_Check&); DealerInput(typename T::MAC_Check&, Preprocessing&, Player& P); DealerInput(Player& P); + DealerInput(SubProcessor*, Player& P); ~DealerInput(); bool is_dealer(int player = -1); diff --git a/Protocols/DealerInput.hpp b/Protocols/DealerInput.hpp index 26bfb9a1a..8b1ea855a 100644 --- a/Protocols/DealerInput.hpp +++ b/Protocols/DealerInput.hpp @@ -10,7 +10,7 @@ template DealerInput::DealerInput(SubProcessor& proc, typename T::MAC_Check&) : - DealerInput(proc.P) + DealerInput(&proc, proc.P) { } @@ -23,6 +23,13 @@ DealerInput::DealerInput(typename T::MAC_Check&, Preprocessing&, template DealerInput::DealerInput(Player& P) : + DealerInput(0, P) +{ +} + +template +DealerInput::DealerInput(SubProcessor* proc, Player& P) : + InputBase(proc), P(P), to_send(P), shares(P.num_players()), from_dealer(false), sub_player(P) { @@ -68,8 +75,8 @@ void DealerInput::add_mine(const typename T::open_type& input, if (is_dealer()) { make_share(shares.data(), input, P.num_players() - 1, 0, G); - for (int i = 1; i < P.num_players(); i++) - shares.at(i - 1).pack(to_send[i]); + for (int i = 0; i < P.num_players() - 1; i++) + shares.at(i).pack(to_send[i]); from_dealer = true; } else diff --git a/Protocols/DealerMC.h b/Protocols/DealerMC.h index 5311f8132..4e6681366 100644 --- a/Protocols/DealerMC.h +++ b/Protocols/DealerMC.h @@ -25,6 +25,7 @@ class DealerMC : public MAC_Check_Base void prepare_open(const T& secret); void exchange(const Player& P); typename T::open_type finalize_raw(); + array finalize_several(int n); DealerMC& get_part_MC() { diff --git a/Protocols/DealerMC.hpp b/Protocols/DealerMC.hpp index a9ddc035c..0f63b93dc 100644 --- a/Protocols/DealerMC.hpp +++ b/Protocols/DealerMC.hpp @@ -73,4 +73,11 @@ typename T::open_type DealerMC::finalize_raw() return {}; } +template +array DealerMC::finalize_several(int n) +{ + assert(sub_player); + return internal.finalize_several(n); +} + #endif /* PROTOCOLS_DEALERMC_HPP_ */ diff --git a/Protocols/DealerMatrixPrep.h b/Protocols/DealerMatrixPrep.h new file mode 100644 index 000000000..787397255 --- /dev/null +++ b/Protocols/DealerMatrixPrep.h @@ -0,0 +1,32 @@ +/* + * DealerMatrixPrep.h + * + */ + +#ifndef PROTOCOLS_DEALERMATRIXPREP_H_ +#define PROTOCOLS_DEALERMATRIXPREP_H_ + +#include "ShareMatrix.h" + +template +class DealerMatrixPrep : public BufferPrep> +{ + typedef BufferPrep> super; + typedef typename T::LivePrep LivePrep; + + int n_rows, n_inner, n_cols; + + LivePrep* prep; + +public: + DealerMatrixPrep(int n_rows, int n_inner, int n_cols, + typename T::LivePrep&, DataPositions& usage); + + void set_protocol(typename ShareMatrix::Protocol&) + { + } + + void buffer_triples(); +}; + +#endif /* PROTOCOLS_DEALERMATRIXPREP_H_ */ diff --git a/Protocols/DealerMatrixPrep.hpp b/Protocols/DealerMatrixPrep.hpp new file mode 100644 index 000000000..faf98ec77 --- /dev/null +++ b/Protocols/DealerMatrixPrep.hpp @@ -0,0 +1,87 @@ +/* + * DealerMatrixPrep.hpp + * + */ + +#include "DealerMatrixPrep.h" + +template +DealerMatrixPrep::DealerMatrixPrep(int n_rows, int n_inner, int n_cols, + typename T::LivePrep& prep, DataPositions& usage) : + super(usage), n_rows(n_rows), n_inner(n_inner), n_cols(n_cols), + prep(&prep) +{ +} + +template +void append_shares(vector& os, + ValueMatrix& M, PRNG& G) +{ + size_t n = os.size(); + for (auto& value : M.entries) + { + T sum; + for (size_t i = 0; i < n - 2; i++) + { + auto share = G.get(); + sum += share; + share.pack(os[i]); + } + (value - sum).pack(os[n - 2]); + } +} + +template +ShareMatrix receive_shares(octetStream& o, int n, int m) +{ + ShareMatrix res(n, m); + for (size_t i = 0; i < res.entries.size(); i++) + res.entries.v.push_back(o.get()); + return res; +} + +template +void DealerMatrixPrep::buffer_triples() +{ + assert(this->prep); + assert(this->prep->proc); + auto& P = this->prep->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + int batch_size = 100; + if (not T::real_shares(P)) + { + SeededPRNG G; + ValueMatrix A(n_rows, n_inner), B(n_inner, n_cols), + C(n_rows, n_cols); + for (int i = 0; i < P.num_players() - 1; i++) + os[i].reserve( + batch_size * T::size() + * (A.entries.size() + B.entries.size() + + C.entries.size())); + for (int i = 0; i < batch_size; i++) + { + A.randomize(G); + B.randomize(G); + C = A * B; + append_shares(os, A, G); + append_shares(os, B, G); + append_shares(os, C, G); + this->triples.push_back({{{n_rows, n_inner}, {n_inner, n_cols}, + {n_rows, n_cols}}}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < batch_size; i++) + { + auto& o = to_receive.back(); + this->triples.push_back({{receive_shares(o, n_rows, n_inner), + receive_shares(o, n_inner, n_cols), + receive_shares(o, n_rows, n_cols)}}); + } + } +} diff --git a/Protocols/DealerPrep.h b/Protocols/DealerPrep.h index ae28ec691..417fdbac7 100644 --- a/Protocols/DealerPrep.h +++ b/Protocols/DealerPrep.h @@ -11,6 +11,13 @@ template class DealerPrep : virtual public BitPrep { + friend class DealerMatrixPrep; + + template + void buffer_inverses(true_type); + template + void buffer_inverses(false_type); + template void buffer_edabits(int n_bits, true_type); template @@ -23,8 +30,14 @@ class DealerPrep : virtual public BitPrep } void buffer_triples(); + void buffer_inverses(); void buffer_bits(); + void buffer_inputs(int player) + { + this->buffer_inputs_as_usual(player, this->proc); + } + void buffer_dabits(ThreadQueues* = 0); void buffer_edabits(int n_bits, ThreadQueues*); void buffer_sedabits(int n_bits, ThreadQueues*); diff --git a/Protocols/DealerPrep.hpp b/Protocols/DealerPrep.hpp index d4a0a91dd..cc010dd71 100644 --- a/Protocols/DealerPrep.hpp +++ b/Protocols/DealerPrep.hpp @@ -45,6 +45,57 @@ void DealerPrep::buffer_triples() } } +template +void DealerPrep::buffer_inverses() +{ + buffer_inverses(T::invertible); +} + +template +template +void DealerPrep::buffer_inverses(false_type) +{ + throw not_implemented(); +} + +template +template +void DealerPrep::buffer_inverses(true_type) +{ + assert(this->proc); + auto& P = this->proc->P; + vector senders(P.num_players()); + senders.back() = true; + octetStreams os(P), to_receive(P); + if (this->proc->input.is_dealer()) + { + SeededPRNG G; + vector> shares(P.num_players() - 1); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + { + T tuple[2]; + while (tuple[0] == 0) + tuple[0] = G.get(); + tuple[1] = tuple[0].invert(); + for (auto& value : tuple) + { + make_share(shares.data(), typename T::clear(value), + P.num_players() - 1, 0, G); + for (int i = 1; i < P.num_players(); i++) + shares.at(i - 1).pack(os[i - 1]); + } + this->inverses.push_back({}); + } + P.send_receive_all(senders, os, to_receive); + } + else + { + P.send_receive_all(senders, os, to_receive); + for (int i = 0; i < OnlineOptions::singleton.batch_size; i++) + this->inverses.push_back(to_receive.back().get>().get()); + } +} + template void DealerPrep::buffer_bits() { diff --git a/Protocols/DealerShare.h b/Protocols/DealerShare.h index 38900ff37..e59e19494 100644 --- a/Protocols/DealerShare.h +++ b/Protocols/DealerShare.h @@ -13,12 +13,16 @@ template class DealerPrep; template class DealerInput; template class DealerMC; template class DirectDealerMC; +template class DealerMatrixPrep; +template class Hemi; namespace GC { class DealerSecret; } +template class Dealer; + template class DealerShare : public SemiShare { @@ -30,22 +34,26 @@ class DealerShare : public SemiShare typedef DealerMC MAC_Check; typedef DirectDealerMC Direct_MC; - typedef Beaver Protocol; + typedef Hemi Protocol; typedef DealerInput Input; typedef DealerPrep LivePrep; typedef ::PrivateOutput PrivateOutput; + typedef DealerMatrixPrep MatrixPrep; + typedef Dealer BasicProtocol; + static false_type dishonest_majority; const static bool needs_ot = false; + const static bool symmetric = false; static string type_short() { return "DD" + string(1, T::type_char()); } - static int threshold(int) + static bool real_shares(const Player& P) { - throw runtime_error("undefined threshold"); + return P.my_num() != P.num_players() - 1; } static This constant(const T& other, int my_num, diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index c0a269d1a..e5bb9e9e5 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -33,6 +33,7 @@ class FakeShare : public T, public ShareInterface static const bool has_trunc_pr = true; static const bool dishonest_majority = false; + static const bool malicious = false; static string type_short() { diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index f43260ea1..0aa61bcba 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -13,22 +13,24 @@ * Matrix multiplication optimized with semi-homomorphic encryption */ template -class Hemi : public Semi +class Hemi : public T::BasicProtocol { - map, HemiMatrixPrep*> matrix_preps; + map, typename T::MatrixPrep*> matrix_preps; DataPositions matrix_usage; + MatrixMC mc; + ShareMatrix matrix_multiply(const ShareMatrix& A, const ShareMatrix& B, SubProcessor& processor); public: Hemi(Player& P) : - Semi(P) + T::BasicProtocol(P) { } ~Hemi(); - HemiMatrixPrep& get_matrix_prep(const array& dimensions, + typename T::MatrixPrep& get_matrix_prep(const array& dimensions, SubProcessor& processor); void matmulsm(SubProcessor& processor, CheckVector& source, diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index 1b3d8f5ba..1549e2cf4 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -21,12 +21,12 @@ Hemi::~Hemi() } template -HemiMatrixPrep& Hemi::get_matrix_prep(const array& dims, +typename T::MatrixPrep& Hemi::get_matrix_prep(const array& dims, SubProcessor& processor) { if (matrix_preps.find(dims) == matrix_preps.end()) matrix_preps.insert({dims, - new HemiMatrixPrep(dims[0], dims[1], dims[2], + new typename T::MatrixPrep(dims[0], dims[1], dims[2], dynamic_cast(processor.DataF), matrix_usage)}); return *matrix_preps.at(dims); @@ -52,22 +52,27 @@ void Hemi::matmulsm(SubProcessor& processor, CheckVector& source, ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); - for (int k = 0; k < dim[1]; k++) + if (not T::real_shares(processor.P)) { - for (int i = 0; i < dim[0]; i++) + matrix_multiply(A, B, processor); + return; + } + + for (int i = 0; i < dim[0]; i++) + for (int k = 0; k < dim[1]; k++) { auto kk = Proc->get_Ci().at(dim[4] + k); auto ii = Proc->get_Ci().at(dim[3] + i); - A[{i, k}] = source.at(a + ii * dim[7] + kk); + A.entries.v.push_back(source.at(a + ii * dim[7] + kk)); } + for (int k = 0; k < dim[1]; k++) for (int j = 0; j < dim[2]; j++) { auto jj = Proc->get_Ci().at(dim[6] + j); auto ll = Proc->get_Ci().at(dim[5] + k); - B[{k, j}] = source.at(b + ll * dim[8] + jj); + B.entries.v.push_back(source.at(b + ll * dim[8] + jj)); } - } auto res = matrix_multiply(A, B, processor); @@ -94,13 +99,16 @@ ShareMatrix Hemi::matrix_multiply(const ShareMatrix& A, subdim[1] = min(max_inner, A.n_cols - i); subdim[2] = min(max_cols, B.n_cols - j); auto& prep = get_matrix_prep(subdim, processor); - MatrixMC mc; beaver.init(prep, mc); beaver.init_mul(); - beaver.prepare_mul(A.from(0, i, subdim.data()), - B.from(i, j, subdim.data() + 1)); - beaver.exchange(); - C.add_from_col(j, beaver.finalize_mul()); + bool for_real = T::real_shares(processor.P); + beaver.prepare_mul(A.from(0, i, subdim.data(), for_real), + B.from(i, j, subdim.data() + 1, for_real)); + if (for_real) + { + beaver.exchange(); + C.add_from_col(j, beaver.finalize_mul()); + } } } @@ -150,6 +158,15 @@ void Hemi::conv2ds(SubProcessor& processor, array dim({{1, weights_h * weights_w * n_channels_in, batch_size * output_h * output_w}}); ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); + if (not T::real_shares(processor.P)) + { + matrix_multiply(A, B, processor); + return; + } + + A.entries.init(); + B.entries.init(); + for (int i_batch = 0; i_batch < batch_size; i_batch ++) { size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in; diff --git a/Protocols/HemiShare.h b/Protocols/HemiShare.h index 4a85cbe34..ddf7e186f 100644 --- a/Protocols/HemiShare.h +++ b/Protocols/HemiShare.h @@ -10,6 +10,7 @@ template class HemiPrep; template class Hemi; +template class HemiMatrixPrep; template class HemiShare : public SemiShare @@ -26,6 +27,9 @@ class HemiShare : public SemiShare typedef typename conditional, Beaver>::type Protocol; typedef HemiPrep LivePrep; + typedef HemiMatrixPrep MatrixPrep; + typedef Semi BasicProtocol; + static const bool needs_ot = false; static const bool local_mul = true; static true_type triple_matmul; diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index 19d5e72d5..fccd2ef57 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -298,7 +298,8 @@ void TreeSum::start(vector& values, const Player& P) { // send from the root player os.reset_write_head(); - for (unsigned int i=0; i::~Direct_MAC_Check() { template void direct_add_openings(vector& values, const PlayerBase& P, vector& os) { - for (unsigned int i=0; i(); } template diff --git a/Protocols/MAC_Check_Base.h b/Protocols/MAC_Check_Base.h index e855214fd..1f745251e 100644 --- a/Protocols/MAC_Check_Base.h +++ b/Protocols/MAC_Check_Base.h @@ -13,6 +13,7 @@ using namespace std; #include "Tools/PointerVector.h" template class Preprocessing; +template class MatrixMC; /** * Abstract base class for opening protocols @@ -20,6 +21,8 @@ template class Preprocessing; template class MAC_Check_Base { + friend class MatrixMC; + protected: /* MAC Share */ typename T::mac_key_type::Scalar alphai; @@ -59,6 +62,7 @@ class MAC_Check_Base /// Get next opened value virtual typename T::clear finalize_open(); virtual typename T::open_type finalize_raw(); + array finalize_several(size_t n); /// Check whether all ``shares`` are ``value`` virtual void CheckFor(const typename T::open_type& value, const vector& shares, const Player& P); diff --git a/Protocols/MAC_Check_Base.hpp b/Protocols/MAC_Check_Base.hpp index 59c6c5dec..47528e006 100644 --- a/Protocols/MAC_Check_Base.hpp +++ b/Protocols/MAC_Check_Base.hpp @@ -70,6 +70,13 @@ typename T::open_type MAC_Check_Base::finalize_raw() return values.next(); } +template +array MAC_Check_Base::finalize_several(size_t n) +{ + assert(values.left() >= n); + return {{values.skip(0), values.skip(n)}}; +} + template void MAC_Check_Base::CheckFor(const typename T::open_type& value, const vector& shares, const Player& P) diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index e6f3a8a6a..7c94b5d81 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -42,8 +42,12 @@ class MaliciousRep3Share : public Rep3Share typedef GC::MaliciousRepSecret bit_type; + // indicate security relevance of field size + typedef T mac_key_type; + const static bool expensive = true; static const bool has_trunc_pr = false; + static const bool malicious = true; static string type_short() { diff --git a/Protocols/MaliciousRepMC.hpp b/Protocols/MaliciousRepMC.hpp index 17eec6f11..631ef7667 100644 --- a/Protocols/MaliciousRepMC.hpp +++ b/Protocols/MaliciousRepMC.hpp @@ -160,7 +160,7 @@ template void CommMaliciousRepMC::POpen_Begin(vector& values, const vector& S, const Player& P) { - assert(T::length == 2); + assert(T::vector_length == 2); (void)values; os.resize(2); for (auto& o : os) diff --git a/Protocols/MaliciousShamirShare.h b/Protocols/MaliciousShamirShare.h index ceedc9157..332996ddd 100644 --- a/Protocols/MaliciousShamirShare.h +++ b/Protocols/MaliciousShamirShare.h @@ -45,6 +45,8 @@ class MaliciousShamirShare : public ShamirShare typedef GC::MaliciousCcdSecret bit_type; #endif + static const bool malicious = true; + static string type_short() { return "M" + super::type_short(); diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index afb456621..786276974 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -122,6 +122,7 @@ class Rep3Share : public RepShare const static bool expensive = false; const static bool variable_players = false; static const bool has_trunc_pr = true; + static const bool malicious = false; static string type_short() { diff --git a/Protocols/Rep4Share.h b/Protocols/Rep4Share.h index 5e197804d..7befb7f4f 100644 --- a/Protocols/Rep4Share.h +++ b/Protocols/Rep4Share.h @@ -37,6 +37,8 @@ class Rep4Share : public RepShare typedef GC::Rep4Secret bit_type; + static const bool malicious = true; + static string type_short() { return "R4" + string(1, T::type_char()); diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index ba5b85c8e..48b014408 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -121,6 +121,8 @@ class ProtocolBase virtual void cisc(SubProcessor&, const Instruction&) { throw runtime_error("CISC instructions not implemented"); } + + virtual vector get_relevant_players(); }; /** @@ -146,7 +148,7 @@ class Replicated : public ReplicatedBase, public ProtocolBase static void assign(T& share, const typename T::clear& value, int my_num) { - assert(T::length == 2); + assert(T::vector_length == 2); share.assign_zero(); if (my_num < 2) share[my_num] = value; diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 2d9eba572..f398da7fe 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -28,7 +28,7 @@ ProtocolBase::ProtocolBase() : template Replicated::Replicated(Player& P) : ReplicatedBase(P) { - assert(T::length == 2); + assert(T::vector_length == 2); } template @@ -152,6 +152,16 @@ T ProtocolBase::get_random() return res; } +template +vector ProtocolBase::get_relevant_players() +{ + vector res; + int n = dynamic_cast(*this).P.num_players(); + for (int i = 0; i < T::threshold(n) + 1; i++) + res.push_back(i); + return res; +} + template void Replicated::init_mul() { diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 9bb3c30a3..9e1498df0 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -71,7 +71,7 @@ class ReplicatedInput : public PrepLessInput ReplicatedInput(SubProcessor* proc, Player& P) : PrepLessInput(proc), proc(proc), P(P), protocol(P) { - assert(T::length == 2); + assert(T::vector_length == 2); expect.resize(P.num_players()); this->reset_all(P); } diff --git a/Protocols/ReplicatedMC.hpp b/Protocols/ReplicatedMC.hpp index e72c0d839..4d875a3b2 100644 --- a/Protocols/ReplicatedMC.hpp +++ b/Protocols/ReplicatedMC.hpp @@ -28,7 +28,7 @@ void ReplicatedMC::POpen_Begin(vector&, template void ReplicatedMC::prepare(const vector& S) { - assert(T::length == 2); + assert(T::vector_length == 2); o.reset_write_head(); to_send.reset_write_head(); to_send.reserve(S.size() * T::value_type::size()); diff --git a/Protocols/SecureShuffle.h b/Protocols/SecureShuffle.h new file mode 100644 index 000000000..a90c6e64f --- /dev/null +++ b/Protocols/SecureShuffle.h @@ -0,0 +1,53 @@ +/* + * SecureShuffle.h + * + */ + +#ifndef PROTOCOLS_SECURESHUFFLE_H_ +#define PROTOCOLS_SECURESHUFFLE_H_ + +#include +using namespace std; + +template class SubProcessor; + +template +class SecureShuffle +{ + SubProcessor& proc; + vector to_shuffle; + vector> config; + vector tmp; + int unit_size; + + vector>>> shuffles; + size_t n_shuffle; + bool exact; + + void player_round(int config_player); + void generate(int config_player, int n_shuffle); + + void waksman(vector& a, int depth, int start); + void cond_swap(T& x, T& y, const T& b); + + void iter_waksman(bool reverse = false); + void waksman_round(int size, bool inwards, bool reverse); + + void pre(vector& a, size_t n, size_t input_base); + void post(vector& a, size_t n, size_t input_base); + +public: + SecureShuffle(vector& a, size_t n, int unit_size, + size_t output_base, size_t input_base, SubProcessor& proc); + + SecureShuffle(SubProcessor& proc); + + int generate(int n_shuffle); + + void apply(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, int handle, bool reverse); + + void del(int handle); +}; + +#endif /* PROTOCOLS_SECURESHUFFLE_H_ */ diff --git a/Protocols/SecureShuffle.hpp b/Protocols/SecureShuffle.hpp new file mode 100644 index 000000000..d2b0676ac --- /dev/null +++ b/Protocols/SecureShuffle.hpp @@ -0,0 +1,328 @@ +/* + * SecureShuffle.hpp + * + */ + +#ifndef PROTOCOLS_SECURESHUFFLE_HPP_ +#define PROTOCOLS_SECURESHUFFLE_HPP_ + +#include "SecureShuffle.h" +#include "Tools/Waksman.h" + +#include +#include + +template +SecureShuffle::SecureShuffle(SubProcessor& proc) : + proc(proc), unit_size(0), n_shuffle(0), exact(false) +{ +} + +template +SecureShuffle::SecureShuffle(vector& a, size_t n, int unit_size, + size_t output_base, size_t input_base, SubProcessor& proc) : + proc(proc), unit_size(unit_size) +{ + pre(a, n, input_base); + + for (auto i : proc.protocol.get_relevant_players()) + player_round(i); + + post(a, n, output_base); +} + +template +void SecureShuffle::apply(vector& a, size_t n, int unit_size, size_t output_base, + size_t input_base, int handle, bool reverse) +{ + this->unit_size = unit_size; + + pre(a, n, input_base); + + auto& shuffle = shuffles.at(handle); + assert(shuffle.size() == proc.protocol.get_relevant_players().size()); + + if (reverse) + for (auto it = shuffle.end(); it > shuffle.begin(); it--) + { + this->config = *(it - 1); + iter_waksman(reverse); + } + else + for (auto& config : shuffle) + { + this->config = config; + iter_waksman(reverse); + } + + post(a, n, output_base); +} + +template +void SecureShuffle::del(int handle) +{ + shuffles.at(handle).clear(); +} + +template +void SecureShuffle::pre(vector& a, size_t n, size_t input_base) +{ + n_shuffle = n / unit_size; + assert(unit_size * n_shuffle == n); + size_t n_shuffle_pow2 = (1u << int(ceil(log2(n_shuffle)))); + exact = (n_shuffle_pow2 == n_shuffle) or not T::malicious; + to_shuffle.clear(); + + if (exact) + { + to_shuffle.resize(n_shuffle_pow2 * unit_size); + for (size_t i = 0; i < n; i++) + to_shuffle[i] = a[input_base + i]; + } + else + { + // sorting power of two elements together with indicator bits + to_shuffle.resize((unit_size + 1) << int(ceil(log2(n_shuffle)))); + for (size_t i = 0; i < n_shuffle; i++) + { + for (int j = 0; j < unit_size; j++) + to_shuffle[i * (unit_size + 1) + j] = a[input_base + + i * unit_size + j]; + to_shuffle[i * (unit_size + 1) + unit_size] = T::constant(1, + proc.P.my_num(), proc.MC.get_alphai()); + } + this->unit_size++; + } +} + +template +void SecureShuffle::post(vector& a, size_t n, size_t output_base) +{ + if (exact) + for (size_t i = 0; i < n; i++) + a[output_base + i] = to_shuffle[i]; + else + { + auto& MC = proc.MC; + MC.init_open(proc.P); + int shuffle_unit_size = this->unit_size; + int unit_size = shuffle_unit_size - 1; + for (size_t i = 0; i < to_shuffle.size() / shuffle_unit_size; i++) + MC.prepare_open(to_shuffle.at((i + 1) * shuffle_unit_size - 1)); + MC.exchange(proc.P); + size_t i_shuffle = 0; + for (size_t i = 0; i < n_shuffle; i++) + { + auto bit = MC.finalize_open(); + if (bit == 1) + { + // only output real elements + for (int j = 0; j < unit_size; j++) + a.at(output_base + i_shuffle * unit_size + j) = + to_shuffle.at(i * shuffle_unit_size + j); + i_shuffle++; + } + } + if (i_shuffle != n_shuffle) + throw runtime_error("incorrect shuffle"); + } +} + +template +void SecureShuffle::player_round(int config_player) +{ + generate(config_player, n_shuffle); + iter_waksman(); +} + +template +int SecureShuffle::generate(int n_shuffle) +{ + int res = shuffles.size(); + shuffles.push_back({}); + auto& shuffle = shuffles.back(); + + for (auto i : proc.protocol.get_relevant_players()) + { + generate(i, n_shuffle); + shuffle.push_back(config); + } + + return res; +} + +template +void SecureShuffle::generate(int config_player, int n) +{ + auto& P = proc.P; + auto& input = proc.input; + input.reset_all(P); + int n_pow2 = 1 << int(ceil(log2(n))); + Waksman waksman(n_pow2); + + if (P.my_num() == config_player) + { + vector perm; + int shuffle_size = n; + for (int j = 0; j < n_pow2; j++) + perm.push_back(j); + SeededPRNG G; + for (int i = 0; i < shuffle_size; i++) + { + int j = G.get_uint(shuffle_size - i); + swap(perm[i], perm[i + j]); + } + + auto config_bits = waksman.configure(perm); + for (size_t i = 0; i < config_bits.size(); i++) + { + auto& x = config_bits[i]; + for (size_t j = 0; j < x.size(); j++) + if (waksman.matters(i, j)) + input.add_mine(int(x[j])); + else + assert(x[j] == 0); + } + } + else + for (size_t i = 0; i < waksman.n_bits(); i++) + input.add_other(config_player); + + input.exchange(); + config.clear(); + typename T::Protocol checker(P); + checker.init(proc.DataF, proc.MC); + checker.init_dotprod(); + auto one = T::constant(1, P.my_num(), proc.MC.get_alphai()); + for (size_t i = 0; i < waksman.n_rounds(); i++) + { + config.push_back({}); + for (int j = 0; j < n_pow2; j++) + { + if (waksman.matters(i, j)) + { + config.back().push_back(input.finalize(config_player)); + if (T::malicious) + checker.prepare_dotprod(config.back().back(), + one - config.back().back()); + } + else + config.back().push_back({}); + } + } + + if (T::malicious) + { + checker.next_dotprod(); + checker.exchange(); + assert( + typename T::clear( + proc.MC.open(checker.finalize_dotprod(waksman.n_bits()), + P)) == 0); + checker.check(); + } +} + +template +void SecureShuffle::waksman(vector& a, int depth, int start) +{ + int n = a.size(); + + if (n == 2) + { + cond_swap(a[0], a[1], config.at(depth).at(start)); + return; + } + + vector a0(n / 2), a1(n / 2); + for (int i = 0; i < n / 2; i++) + { + a0.at(i) = a.at(2 * i); + a1.at(i) = a.at(2 * i + 1); + + cond_swap(a0[i], a1[i], config.at(depth).at(i + start + n / 2)); + } + + waksman(a0, depth + 1, start); + waksman(a1, depth + 1, start + n / 2); + + for (int i = 0; i < n / 2; i++) + { + a.at(2 * i) = a0.at(i); + a.at(2 * i + 1) = a1.at(i); + cond_swap(a[2 * i], a[2 * i + 1], config.at(depth).at(i + start)); + } +} + +template +void SecureShuffle::cond_swap(T& x, T& y, const T& b) +{ + auto diff = proc.protocol.mul(x - y, b); + x -= diff; + y += diff; +} + +template +void SecureShuffle::iter_waksman(bool reverse) +{ + int n = to_shuffle.size() / unit_size; + + for (int depth = 0; depth < log2(n); depth++) + waksman_round(depth, true, reverse); + + for (int depth = log2(n) - 2; depth >= 0; depth--) + waksman_round(depth, false, reverse); +} + +template +void SecureShuffle::waksman_round(int depth, bool inwards, bool reverse) +{ + int n = to_shuffle.size() / unit_size; + assert((int) config.at(depth).size() == n); + int nblocks = 1 << depth; + int size = n / (2 * nblocks); + bool outwards = !inwards; + proc.protocol.init_mul(); + vector> indices; + indices.reserve(n / 2); + Waksman waksman(n); + for (int k = 0; k < n / 2; k++) + { + int j = k % size; + int i = k / size; + int base = 2 * i * size; + int in1 = base + j + j * inwards; + int in2 = in1 + inwards + size * outwards; + int out1 = base + j + j * outwards; + int out2 = out1 + outwards + size * inwards; + int i_bit = base + j + size * (outwards ^ reverse); + bool run = waksman.matters(depth, i_bit); + if (run) + { + for (int l = 0; l < unit_size; l++) + proc.protocol.prepare_mul(config.at(depth).at(i_bit), + to_shuffle.at(in1 * unit_size + l) + - to_shuffle.at(in2 * unit_size + l)); + } + indices.push_back({{in1, in2, out1, out2, run}}); + } + proc.protocol.exchange(); + tmp.resize(to_shuffle.size()); + for (int k = 0; k < n / 2; k++) + { + auto idx = indices.at(k); + for (int l = 0; l < unit_size; l++) + { + T diff; + if (idx[4]) + diff = proc.protocol.finalize_mul(); + tmp.at(idx[2] * unit_size + l) = to_shuffle.at( + idx[0] * unit_size + l) - diff; + tmp.at(idx[3] * unit_size + l) = to_shuffle.at( + idx[1] * unit_size + l) + diff; + } + } + swap(tmp, to_shuffle); +} + +#endif /* PROTOCOLS_SECURESHUFFLE_HPP_ */ diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index b306d5c3d..432b599bb 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -78,6 +78,7 @@ class SemiShare : public T, public ShareInterface const static bool variable_players = true; const static bool expensive = false; static const bool has_trunc_pr = true; + static const bool malicious = false; static string type_short() { return "D" + string(1, T::type_char()); } diff --git a/Protocols/ShamirShare.h b/Protocols/ShamirShare.h index aea0bb97a..bf40cb287 100644 --- a/Protocols/ShamirShare.h +++ b/Protocols/ShamirShare.h @@ -49,6 +49,7 @@ class ShamirShare : public T, public ShareInterface const static bool dishonest_majority = false; const static bool variable_players = true; const static bool expensive = false; + const static bool malicious = true; static string type_short() { diff --git a/Protocols/Share.h b/Protocols/Share.h index e2a9f0bb5..9ca86cea7 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -56,6 +56,7 @@ class Share_ : public ShareInterface const static bool dishonest_majority = T::dishonest_majority; const static bool variable_players = T::variable_players; const static bool has_mac = true; + static const bool malicious = true; static int size() { return T::size() + V::size(); } diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index e5af8dddd..a8ef7a224 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -40,12 +40,17 @@ class ShareInterface static const bool has_trunc_pr = false; static const bool has_split = false; static const bool has_mac = false; + static const bool malicious = false; static const false_type triple_matmul; + const static bool symmetric = true; + static const int default_length = 1; - static string type_short() { return "undef"; } + static string type_short() { throw runtime_error("don't call this"); } + + static bool real_shares(const Player&) { return true; } template static void split(vector, vector, int, T*, int, @@ -63,6 +68,8 @@ class ShareInterface template static void generate_mac_key(T&, U&) {} + + static int threshold(int) { throw runtime_error("undefined threshold"); } }; #endif /* PROTOCOLS_SHAREINTERFACE_H_ */ diff --git a/Protocols/ShareMatrix.h b/Protocols/ShareMatrix.h index 7f84213e6..b31aa7085 100644 --- a/Protocols/ShareMatrix.h +++ b/Protocols/ShareMatrix.h @@ -14,6 +14,124 @@ using namespace std; template class MatrixMC; +template +class NonInitVector +{ + template friend class NonInitVector; + + size_t size_; +public: + AddableVector v; + + NonInitVector(size_t size) : + size_(size) + { + v.reserve(size); + } + + template + NonInitVector(const NonInitVector& other) : + size_(other.size()), v(other.v) + { + } + + size_t size() const + { + return size_; + } + + void init() + { + v.resize(size_); + } + + void check() const + { +#ifdef DEBUG_MATRIX + assert(not v.empty()); +#endif + } + + typename vector::iterator begin() + { + check(); + return v.begin(); + } + + typename vector::iterator end() + { + check(); + return v.end(); + } + + T& at(size_t index) + { + check(); + return v.at(index); + } + + const T& at(size_t index) const + { +#ifdef DEBUG_MATRIX + assert(index < size()); +#endif + return (*this)[index]; + } + + T& operator[](size_t index) + { + check(); + return v[index]; + } + + const T& operator[](size_t index) const + { + check(); + return v[index]; + } + + NonInitVector operator-(const NonInitVector& other) const + { + assert(size() == other.size()); + NonInitVector res(size()); + if (other.v.empty()) + return *this; + else if (v.empty()) + { + res.init(); + res.v = res.v - other.v; + } + else + res.v = v - other.v; + return res; + } + + NonInitVector& operator+=(const NonInitVector& other) + { + assert(size() == other.size()); + if (not other.v.empty()) + { + if (v.empty()) + *this = other; + else + v += other.v; + } + return *this; + } + + bool operator!=(const NonInitVector& other) const + { + return v != other.v; + } + + void randomize(PRNG& G) + { + v.clear(); + for (size_t i = 0; i < size(); i++) + v.push_back(G.get()); + } +}; + template class ValueMatrix : public ValueInterface { @@ -21,7 +139,7 @@ class ValueMatrix : public ValueInterface public: int n_rows, n_cols; - AddableVector entries; + NonInitVector entries; static DataFieldType field_type() { @@ -48,15 +166,19 @@ class ValueMatrix : public ValueInterface T& operator[](const pair& indices) { +#ifdef DEBUG_MATRIX assert(indices.first < n_rows); assert(indices.second < n_cols); +#endif return entries.at(indices.first * n_cols + indices.second); } const T& operator[](const pair& indices) const { +#ifdef DEBUG_MATRIX assert(indices.first < n_rows); assert(indices.second < n_cols); +#endif return entries.at(indices.first * n_cols + indices.second); } @@ -80,6 +202,9 @@ class ValueMatrix : public ValueInterface { assert(n_cols == other.n_rows); This res(n_rows, other.n_cols); + if (entries.v.empty() or other.entries.v.empty()) + return res; + res.entries.init(); for (int i = 0; i < n_rows; i++) for (int j = 0; j < other.n_cols; j++) for (int k = 0; k < n_cols; k++) @@ -103,9 +228,9 @@ class ValueMatrix : public ValueInterface ValueMatrix transpose() const { ValueMatrix res(this->n_cols, this->n_rows); - for (int i = 0; i < this->n_rows; i++) - for (int j = 0; j < this->n_cols; j++) - res[{j, i}] = (*this)[{i, j}]; + for (int j = 0; j < this->n_cols; j++) + for (int i = 0; i < this->n_rows; i++) + res.entries.v.push_back((*this)[{i, j}]); return res; } @@ -139,7 +264,7 @@ class ShareMatrix : public ValueMatrix, public ShareInterface { This res(other.n_rows, other.n_cols); for (size_t i = 0; i < other.entries.size(); i++) - res.entries[i] = T::constant(other.entries[i], my_num, key); + res.entries.v.push_back(T::constant(other.entries[i], my_num, key)); res.check(); return res; } @@ -167,24 +292,29 @@ class ShareMatrix : public ValueMatrix, public ShareInterface ShareMatrix from_col(int start, int size) const { ShareMatrix res(this->n_rows, min(size, this->n_cols - start)); + res.entries.clear(); for (int i = 0; i < res.n_rows; i++) for (int j = 0; j < res.n_cols; j++) - res[{i, j}] = (*this)[{i, start + j}]; + res.entries.v.push_back((*this)[{i, start + j}]); return res; } - ShareMatrix from(int start_row, int start_col, int* sizes) const + ShareMatrix from(int start_row, int start_col, int* sizes, bool for_real = + true) const { ShareMatrix res(min(sizes[0], this->n_rows - start_row), min(sizes[1], this->n_cols - start_col)); + if (not for_real) + return res; for (int i = 0; i < res.n_rows; i++) for (int j = 0; j < res.n_cols; j++) - res[{i, j}] = (*this)[{start_row + i, start_col + j}]; + res.entries.v.push_back((*this)[{start_row + i, start_col + j}]); return res; } void add_from_col(int start, const ShareMatrix& other) { + this->entries.init(); for (int i = 0; i < this->n_rows; i++) for (int j = 0; j < other.n_cols; j++) (*this)[{i, start + j}] += other[{i, j}]; @@ -197,6 +327,9 @@ ShareMatrix operator*(const ValueMatrix& a, { assert(a.n_cols == b.n_rows); ShareMatrix res(a.n_rows, b.n_cols); + if (a.entries.v.empty() or b.entries.v.empty()) + return res; + res.entries.init(); for (int i = 0; i < a.n_rows; i++) for (int j = 0; j < b.n_cols; j++) for (int k = 0; k < a.n_cols; k++) @@ -208,9 +341,22 @@ ShareMatrix operator*(const ValueMatrix& a, template class MatrixMC : public MAC_Check_Base> { - typename T::MAC_Check inner; + typename T::MAC_Check& inner; public: + MatrixMC() : + inner( + *(OnlineOptions::singleton.direct ? + new typename T::Direct_MC : + new typename T::MAC_Check)) + { + } + + ~MatrixMC() + { + delete &inner; + } + void exchange(const Player& P) { inner.init_open(P); @@ -224,8 +370,15 @@ class MatrixMC : public MAC_Check_Base> for (auto& share : this->secrets) { this->values.push_back({share.n_rows, share.n_cols}); - for (auto& entry : this->values.back().entries) - entry = inner.finalize_open(); + if (share.entries.v.empty()) + for (size_t i = 0; i < share.entries.size(); i++) + inner.finalize_open(); + else + { + auto range = inner.finalize_several(share.entries.size()); + auto& v = this->values.back().entries.v; + v.insert(v.begin(), range[0], range[1]); + } } } }; diff --git a/Protocols/TemiShare.h b/Protocols/TemiShare.h index f4f37dcd6..049881ffe 100644 --- a/Protocols/TemiShare.h +++ b/Protocols/TemiShare.h @@ -25,6 +25,9 @@ class TemiShare : public HemiShare typedef typename conditional, Beaver>::type Protocol; typedef TemiPrep LivePrep; + typedef HemiMatrixPrep MatrixPrep; + typedef Semi BasicProtocol; + static const bool needs_ot = false; static const bool local_mul = false; diff --git a/Protocols/fake-stuff.hpp b/Protocols/fake-stuff.hpp index bae415c4a..63b058e08 100644 --- a/Protocols/fake-stuff.hpp +++ b/Protocols/fake-stuff.hpp @@ -130,7 +130,6 @@ template void make_share(DealerShare* Sa, const T& a, int N, const U&, PRNG& G) { make_share((SemiShare*) Sa, a, N - 1, U(), G); - Sa[N - 1] = {}; } template @@ -273,6 +272,11 @@ inline string mac_filename(string directory, int playerno) + to_string(playerno); } +template <> +inline void write_mac_key(const string&, int, int, GC::NoValue) +{ +} + template void write_mac_key(const string& directory, int i, int nplayers, U key) { @@ -301,6 +305,11 @@ void read_mac_key(const string& directory, const Names& N, T& key) read_mac_key(directory, N.my_num(), N.num_players(), key); } +template <> +inline void read_mac_key(const string&, int, int, GC::NoValue&) +{ +} + template void read_mac_key(const string& directory, int player_num, int nplayers, U& key) { @@ -367,7 +376,7 @@ typename T::mac_key_type read_generate_write_mac_key(Player& P, } template -void read_global_mac_key(const string& directory, int nparties, U& key, false_type) +void read_global_mac_key(const string& directory, int nparties, U& key) { U pp; key.assign_zero(); @@ -383,15 +392,9 @@ void read_global_mac_key(const string& directory, int nparties, U& key, false_ty cout << "Final Keys : " << key << endl; } -template -void read_global_mac_key(const string&, int, U&, true_type) -{ -} - -template -void read_global_mac_key(const string& directory, int nparties, U& key) +template <> +inline void read_global_mac_key(const string&, int, GC::NoValue&) { - read_global_mac_key(directory, nparties, key, is_same()); } template @@ -579,14 +582,14 @@ void plain_edabits(vector& as, as.resize(max_size); bs.clear(); bs.resize(length); - bigint value; + Z2 value; for (int j = 0; j < max_size; j++) { if (not zero) - G.get_bigint(value, length, true); + value.randomize_part(G, length); as[j] = value; for (int k = 0; k < length; k++) - bs[k] ^= BitVec(bigint((value >> k) & 1).get_si()) << j; + bs[k] ^= BitVec(value.get_bit(k)) << j; } } diff --git a/README.md b/README.md index a3f6741fd..cd0e9781c 100644 --- a/README.md +++ b/README.md @@ -101,8 +101,9 @@ The following table lists all protocols that are fully supported. | Malicious, dishonest majority | [MASCOT / LowGear / HighGear](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear / ChaiGear](#secret-sharing) | N/A | N/A | N/A | | Semi-honest, dishonest majority | [Semi / Hemi / Temi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | -| Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | +| Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep3 / PS / SY](#honest-majority) | [Rep3 / CCD / PS](#honest-majority) | [BMR](#bmr) | | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | +| Malicious, honest supermajority | [Rep4](#honest-majority) | [Rep4](#honest-majority) | [Rep4](#honest-majority) | N/A | | Semi-honest, dealer | [Dealer](#dealer-model) | [Dealer](#dealer-model) | [Dealer](#dealer-model) | N/A | Modulo prime and modulo 2^k are the two settings that allow @@ -280,6 +281,8 @@ compute the preprocessing time for a particular computation. - Python 3.5 or later - NTL library for homomorphic encryption (optional; tested with NTL 10.5) - If using macOS, Sierra or later + - Windows/VirtualBox: see [this + issue](https://github.com/data61/MP-SPDZ/issues/557) for a discussion #### Compilation diff --git a/Tools/Exceptions.cpp b/Tools/Exceptions.cpp index c7f8c371f..ec39b7728 100644 --- a/Tools/Exceptions.cpp +++ b/Tools/Exceptions.cpp @@ -84,7 +84,9 @@ not_enough_to_buffer::not_enough_to_buffer(const string& type, const string& fil { } -gf2n_not_supported::gf2n_not_supported(int n) : - runtime_error("GF(2^" + to_string(n) + ") not supported") +gf2n_not_supported::gf2n_not_supported(int n, string options) : + runtime_error( + "GF(2^" + to_string(n) + ") not supported" + + (options.empty() ? "" : ", options are " + options)) { } diff --git a/Tools/Exceptions.h b/Tools/Exceptions.h index bb347c6a8..a3ca3a5d0 100644 --- a/Tools/Exceptions.h +++ b/Tools/Exceptions.h @@ -281,7 +281,7 @@ class insufficient_memory : public runtime_error class gf2n_not_supported : public runtime_error { public: - gf2n_not_supported(int n); + gf2n_not_supported(int n, string options = ""); }; #endif diff --git a/Tools/PointerVector.h b/Tools/PointerVector.h index 32d1b46ee..404c4ee92 100644 --- a/Tools/PointerVector.h +++ b/Tools/PointerVector.h @@ -30,6 +30,15 @@ class PointerVector : public CheckVector { return (*this)[i++]; } + T* skip(size_t n) + { + i += n; + return &(*this)[i]; + } + size_t left() + { + return this->size() - i; + } }; #endif /* TOOLS_POINTERVECTOR_H_ */ diff --git a/Tools/Waksman.cpp b/Tools/Waksman.cpp new file mode 100644 index 000000000..a54b7766d --- /dev/null +++ b/Tools/Waksman.cpp @@ -0,0 +1,91 @@ +/* + * Waksman.cpp + * + */ + +#include "Waksman.h" + +#include +#include +#include + +template +void append(vector& x, const vector& y) +{ + x.insert(x.end(), y.begin(), y.end()); +} + +vector > Waksman::configure(const vector& perm) +{ + int n = perm.size(); + assert(n > 1); + + if (n == 2) + return {{perm[0] == 1, perm[0] == 1}}; + + vector I(n / 2); + vector O(n / 2, -1); + vector p0(n / 2, -1), p1(n / 2, -1), inv_perm(n); + + for (int i = 0; i < n; i++) + inv_perm[perm[i]] = i; + + while (true) + { + auto it = find(O.begin(), O.end(), -1); + if (it == O.end()) + break; + int j = 2 * (it - O.begin()); + O.at(j / 2) = 0; + int j0 = j; + + while (true) + { + int i = inv_perm.at(j); + p0.at(i / 2) = j / 2; + I.at(i / 2) = i % 2; + O.at(j / 2) = j % 2; + if (i % 2 == 1) + i--; + else + i++; + j = perm.at(i); + if (j % 2 == 1) + j--; + else + j++; + p1.at(i / 2) = perm.at(i) / 2; + if (j == j0) + break; + } + + if ((find(p1.begin(), p1.end(), -1) == p1.end()) + and (find(p0.begin(), p0.end(), -1) == p0.end())) + break; + } + + auto p0_config = configure(p0); + auto p1_config = configure(p1); + + vector> res; + res.push_back(I); + for (auto& x : O) + res.back().push_back(x); + + assert(p0_config.size() == p1_config.size()); + + for (size_t i = 0; i < p0_config.size(); i++) + { + res.push_back(p0_config.at(i)); + append(res.back(), p1_config.at(i)); + } + + assert(res.size() == Waksman(perm.size()).n_rounds()); + return res; +} + +Waksman::Waksman(int n_elements) : + n_elements(n_elements), nr(log2(n_elements)) +{ + assert(n_elements == (1 << nr)); +} diff --git a/Tools/Waksman.h b/Tools/Waksman.h new file mode 100644 index 000000000..521e990f9 --- /dev/null +++ b/Tools/Waksman.h @@ -0,0 +1,39 @@ +/* + * Waksman.h + * + */ + +#ifndef TOOLS_WAKSMAN_H_ +#define TOOLS_WAKSMAN_H_ + +#include +using namespace std; + +class Waksman +{ + int n_elements; + int nr; + +public: + static vector> configure(const vector& perm); + + Waksman(int n_elements); + + size_t n_rounds() const + { + return nr; + } + + bool matters(int i, int j) const + { + int block = n_elements >> i; + return block == 2 or j % block != block / 2; + } + + size_t n_bits() const + { + return nr * n_elements - (1 << (nr - 1)) + 1; + } +}; + +#endif /* TOOLS_WAKSMAN_H_ */ diff --git a/Utils/he-example.cpp b/Utils/he-example.cpp new file mode 100644 index 000000000..179028a5b --- /dev/null +++ b/Utils/he-example.cpp @@ -0,0 +1,97 @@ +/* + * he-example.cpp + * + */ + +#include "FHE/FHE_Params.h" +#include "FHE/NTL-Subs.h" +#include "FHE/FHE_Keys.h" +#include "FHE/Plaintext.h" + +void first_phase(string filename, int n_mults, int circuit_sec); +void second_phase(string filename); + +int main() +{ + for (int n_mults = 0; n_mults < 2; n_mults++) + for (int sec = 0; sec <= 120; sec += 40) + { + string filename = "mp-spdz-he"; + first_phase(filename, n_mults, sec); + second_phase(filename); + } +} + +void first_phase(string filename, int n_mults, int circuit_sec) +{ + // specify number of multiplications (at most one) and function privacy parameter + // increase the latter to accommodate more operations + FHE_Params params(n_mults, circuit_sec); + + // generate parameters for computation modulo a 32-bit prime + params.basic_generation_mod_prime(32); + + // find computation modulus (depends on parameter generation) + cout << "computation modulo " << params.get_plaintext_modulus() << endl; + + // generate key pair + FHE_KeyPair pair(params); + pair.generate(); + + Plaintext_mod_prime plaintext(params); + + // set first two plaintext slots + plaintext.set_element(0, 4); + plaintext.set_element(1, -1); + + // encrypt + Ciphertext ciphertext = pair.pk.encrypt(plaintext); + + // store for second phase + octetStream os; + params.pack(os); + pair.pk.pack(os); + ciphertext.pack(os); + plaintext.pack(os); + pair.sk.pack(os); + ofstream out(filename); + os.output(out); +} + +void second_phase(string filename) +{ + // read from file + ifstream in(filename); + octetStream os; + os.input(in); + FHE_Params params; + FHE_PK pk(params); + FHE_SK sk(params); + Plaintext_mod_prime plaintext(params); + Ciphertext ciphertext(params); + + // parameter must be set correctly first + params.unpack(os); + pk.unpack(os); + ciphertext.unpack(os); + plaintext.unpack(os); + + if (params.n_mults() == 0) + // public-private multiplication is always available + ciphertext *= plaintext; + else + // private-private multiplication only with matching parameters + ciphertext = ciphertext.mul(pk, ciphertext); + + // re-randomize for circuit privacy + ciphertext.rerandomize(pk); + + // read secret key and decrypt + sk.unpack(os); + plaintext = sk.decrypt(ciphertext); + + cout << "should be 16: " << plaintext.element(0) << endl; + cout << "should be 1: " << plaintext.element(1) << endl; + assert(plaintext.element(0) == 16); + assert(plaintext.element(1) == 1); +} diff --git a/doc/Doxyfile b/doc/Doxyfile index 9820ba50c..5f1143e31 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -829,7 +829,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h ../Math/gfp.h ../Math/gfpvar.h ../Math/Z2k.h +INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h ../Math/gfp.h ../Math/gfpvar.h ../Math/Z2k.h ../FHE/Ciphertext.h ../FHE/FHE_Keys.h ../FHE/FHE_Params.h ../FHE/Plaintext.h # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/doc/homomorphic-encryption.rst b/doc/homomorphic-encryption.rst new file mode 100644 index 000000000..95c922fd8 --- /dev/null +++ b/doc/homomorphic-encryption.rst @@ -0,0 +1,31 @@ +Homomorphic Encryption +---------------------- + +MP-SPDZ uses BGV encryption for triple generation in a number of +protocols. This involves zero-knowledge proofs in some protocols and +considerations about function privacy in all of them. The interface +described below allows directly accessing the basic cryptographic +operations in contexts where these considerations are not relevant. +See ``Utils/he-example.cpp`` for some example code. + + +Reference +~~~~~~~~~ + +.. doxygenclass:: FHE_Params + :members: + +.. doxygenclass:: FHE_KeyPair + :members: + +.. doxygenclass:: FHE_SK + :members: + +.. doxygenclass:: FHE_PK + :members: + +.. doxygenclass:: Plaintext + :members: + +.. doxygenclass:: Ciphertext + :members: diff --git a/doc/index.rst b/doc/index.rst index d2a2c4dcd..59caa58de 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -175,6 +175,7 @@ Reference non-linear preprocessing add-protocol + homomorphic-encryption troubleshooting diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 268084806..8b02fa3ed 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -148,12 +148,14 @@ AVX/AVX2 instructions are deactivated (see e.g. `here `_), which causes a dramatic performance loss. Deactivate Hyper-V/Hypervisor using:: + bcdedit /set hypervisorlaunchtype off DISM /Online /Disable-Feature:Microsoft-Hyper-V Performance can be further increased when compiling MP-SPDZ yourself: :: + sudo apt-get update sudo apt-get install automake build-essential git libboost-dev libboost-thread-dev libntl-dev libsodium-dev libssl-dev libtool m4 python3 texinfo yasm git clone https://github.com/data61/MP-SPDZ.git