diff --git a/CHANGELOG.md b/CHANGELOG.md index 6416d68d2..9c2915120 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. -## 0.2.2 (Jan 21, 2020) +## 0.2.3 (Feb 23, 2021) + +- Distributed key generation for homomorphic encryption with active security similar to [Rotaru et al.](https://eprint.iacr.org/2019/1300) +- Homomorphic encryption parameters more similar to SCALE-MAMBA +- Fixed security bug: all-zero secret keys in homomorphic encryption +- Fixed security bug: missing check in binary Rep4 +- Fixed security bug: insufficient "blaming" (covert security) in CowGear and HighGear due to low default security parameter + +## 0.2.2 (Jan 21, 2021) - Infrastructure for random element generation - Programs generating as much preprocessing data as required by a particular high-level program diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index c9fd4fad1..dba48edcb 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -490,6 +490,9 @@ class bitb(NonVectorInstruction): code = opcodes['BITB'] arg_format = ['sbw'] + def add_usage(self, req_node): + req_node.increment(('bit', 'bit'), 1) + class reveal(BinaryVectorInstruction, base.VarArgsInstruction, base.Mergeable): """ Reveal secret bit register vectors and copy result to clear bit register vectors. @@ -519,6 +522,10 @@ class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction): arg_format = tools.cycle(['p','int','int','sbw']) is_vec = lambda self: True + def add_usage(self, req_node): + for i in range(0, len(self.args), 4): + req_node.increment(('bit', 'input', self.args[0]), self.args[1]) + class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, base.Mergeable): """ Copy private input to secret bit registers bit by bit. The input is @@ -538,17 +545,26 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction, def __init__(self, *args, **kwargs): self.arg_format = [] + for x in self.get_arg_tuples(args): + self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (x[0] - 3) + super(inputbvec, self).__init__(*args, **kwargs) + + @staticmethod + def get_arg_tuples(args): i = 0 while i < len(args): - self.arg_format += ['int', 'int', 'p'] + ['sbw'] * (args[i] - 3) + yield args[i:i+args[i]] i += args[i] assert i == len(args) - super(inputbvec, self).__init__(*args, **kwargs) def merge(self, other): self.args += other.args self.arg_format += other.arg_format + def add_usage(self, req_node): + for x in self.get_arg_tuples(self.args): + req_node.increment(('bit', 'input', x[2]), x[0] - 3) + class print_regb(base.VectorInstruction, base.IOInstruction): """ Debug output of clear bit register. diff --git a/Compiler/allocator.py b/Compiler/allocator.py index f10328f74..3553532db 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -371,6 +371,14 @@ def mem_access(n, instr, last_access_this_kind, last_access_other_kind): # hack warned_about_mem.append(True) + def strict_mem_access(n, last_this_kind, last_other_kind): + if last_other_kind and last_this_kind and \ + last_other_kind[-1] > last_this_kind[-1]: + last_this_kind[:] = [] + last_this_kind.append(n) + for i in last_other_kind: + add_edge(i, n) + def keep_order(instr, n, t, arg_index=None): if arg_index is None: player = None @@ -444,26 +452,17 @@ def keep_merged_order(instr, n, t): if isinstance(instr, ReadMemoryInstruction): if options.preserve_mem_order: - if last_mem_write and last_mem_read and last_mem_write[-1] > last_mem_read[-1]: - last_mem_read[:] = [] - last_mem_read.append(n) - for i in last_mem_write: - add_edge(i, n) + strict_mem_access(n, last_mem_read, last_mem_write) else: mem_access(n, instr, last_mem_read_of, last_mem_write_of) elif isinstance(instr, WriteMemoryInstruction): if options.preserve_mem_order: - if last_mem_write and last_mem_read and last_mem_write[-1] < last_mem_read[-1]: - last_mem_write[:] = [] - last_mem_write.append(n) - for i in last_mem_read: - add_edge(i, n) + strict_mem_access(n, last_mem_write, last_mem_read) else: mem_access(n, instr, last_mem_write_of, last_mem_read_of) elif isinstance(instr, matmulsm): if options.preserve_mem_order: - for i in last_mem_write: - add_edge(i, n) + strict_mem_access(n, last_mem_read, last_mem_write) else: for i in last_mem_write_of.values(): for j in i: diff --git a/Compiler/comparison.py b/Compiler/comparison.py index d330e3614..7dfa3673a 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -452,6 +452,11 @@ def CarryOutRaw(a, b, c=0): assert len(a) == len(b) k = len(a) from . import types + if program.linear_rounds(): + carry = 0 + for (ai, bi) in zip(a, b): + carry = bi.carry_out(ai, carry) + return carry d = [program.curr_block.new_reg('s') for i in range(k)] s = [program.curr_block.new_reg('s') for i in range(3)] for i in range(k): diff --git a/Compiler/ml.py b/Compiler/ml.py index 735cb781a..1ae4bd4db 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -627,15 +627,14 @@ class Dense(DenseBase): :param d_out: output dimension """ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): - self.activation = activation if activation == 'id': - self.f = lambda x: x + self.activation_layer = None elif activation == 'relu': - self.f = relu - self.f_prime = relu_prime - elif activation == 'sigmoid': - self.f = sigmoid - self.f_prime = sigmoid_prime + self.activation_layer = Relu([N, d, d_out]) + elif activation == 'square': + self.activation_layer = Square([N, d, d_out]) + else: + raise CompilerError('activation not supported: %s', activation) self.N = N self.d_in = d_in @@ -652,10 +651,16 @@ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): self.nabla_W = sfix.Matrix(d_in, d_out) self.nabla_b = sfix.Array(d_out) - self.f_input = MultiArray([N, d, d_out], sfix) - self.debug = debug + l = self.activation_layer + if l: + self.f_input = l.X + l.Y = self.Y + l.nabla_Y = self.nabla_Y + else: + self.f_input = self.Y + def reset(self): d_in = self.d_in d_out = self.d_out @@ -699,10 +704,8 @@ def _(i): def forward(self, batch=None): self.compute_f_input(batch=batch) - @multithread(self.n_threads, len(batch), 128) - def _(base, size): - self.Y.assign_part_vector(self.f( - self.f_input.get_part_vector(base, size)), base) + if self.activation_layer: + self.activation_layer.forward(batch) if self.debug: limit = self.debug @for_range_opt(len(batch)) @@ -731,26 +734,11 @@ def backward(self, compute_nabla_X=True, batch=None): nabla_W = self.nabla_W nabla_b = self.nabla_b - if self.activation == 'id': - f_schur_Y = nabla_Y + if self.activation_layer: + self.activation_layer.backward(batch) + f_schur_Y = self.activation_layer.nabla_X else: - f_prime_bit = MultiArray([N, d, d_out], sint) - f_schur_Y = MultiArray([N, d, d_out], sfix) - - @multithread(self.n_threads, f_prime_bit.total_size()) - def _(base, size): - f_prime_bit.assign_vector( - self.f_prime(self.f_input.get_vector(base, size)), base) - - progress('f prime') - - @multithread(self.n_threads, f_prime_bit.total_size()) - def _(base, size): - f_schur_Y.assign_vector(nabla_Y.get_vector(base, size) * - f_prime_bit.get_vector(base, size), - base) - - progress('f prime schur Y') + f_schur_Y = nabla_Y if compute_nabla_X: @multithread(self.n_threads, N) @@ -841,35 +829,55 @@ def _(k): def backward(self): self.nabla_X = self.nabla_Y.schur(self.B) -class Relu(NoVariableLayer): - """ Fixed-point ReLU layer. - - :param shape: input/output shape (tuple/list of int) - """ +class ElementWiseLayer(NoVariableLayer): def __init__(self, shape, inputs=None): self.X = Tensor(shape, sfix) self.Y = Tensor(shape, sfix) + self.nabla_X = Tensor(shape, sfix) + self.nabla_Y = Tensor(shape, sfix) self.inputs = inputs def forward(self, batch=[0]): - assert len(batch) == 1 - @multithread(self.n_threads, self.X[batch[0]].total_size()) + @multithread(self.n_threads, len(batch), 128) def _(base, size): - tmp = relu(self.X[batch[0]].get_vector(base, size)) - self.Y[batch[0]].assign_vector(tmp, base) + self.Y.assign_part_vector(self.f( + self.X.get_part_vector(base, size)), base) -class Square(NoVariableLayer): - """ Fixed-point square layer. + def backward(self, batch): + f_prime_bit = MultiArray(self.X.sizes, self.prime_type) + + @multithread(self.n_threads, f_prime_bit.total_size()) + def _(base, size): + f_prime_bit.assign_vector( + self.f_prime(self.X.get_vector(base, size)), base) + + progress('f prime') + + @multithread(self.n_threads, f_prime_bit.total_size()) + def _(base, size): + self.nabla_X.assign_vector(self.nabla_Y.get_vector(base, size) * + f_prime_bit.get_vector(base, size), + base) + + progress('f prime schur Y') + +class Relu(ElementWiseLayer): + """ Fixed-point ReLU layer. :param shape: input/output shape (tuple/list of int) """ - def __init__(self, shape): - self.X = MultiArray(shape, sfix) - self.Y = MultiArray(shape, sfix) + f = staticmethod(relu) + f_prime = staticmethod(relu_prime) + prime_type = sint - def forward(self, batch=[0]): - assert len(batch) == 1 - self.Y.assign_vector(self.X.get_part_vector(batch[0]) ** 2) +class Square(ElementWiseLayer): + """ Fixed-point square layer. + + :param shape: input/output shape (tuple/list of int) + """ + f = staticmethod(lambda x: x ** 2) + f_prime = staticmethod(lambda x: cfix(2, size=x.size) * x) + prime_type = sfix class MaxPool(NoVariableLayer): """ Fixed-point MaxPool layer. diff --git a/Compiler/program.py b/Compiler/program.py index 004f4dc26..68fea852b 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -150,6 +150,7 @@ def __init__(self, args, options=defaults): self.use_split(int(options.split)) self._square = False self._always_raw = False + self._linear_rounds = False self.warn_about_mem = [True] Program.prog = self from . import instructions_base, instructions, types, comparison @@ -461,6 +462,12 @@ def always_raw(self, change=None): else: self._always_raw = change + def linear_rounds(self, change=None): + if change is None: + return self._linear_rounds + else: + self._linear_rounds = change + def options_from_args(self): """ Set a number of options from the command-line arguments. """ if 'trunc_pr' in self.args: @@ -475,6 +482,8 @@ def options_from_args(self): self.always_raw(True) if 'edabit' in self.args: self.use_edabit(True) + if 'linear_rounds' in self.args: + self.linear_rounds(True) def disable_memory_warnings(self): self.warn_about_mem.append(False) @@ -855,7 +864,9 @@ def write_bytes(self, filename=None): filename = self.program.programs_dir + '/Bytecode/' + filename print('Writing to', filename) f = open(filename, 'wb') - f.write(self.get_bytes()) + for i in self._get_instructions(): + if i is not None: + f.write(i.get_bytes()) f.close() def new_reg(self, reg_type, size=None): diff --git a/Compiler/types.py b/Compiler/types.py index c3e495b4d..461f3886e 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -267,7 +267,7 @@ def __mul__(self, other): def __pow__(self, exp): """ Exponentation through square-and-multiply. - :param exp: compile-time (int) """ + :param exp: any type allowing bit decomposition """ if isinstance(exp, int) and exp >= 0: if exp == 0: return self.__class__(1) @@ -425,6 +425,10 @@ def half_adder(self, other): :rtype: depending on inputs (secret if any of them is) """ return self ^ other, self & other + def carry_out(self, a, b): + s = a ^ b + return a ^ (s & (self ^ a)) + class _gf2n(_bit): """ :math:`\mathrm{GF}(2^n)` functionality. """ @@ -5247,6 +5251,7 @@ def reveal(self): bit_decompose = lambda self,*args,**kwargs: self.read().bit_decompose(*args, **kwargs) if_else = lambda self,*args,**kwargs: self.read().if_else(*args, **kwargs) + bit_and = lambda self,other: self.read().bit_and(other) def expand_to_vector(self, size=None): if program.curr_block == self.last_write_block: diff --git a/FHE/DiscreteGauss.cpp b/FHE/DiscreteGauss.cpp index 924e9d0be..415ef25ed 100644 --- a/FHE/DiscreteGauss.cpp +++ b/FHE/DiscreteGauss.cpp @@ -35,13 +35,14 @@ int DiscreteGauss::sample(PRNG &G, int stretch) const void RandomVectors::set(int nn,int hh,double R) { n=nn; - if (hh > 0) - h=hh; + h=hh; DG.set(R); - assert(h > 0); } - +void RandomVectors::set_n(int nn) +{ + n = nn; +} vector RandomVectors::sample_Gauss(PRNG& G, int stretch) const { @@ -54,8 +55,7 @@ vector RandomVectors::sample_Gauss(PRNG& G, int stretch) const vector RandomVectors::sample_Hwt(PRNG& G) const { - assert(h > 0); - if (h>n/2) { return sample_Gauss(G); } + if (h > n/2 or h <= 0) { return sample_Gauss(G); } vector ans(n); for (int i=0; i sample_Gauss(PRNG& G, int stretch = 1) const; diff --git a/FHE/FFT.cpp b/FHE/FFT.cpp index baed86eec..f15145250 100644 --- a/FHE/FFT.cpp +++ b/FHE/FFT.cpp @@ -191,6 +191,13 @@ void FFT2(vector& a, int N, const modp& alpha, const Zp_Data& PrD) } +void FFT_non_power_of_two(vector& res, const vector& input, const FFT_Data& FFTD) +{ + vector tmp(FFTD.m()); + BFFT(tmp, input, FFTD); + for (int i = 0; i < (FFTD).phi_m(); i++) + res[i] = tmp[(FFTD).p(i)]; +} void BFFT(vector& ans,const vector& a,const FFT_Data& FFTD,bool forward) { diff --git a/FHE/FFT.h b/FHE/FFT.h index 70a99e2c0..b0935d65e 100644 --- a/FHE/FFT.h +++ b/FHE/FFT.h @@ -50,6 +50,8 @@ void FFT_Iter2(vector& a,int N,const modp& theta,const Zp_Data& PrD); void BFFT(vector& ans,const vector& a,const FFT_Data& FFTD,bool forward=true); +void FFT_non_power_of_two(vector& res, const vector& input, + const FFT_Data& FFTD); /* Computes the FFT via Horner's Rule theta is assumed to be an Nth root of unity diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index ec9e6e956..564184dd6 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -13,12 +13,16 @@ FHE_SK::FHE_SK(const FHE_PK& pk) : FHE_SK(pk.get_params(), pk.p()) } -void add(FHE_SK& a,const FHE_SK& b,const FHE_SK& c) +FHE_SK& FHE_SK::operator+=(const FHE_SK& c) { - if (a.params!=b.params) { throw params_mismatch(); } + auto& a = *this; + auto& b = *this; + if (a.params!=c.params) { throw params_mismatch(); } - add(a.sk,b.sk,c.sk); + ::add(a.sk,b.sk,c.sk); + + return *this; } @@ -84,7 +88,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) } } -void FHE_PK::check_noise(const FHE_SK& SK) +void FHE_PK::check_noise(const FHE_SK& SK) const { Rq_Element sk = SK.s(); if (params->n_mults() > 0) @@ -92,7 +96,7 @@ void FHE_PK::check_noise(const FHE_SK& SK) check_noise(b0 - a0 * sk); } -void FHE_PK::check_noise(const Rq_Element& x, bool check_modulo) +void FHE_PK::check_noise(const Rq_Element& x, bool check_modulo) const { assert(pr != 0); vector noise = x.to_vec_bigint(); @@ -162,6 +166,7 @@ void FHE_PK::quasi_encrypt(Ciphertext& c, { if (&c.get_params()!=params) { throw params_mismatch(); } if (&rc.get_params()!=params) { throw params_mismatch(); } + assert(pr != 0); Rq_Element ed,edd,c0,c1,aa; @@ -332,8 +337,11 @@ void FHE_PK::pack(octetStream& o) const o.append((octet*) "PKPKPKPK", 8); a0.pack(o); b0.pack(o); - Sw_a.pack(o); - Sw_b.pack(o); + if (params->n_mults() > 0) + { + Sw_a.pack(o); + Sw_b.pack(o); + } pr.pack(o); } @@ -345,8 +353,11 @@ void FHE_PK::unpack(octetStream& o) throw runtime_error("invalid serialization of public key"); a0.unpack(o); b0.unpack(o); - Sw_a.unpack(o); - Sw_b.unpack(o); + if (params->n_mults() > 0) + { + Sw_a.unpack(o); + Sw_b.unpack(o); + } pr.unpack(o); } @@ -376,14 +387,30 @@ void FHE_SK::check(const FHE_Params& params, const FHE_PK& pk, } +template +void FHE_SK::check(const FHE_PK& pk, const FD& FieldD) +{ + check(*params, pk, pr); + pk.check_noise(*this); + if (decrypt(pk.encrypt(Plaintext_(FieldD)), FieldD) != + Plaintext_(FieldD)) + throw runtime_error("incorrect key pair"); +} + + + void FHE_PK::check(const FHE_Params& params, const bigint& pr) const { if (this->pr != pr) throw pr_mismatch(); a0.check(params); b0.check(params); - Sw_a.check(params); - Sw_b.check(params); + + if (params.n_mults() > 0) + { + Sw_a.check(params); + Sw_b.check(params); + } } @@ -402,3 +429,6 @@ template void FHE_SK::decrypt_any(Plaintext_& res, const Ciphertext& c); template void FHE_SK::decrypt_any(Plaintext_& res, const Ciphertext& c); + +template void FHE_SK::check(const FHE_PK& pk, const FFT_Data&); +template void FHE_SK::check(const FHE_PK& pk, const P2Data&); diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h index 70d9689f7..72a7ddfa8 100644 --- a/FHE/FHE_Keys.h +++ b/FHE/FHE_Keys.h @@ -20,6 +20,8 @@ class FHE_SK public: + static int size() { return 0; } + const FHE_Params& get_params() const { return *params; } bigint p() const { return pr; } @@ -63,18 +65,24 @@ class FHE_SK friend void KeyGen(FHE_PK& PK,FHE_SK& SK,PRNG& G); - /* Add secret key onto the existing one + /* Add secret keys * Used for adding distributed keys together * a,b,c must have same params otherwise an error */ - friend void add(FHE_SK& a,const FHE_SK& b,const FHE_SK& c); - FHE_SK operator+(const FHE_SK& x) { FHE_SK res(*params, pr); add(res, *this, x); return res; } - FHE_SK& operator+=(const FHE_SK& x) { add(*this, *this, x); return *this; } + FHE_SK operator+(const FHE_SK& x) const { FHE_SK res = *this; res += x; return res; } + FHE_SK& operator+=(const FHE_SK& x); + + bool operator!=(const FHE_SK& x) const { return pr != x.pr or sk != x.sk; } - bool operator!=(const FHE_SK& x) { return pr != x.pr or sk != x.sk; } + void add(octetStream& os) { FHE_SK tmp(*this); tmp.unpack(os); *this += tmp; } void check(const FHE_Params& params, const FHE_PK& pk, const bigint& pr) const; + + template + void check(const FHE_PK& pk, const FD& FieldD); + + friend ostream& operator<<(ostream& o, const FHE_SK&) { throw not_implemented(); return o; } }; @@ -92,7 +100,7 @@ class FHE_PK bigint p() const { return pr; } void assign(const Rq_Element& a,const Rq_Element& b, - const Rq_Element& sa,const Rq_Element& sb + const Rq_Element& sa = {},const Rq_Element& sb = {} ) { a0=a; b0=b; Sw_a=sa; Sw_b=sb; } @@ -143,8 +151,8 @@ class FHE_PK Rq_Element sample_secret_key(PRNG& G); void KeyGen(Rq_Element& sk, PRNG& G, int noise_boost = 1); - void check_noise(const FHE_SK& sk); - void check_noise(const Rq_Element& x, bool check_modulo = false); + void check_noise(const FHE_SK& sk) const; + void check_noise(const Rq_Element& x, bool check_modulo = false) const; // params setting is done out of these IO/pack/unpack functions void pack(octetStream& o) const; diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 04b9299f7..7f5563390 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -3,9 +3,16 @@ #include "FHE/Ring_Element.h" #include "Tools/Exceptions.h" - void FHE_Params::set(const Ring& R, const vector& primes,double r,int hwt) +{ + set(R, primes); + + Chi.set(R.phi_m(),hwt,r); +} + +void FHE_Params::set(const Ring& R, + const vector& primes) { if (primes.size() != FFTData.size()) throw runtime_error("wrong number of primes"); @@ -13,8 +20,7 @@ void FHE_Params::set(const Ring& R, for (size_t i = 0; i < FFTData.size(); i++) FFTData[i].init(R,primes[i]); - Chi.set(R.phi_m(),hwt,r); - + Chi.set_n(R.phi_m()); set_sec(40); } diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index 0591cb44f..d918c9567 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -29,14 +29,16 @@ class FHE_Params public: - FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(64, 0.7), sec_p(-1) {} + FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(-1, 0.7), sec_p(-1) {} int n_mults() const { return FFTData.size() - 1; } // Rely on default copy assignment/constructor (not that they should // ever be needed) - void set(const Ring& R,const vector& primes,double r=-1,int hwt=-1); + void set(const Ring& R,const vector& primes,double r,int hwt); + void set(const Ring& R,const vector& primes); + void set(const vector& primes); void set_sec(int sec); vector sampleGaussian(PRNG& G, int noise_boost = 1) const @@ -57,7 +59,9 @@ class FHE_Params int secp() const { return sec_p; } const bigint& B() const { return Bval; } double get_R() const { return Chi.get_R(); } + void set_R(double R) const { return Chi.get_DG().set(R); } DiscreteGauss get_DG() const { return Chi.get_DG(); } + int get_h() const { return Chi.get_h(); } int phi_m() const { return FFTData[0].phi_m(); } const Ring& get_ring() { return FFTData[0].get_R(); } diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 53980239c..67bb551bd 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -55,13 +55,14 @@ int generate_semi_setup(int plaintext_length, int sec, while (true) { SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec, - numBits(NonInteractiveProof::slack(sec, phi_N(m))), true); + numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params); bigint p1 = 2 * p * m, p0 = p; while (nb.min_p0(params.n_mults() > 0, p1) > p0) { p0 *= 2; } - if (phi_N(m) < nb.min_phi_m(2 + numBits(p0 * (params.n_mults() > 0 ? p1 : 1)))) + if (phi_N(m) < nb.min_phi_m(2 + numBits(p0 * (params.n_mults() > 0 ? p1 : 1)), + params.get_R())) { m *= 2; generate_prime(p, lgp, m); @@ -91,7 +92,7 @@ int generate_semi_setup(int plaintext_length, int sec, int m; char_2_dimension(m, plaintext_length); SemiHomomorphicNoiseBounds nb(2, phi_N(m), 1, sec, - numBits(NonInteractiveProof::slack(sec, phi_N(m))), true); + numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params); int lgp0 = numBits(nb.min_p0(false, 0)); int extra_slack = common_semi_setup(params, m, 2, lgp0, -1, round_up); load_or_generate(P2D, params.get_ring()); @@ -111,7 +112,7 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, b int i; for (i = 0; i <= 20; i++) { - if (SemiHomomorphicNoiseBounds::min_phi_m(lgp0 + i) > phi_N(m)) + if (SemiHomomorphicNoiseBounds::min_phi_m(lgp0 + i, params) > phi_N(m)) break; if (not same_word_length(lgp0, lgp0 + i)) break; @@ -138,7 +139,8 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, b return extra_slack; } -int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, bool round_up) +int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, + bool round_up, FHE_Params& params) { if (n >= 2 and n <= 10) cout << "Difference to suggestion for p0: " << lg2p0 - lg2pi[n - 2] @@ -152,7 +154,7 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, bool roun int i = 0; for (i = 0; i < 10; i++) { - if (phi_N(m) < NoiseBounds::min_phi_m(lg2p0 + lg2p1 + 2 * i)) + if (phi_N(m) < NoiseBounds::min_phi_m(lg2p0 + lg2p1 + 2 * i, params)) break; if (not same_word_length(lg2p0 + i, lg2p0)) break; @@ -183,7 +185,8 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, bool roun /* * Subroutine for creating the FHE parameters */ -int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p) +int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p, + FHE_Params& params) { int n = n_parties; int lg2pi[5][2][9] @@ -211,7 +214,7 @@ int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p) while (sec != -1) { double phi_m_bound = - NoiseBounds(p, phi_N(m), n, sec, slack).optimize(lg2p0, lg2p1); + NoiseBounds(p, phi_N(m), n, sec, slack, params).optimize(lg2p0, lg2p1); cout << "Trying primes of length " << lg2p0 << " and " << lg2p1 << endl; if (phi_N(m) < phi_m_bound) { @@ -226,7 +229,7 @@ int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p) init(R,m); int extra_slack = finalize_lengths(lg2p0, lg2p1, n, m, lg2pi[idx][0], - round_up); + round_up, params); generate_moduli(pr0, pr1, m, p, lg2p0, lg2p1); return extra_slack; } @@ -241,8 +244,6 @@ void generate_moduli(bigint& pr0, bigint& pr1, const int m, const bigint p, void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, const string& i, const bigint& pr0) { - int ex; - if (lg2pr==0) { throw invalid_params(); } bigint step=m; @@ -250,15 +251,20 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr, bigint gc=gcd(step,twop); step=step*twop/gc; - ex=lg2pr-numBits(p)-numBits(step)+1; - if (ex<0) { cout << "Something wrong in lg2p" << i << " = " << lg2pr << endl; abort(); } - pr=1; pr=(pr< void generate_setup(FHE_Params& params, FD& FTD) { - SPDZ_Data_Setup(FTD); + SPDZ_Data_Setup(params, FTD); params.set(R, {pr0, pr1}); } template - void SPDZ_Data_Setup(FD& FTD); + void SPDZ_Data_Setup(FHE_Params& params, FD& FTD); - int SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p); + int SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p, + FHE_Params& params); }; @@ -65,7 +66,7 @@ void init(P2Data& P2D,const Ring& Rg); // For use when we want p to be a specific value void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0, - bigint& pr1, int n, int sec, bigint& p); + bigint& pr1, int n, int sec, bigint& p, FHE_Params& params); // generate moduli according to lengths and other parameters void generate_moduli(bigint& pr0, bigint& pr1, const int m, diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index 1ec1e8714..f343d7f71 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -8,52 +8,51 @@ #include "Protocols/CowGearOptions.h" #include - SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, - int phi_m, int n, int sec, int slack_param, bool extra_h, double sigma, int h) : + int phi_m, int n, int sec, int slack_param, bool extra_h, + const FHE_Params& params) : p(p), phi_m(phi_m), n(n), sec(sec), - slack(numBits(Proof::slack(slack_param, sec, phi_m))), sigma(sigma), h(h) + slack(numBits(Proof::slack(slack_param, sec, phi_m))), + sigma(params.get_R()), h(params.get_h()) { if (sigma <= 0) this->sigma = sigma = FHE_Params().get_R(); #ifdef VERBOSE cerr << "Standard deviation: " << this->sigma << endl; #endif - h += extra_h * sec; - produce_epsilon_constants(); - - if (CowGearOptions::singleton.top_gear()) + if (h > 0) + h += extra_h * sec; + else if (extra_h) { - // according to documentation of SCALE-MAMBA 1.7 - // excluding a factor of n because we don't always add up n ciphertexts - B_clean = (bigint(phi_m) << (sec + 2)) * p - * (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * sqrt(h)); - mpf_class V_s; - if (h > 0) - V_s = sqrt(h); - else - V_s = sigma * sqrt(phi_m); - B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0); -#ifdef NOISY - cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; - cout << "V_s: " << V_s << endl; - cout << "c1: " << c1 << endl; - cout << "c2: " << c2 << endl; - cout << "c1 + c2 * V_s: " << c1 + c2 * V_s << endl; - cout << "B_scale: " << B_scale << endl; -#endif + sigma *= 1.4; + params.set_R(params.get_R() * 1.4); } + + produce_epsilon_constants(); + + // according to documentation of SCALE-MAMBA 1.7 + // excluding a factor of n because we don't always add up n ciphertexts + if (h > 0) + V_s = sqrt(h); else - { - B_clean = (phi_m * p / 2 - + p * sigma - * (16 * phi_m * sqrt(n / 2) + 6 * sqrt(phi_m) - + 16 * sqrt(n * h * phi_m))) << slack; - B_scale = p * sqrt(3 * phi_m) * (1 + 8 * sqrt(n * h) / 3); -#ifdef VERBOSE - cout << "log(slack): " << slack << endl; + V_s = sigma * sqrt(phi_m); + B_clean = (bigint(phi_m) << (sec + 1)) * p + * (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * V_s); + // unify parameters by taking maximum over TopGear or not + bigint B_clean_top_gear = B_clean * 2; + bigint B_clean_not_top_gear = B_clean << int(ceil(sec / 2.)); + B_clean = max(B_clean_not_top_gear, B_clean_top_gear); + B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0); +#ifdef NOISY + cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl; + cout << "V_s: " << V_s << endl; + cout << "c1: " << c1 << endl; + cout << "c2: " << c2 << endl; + cout << "c1 + c2 * V_s: " << c1 + c2 * V_s << endl; + cout << "log(slack): " << slack << endl; + cout << "B_clean: " << B_clean << endl; + cout << "B_scale: " << B_scale << endl; #endif - } drown = 1 + n * (bigint(1) << sec); } @@ -77,6 +76,11 @@ double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, double sigma) return 37.8 * (log_q - log2(sigma)); } +double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, const FHE_Params& params) +{ + return min_phi_m(log_q, params.get_R()); +} + void SemiHomomorphicNoiseBounds::produce_epsilon_constants() { double C[3]; @@ -104,21 +108,10 @@ void SemiHomomorphicNoiseBounds::produce_epsilon_constants() } NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack, - double sigma, int h) : - SemiHomomorphicNoiseBounds(p, phi_m, n, sec, slack, false, sigma, h) + const FHE_Params& params) : + SemiHomomorphicNoiseBounds(p, phi_m, n, sec, slack, false, params) { - if (CowGearOptions::singleton.top_gear()) - { - B_KS = p * c2 * this->sigma * phi_m / sqrt(12); - } - else - { - B_KS = p * phi_m * mpf_class(this->sigma) - * (pow(n, 2.5) * (1.49 * sqrt(h * phi_m) + 2.11 * h) - + 2.77 * n * n * sqrt(h) - + pow(n, 1.5) * (1.96 * sqrt(phi_m) * 2.77 * sqrt(h)) - + 4.62 * n); - } + B_KS = p * c2 * this->sigma * phi_m / sqrt(12); #ifdef NOISY cout << "p size: " << numBits(p) << endl; cout << "phi(m): " << phi_m << endl; @@ -197,5 +190,5 @@ double NoiseBounds::optimize(int& lg2p0, int& lg2p1) } lg2p1 = numBits(min_p1); lg2p0 = numBits(min_p0); - return min_phi_m(lg2p0 + lg2p1); + return min_phi_m(lg2p0 + lg2p1, sigma.get_d()); } diff --git a/FHE/NoiseBounds.h b/FHE/NoiseBounds.h index 700790735..466190320 100644 --- a/FHE/NoiseBounds.h +++ b/FHE/NoiseBounds.h @@ -9,6 +9,7 @@ #include "Math/bigint.h" int phi_N(int N); +class FHE_Params; class SemiHomomorphicNoiseBounds { @@ -21,25 +22,27 @@ class SemiHomomorphicNoiseBounds const int sec; int slack; mpf_class sigma; - const int h; + int h; bigint B_clean; bigint B_scale; bigint drown; mpf_class c1, c2; + mpf_class V_s; void produce_epsilon_constants(); public: SemiHomomorphicNoiseBounds(const bigint& p, int phi_m, int n, int sec, - int slack, bool extra_h = false, double sigma = -1, int h = 64); + int slack, bool extra_h, const FHE_Params& params); // with scaling bigint min_p0(const bigint& p1); // without scaling bigint min_p0(); bigint min_p0(bool scale, const bigint& p1) { return scale ? min_p0(p1) : min_p0(); } - static double min_phi_m(int log_q, double sigma = -1); + static double min_phi_m(int log_q, double sigma); + static double min_phi_m(int log_q, const FHE_Params& params); }; // as per ePrint 2012:642 for slack = 0 @@ -49,7 +52,7 @@ class NoiseBounds : public SemiHomomorphicNoiseBounds public: NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack, - double sigma = -1, int h = 64); + const FHE_Params& params); bigint U1(const bigint& p0, const bigint& p1); bigint U2(const bigint& p0, const bigint& p1); bigint min_p0(const bigint& p0, const bigint& p1); diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp index 81f4e0e2b..b9353df80 100644 --- a/FHE/Plaintext.cpp +++ b/FHE/Plaintext.cpp @@ -286,15 +286,21 @@ void Plaintext_::randomize(PRNG& G, bigint B, bool Diag, bool binary, template void Plaintext::randomize(PRNG& G, int n_bits, bool Diag, bool binary, PT_Type t) { - if (Diag or binary) + if (binary) throw not_implemented(); allocate(t); switch(t) { case Polynomial: - for (int i = 0; i < n_slots; i++) - b[i].generateUniform(G, n_bits, false); + if (Diag) + { + assign_zero(t); + b[0].generateUniform(G, n_bits, false); + } + else + for (int i = 0; i < n_slots; i++) + b[i].generateUniform(G, n_bits, false); break; default: throw not_implemented(); @@ -638,6 +644,41 @@ bool Plaintext::equals(const Plaintext& x) const return true; } +template<> +bool Plaintext::is_diagonal() const +{ + if (type != Evaluation) + { + for (size_t i = 1; i < b.size(); i++) + if (b[i] != 0) + return false; + } + + if (type != Polynomial) + { + auto first = a[0]; + for (auto& x : a) + if (x != first) + return false; + } + + return true; +} + +template<> +bool Plaintext::is_diagonal() const +{ + if (type == Polynomial) + from_poly(); + + auto first = a[0]; + for (auto& x : a) + if (x != first) + return false; + + return true; +} + template diff --git a/FHE/Plaintext.h b/FHE/Plaintext.h index cf0a75818..5781e1951 100644 --- a/FHE/Plaintext.h +++ b/FHE/Plaintext.h @@ -176,7 +176,7 @@ class Plaintext bool equals(const Plaintext& x) const; bool operator!=(const Plaintext& x) { return !equals(x); } - bool is_diagonal() const { throw not_implemented(); } + bool is_diagonal() const; bool is_binary() const { throw not_implemented(); } /* Pack and unpack into an octetStream diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 6bf032558..2acb57ef5 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -222,10 +222,7 @@ void Ring_Element::change_rep(RepType r) } else { // Non m power of two variant and FFT enabled - vector fft((*FFTD).m()); - BFFT(fft,element,*FFTD); - for (int i=0; i<(*FFTD).phi_m(); i++) - { element[i]=fft[(*FFTD).p(i)]; } + FFT_non_power_of_two(element, element, *FFTD); } } else diff --git a/FHE/Ring_Element.h b/FHE/Ring_Element.h index da3e0769e..ba147062e 100644 --- a/FHE/Ring_Element.h +++ b/FHE/Ring_Element.h @@ -70,6 +70,16 @@ class Ring_Element Ring_Element(const FFT_Data& prd,RepType r=polynomial); + template + Ring_Element(const FFT_Data& prd, RepType r, const vector& other) + { + assert(size_t(prd.num_slots()) == other.size()); + FFTD = &prd; + rep = r; + for (auto& x : other) + element.push_back(x); + } + // Copy Constructor Ring_Element(const Ring_Element& e) { assign(e); } diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index 087d0402c..db3d4649d 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -66,6 +66,9 @@ class Rq_Element Rq_Element(const Ring_Element& b0,const Ring_Element& b1) : a({b0, b1}), lev(n_mults()) {} + Rq_Element(const Ring_Element& b0) : + a({b0}), lev(n_mults()) {} + template Rq_Element(const FHE_Params& params, const Plaintext& plaintext) : Rq_Element(params) @@ -73,6 +76,15 @@ class Rq_Element from(plaintext.get_iterator()); } + template + Rq_Element(const vector& prd, const vector& b0, + const vector& b1, RepType r = evaluation) : + Rq_Element(prd, r, r) + { + a[0] = Ring_Element(prd[0], r, b0); + a[1] = Ring_Element(prd[1], r, b1); + } + // Destructor ~Rq_Element() { ; } @@ -107,6 +119,7 @@ class Rq_Element void Scale(const bigint& p); bool equals(const Rq_Element& a) const; + bool operator==(const Rq_Element& a) const { return equals(a); } bool operator!=(const Rq_Element& a) const { return !equals(a); } int level() const { return lev; } diff --git a/FHEOffline/DataSetup.cpp b/FHEOffline/DataSetup.cpp index db54d8095..0f5d1fe86 100644 --- a/FHEOffline/DataSetup.cpp +++ b/FHEOffline/DataSetup.cpp @@ -241,67 +241,12 @@ void PartSetup::covert_mac_generation(Player& P, MachineBase&, int num_runs) } template -void PartSetup::covert_secrets_generation(Player& P, MachineBase& machine, - int num_runs) +void PartSetup::key_and_mac_generation(Player& P, + MachineBase& machine, int num_runs, true_type) { - read_or_generate_covert_secrets(*this, P, machine, num_runs); -} - -template -void read_or_generate_covert_secrets(T& setup, Player& P, U& machine, - int num_runs) -{ - octetStream os; - setup.params.pack(os); - setup.FieldD.pack(os); - string filename = PREP_DIR + setup.covert_name() + "-Secrets-" - + to_string(num_runs) + "-" - + os.check_sum(20).get_str(16) + "-P" + to_string(P.my_num()) + "-" - + to_string(P.num_players()); - - string error; - - try - { - ifstream input(filename); - os.input(input); - setup.unpack(os); - machine.unpack(os); - } - catch (exception& e) - { - error = e.what(); - } - - try - { - setup.check(P, machine); - machine.check(P); - } - catch (mismatch_among_parties& e) - { - error = e.what(); - } - - if (not error.empty()) - { - cerr << "Running secrets generation because " << error << endl; - setup.covert_key_generation(P, machine, num_runs); - setup.covert_mac_generation(P, machine, num_runs); - ofstream output(filename); - octetStream os; - setup.pack(os); - machine.pack(os); - os.output(output); - } + covert_key_generation(P, machine, num_runs); + covert_mac_generation(P, machine, num_runs); } template class PartSetup; template class PartSetup; - -template void read_or_generate_covert_secrets, - PairwiseMachine>(PairwiseSetup&, Player&, PairwiseMachine&, - int); -template void read_or_generate_covert_secrets, - PairwiseMachine>(PairwiseSetup&, Player&, PairwiseMachine&, - int); diff --git a/FHEOffline/DataSetup.h b/FHEOffline/DataSetup.h index 281964889..88c9e05fc 100644 --- a/FHEOffline/DataSetup.h +++ b/FHEOffline/DataSetup.h @@ -16,10 +16,6 @@ class DataSetup; class MachineBase; class MultiplicativeMachine; -template -void read_or_generate_covert_secrets(T& setup, Player& P, U& machine, - int num_runs); - template class PartSetup { @@ -38,9 +34,12 @@ class PartSetup return "GlobalParams-" + T::type_string(); } - static string covert_name() + static string protocol_name(bool covert) { - return "ChaiGear"; + if (covert) + return "ChaiGear"; + else + return "HighGear"; } PartSetup(); @@ -69,7 +68,11 @@ class PartSetup void covert_key_generation(Player& P, MachineBase&, int num_runs); void covert_mac_generation(Player& P, MachineBase&, int num_runs); - void covert_secrets_generation(Player& P, MachineBase& machine, int num_runs); + + void key_and_mac_generation(Player& P, MachineBase& machine, int num_runs, + false_type); + void key_and_mac_generation(Player& P, MachineBase& machine, int num_runs, + true_type); void output(Names& N); }; diff --git a/FHEOffline/DataSetup.hpp b/FHEOffline/DataSetup.hpp new file mode 100644 index 000000000..c1124a8e0 --- /dev/null +++ b/FHEOffline/DataSetup.hpp @@ -0,0 +1,63 @@ +/* + * DataSetup.hpp + * + */ + +#ifndef FHEOFFLINE_DATASETUP_HPP_ +#define FHEOFFLINE_DATASETUP_HPP_ + +#include "Networking/Player.h" +#include "Tools/Bundle.h" + +template +void read_or_generate_secrets(T& setup, Player& P, U& machine, + int num_runs, V) +{ + octetStream os; + setup.params.pack(os); + setup.FieldD.pack(os); + bool covert = num_runs > 0; + assert(covert == V()); + string filename = PREP_DIR + setup.protocol_name(covert) + "-Secrets-" + + (covert ? to_string(num_runs) : to_string(machine.sec)) + "-" + + os.check_sum(20).get_str(16) + "-P" + to_string(P.my_num()) + "-" + + to_string(P.num_players()); + + string error; + + try + { + ifstream input(filename); + os.input(input); + setup.unpack(os); + machine.unpack(os); + } + catch (exception& e) + { + error = e.what(); + } + + try + { + setup.check(P, machine); + machine.check(P); + } + catch (mismatch_among_parties& e) + { + error = e.what(); + } + + if (not error.empty()) + { + cerr << "Running secrets generation because " << error << endl; + setup.key_and_mac_generation(P, machine, num_runs, V()); + + ofstream output(filename); + octetStream os; + setup.pack(os); + machine.pack(os); + os.output(output); + } +} + +#endif /* FHEOFFLINE_DATASETUP_HPP_ */ diff --git a/FHEOffline/PairwiseMachine.cpp b/FHEOffline/PairwiseMachine.cpp index 8ef054554..942dd698b 100644 --- a/FHEOffline/PairwiseMachine.cpp +++ b/FHEOffline/PairwiseMachine.cpp @@ -120,7 +120,7 @@ void PairwiseMachine::pack(octetStream& os) const void PairwiseMachine::unpack(octetStream& os) { - os.get(other_pks, {setup_p.params, 0}); + os.get_no_resize(other_pks); os.get(enc_alphas, {pk}); sk.unpack(os); } diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp index fe2ec53a5..bba83b5fd 100644 --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -12,6 +12,7 @@ #include "Tools/Commit.h" #include "Tools/Bundle.h" #include "Processor/OnlineOptions.h" +#include "Protocols/LowGearKeyGen.h" #include "Protocols/Share.hpp" #include "Protocols/mac_key.hpp" @@ -189,6 +190,14 @@ void PairwiseSetup::covert_mac_generation(Player& P, alphai = alpha.element(0); } +template +void PairwiseSetup::key_and_mac_generation(Player& P, + PairwiseMachine& machine, int num_runs, true_type) +{ + covert_key_generation(P, machine, num_runs); + covert_mac_generation(P, machine, num_runs); +} + template void PairwiseSetup::set_alphai(T alphai) { diff --git a/FHEOffline/PairwiseSetup.h b/FHEOffline/PairwiseSetup.h index e180ab75d..8e16eaf34 100644 --- a/FHEOffline/PairwiseSetup.h +++ b/FHEOffline/PairwiseSetup.h @@ -33,9 +33,12 @@ class PairwiseSetup return "PairwiseParams-" + FD::T::type_string(); } - static string covert_name() + static string protocol_name(bool covert) { - return "CowGear"; + if (covert) + return "CowGear"; + else + return "LowGear"; } PairwiseSetup() : params(0), alpha(FieldD) {} @@ -48,6 +51,11 @@ class PairwiseSetup void covert_key_generation(Player& P, PairwiseMachine& machine, int num_runs); void covert_mac_generation(Player& P, PairwiseMachine& machine, int num_runs); + void key_and_mac_generation(Player& P, PairwiseMachine& machine, + int num_runs, false_type); + void key_and_mac_generation(Player& P, PairwiseMachine& machine, + int num_runs, true_type); + void pack(octetStream& os) const; void unpack(octetStream& os); diff --git a/FHEOffline/Producer.cpp b/FHEOffline/Producer.cpp index f573cc1c9..d121372fb 100644 --- a/FHEOffline/Producer.cpp +++ b/FHEOffline/Producer.cpp @@ -588,8 +588,8 @@ void InputProducer::run(const Player& P, const FHE_PK& pk, P.receive_player(j, ciphertexts); P.receive_player(j, cleartexts); C.resize(personal_EC.machine->sec, pk.get_params()); - Verifier(personal_EC.proof).NIZKPoK(C, ciphertexts, - cleartexts, pk, false, false); + Verifier(personal_EC.proof, FieldD).NIZKPoK(C, ciphertexts, + cleartexts, pk, false); } inputs[j].clear(); diff --git a/FHEOffline/Proof.h b/FHEOffline/Proof.h index e4b3c41fb..f331d493a 100644 --- a/FHEOffline/Proof.h +++ b/FHEOffline/Proof.h @@ -24,6 +24,8 @@ class Proof { unsigned int sec; + bool diagonal; + Proof(); // Private to avoid default public: @@ -72,7 +74,8 @@ class Proof typedef AddableMatrix X; Proof(int sc, const bigint& Tau, const bigint& Rho, const FHE_PK& pk, - int n_proofs = 1) : + int n_proofs = 1, bool diagonal = false) : + diagonal(diagonal), B_plain_length(0), B_rand_length(0), pk(&pk), n_proofs(n_proofs) { sec=sc; tau=Tau; rho=Rho; @@ -95,19 +98,25 @@ class Proof } } - Proof(int sec, const FHE_PK& pk, int n_proofs = 1) : + Proof(int sec, const FHE_PK& pk, int n_proofs = 1, bool diagonal = false) : Proof(sec, pk.p() / 2, pk.get_params().get_DG().get_NewHopeB(), pk, - n_proofs) {} + n_proofs, diagonal) {} virtual ~Proof() {} public: static bigint slack(int slack, int sec, int phim); - static bool use_top_gear(const FHE_PK& pk) + bool use_top_gear(const FHE_PK& pk) + { + return CowGearOptions::singleton.top_gear() and pk.p() > 2 and + not diagonal; + } + + bool get_diagonal() const { - return CowGearOptions::singleton.top_gear() and pk.p() > 2; + return diagonal; } static int n_ciphertext_per_proof(int sec, const FHE_PK& pk) @@ -147,8 +156,8 @@ class NonInteractiveProof : public Proof { return bigint(phim * sec * sec) << (sec / 2 + 8); } NonInteractiveProof(int sec, const FHE_PK& pk, - int extra_slack) : - Proof(sec, pk, 1) + int extra_slack, bool diagonal = false) : + Proof(sec, pk, 1, diagonal) { bigint B; B=128*sec*sec; @@ -167,8 +176,8 @@ class InteractiveProof : public Proof { (void)phim; return pow(2, 1.5 * sec + 1); } InteractiveProof(int sec, const FHE_PK& pk, - int n_proofs = 1) : - Proof(sec, pk, n_proofs) + int n_proofs = 1, bool diagonal = false) : + Proof(sec, pk, n_proofs, diagonal) { bigint B; B = bigint(1) << sec; diff --git a/FHEOffline/Prover.cpp b/FHEOffline/Prover.cpp index 0993d4fe1..230c44fb7 100644 --- a/FHEOffline/Prover.cpp +++ b/FHEOffline/Prover.cpp @@ -1,5 +1,6 @@ #include "Prover.h" +#include "Verifier.h" #include "FHE/P2Data.h" #include "Tools/random.h" @@ -27,7 +28,7 @@ Prover::Prover(Proof& proof, const FD& FieldD) : template void Prover::Stage_1(const Proof& P, octetStream& ciphertexts, const AddableVector& c, - const FHE_PK& pk, bool Diag, bool binary) + const FHE_PK& pk, bool binary) { size_t allocate = 3 * c.size() * c[0].report_size(USED); ciphertexts.resize_precise(allocate); @@ -50,7 +51,9 @@ void Prover::Stage_1(const Proof& P, octetStream& ciphertexts, // AE.randomize(Diag,binary); // rd=RandPoly(phim,bd<<1); // y[i]=AE.plaintext()+pr*rd; - y[i].randomize(G, P.B_plain_length, Diag, binary); + y[i].randomize(G, P.B_plain_length, P.get_diagonal(), binary); + if (P.get_diagonal()) + assert(y[i].is_diagonal()); s[i].resize(3, P.phim); s[i].generateUniform(G, P.B_rand_length); rc.assign(s[i][0], s[i][1], s[i][2]); @@ -78,10 +81,14 @@ bool Prover::Stage_2(Proof& P, octetStream& cleartexts, #endif cleartexts.reset_write_head(); cleartexts.store(P.V); + if (P.get_diagonal()) + for (auto& xx : x) + assert(xx.is_diagonal()); for (i=0; i::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl const AddableVector& c, const vector& x, const Proof::Randomness& r, - bool Diag,bool binary) + bool binary) { // AElement AE; // for (i=0; i::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl int cnt=0; while (!ok) { cnt++; - Stage_1(P,ciphertexts,c,pk,Diag,binary); + Stage_1(P,ciphertexts,c,pk,binary); P.set_challenge(ciphertexts); // Check check whether we are OK, or whether we should abort ok = Stage_2(P,cleartexts,x,r,pk); diff --git a/FHEOffline/Prover.h b/FHEOffline/Prover.h index 5e4b28c0b..91bf05cf1 100644 --- a/FHEOffline/Prover.h +++ b/FHEOffline/Prover.h @@ -24,7 +24,7 @@ class Prover Prover(Proof& proof, const FD& FieldD); void Stage_1(const Proof& P, octetStream& ciphertexts, const AddableVector& c, - const FHE_PK& pk, bool Diag, + const FHE_PK& pk, bool binary = false); bool Stage_2(Proof& P, octetStream& cleartexts, @@ -41,7 +41,7 @@ class Prover const AddableVector& c, const vector& x, const Proof::Randomness& r, - bool Diag,bool binary=false); + bool binary=false); size_t report_size(ReportType type); void report_size(ReportType type, MemoryUsage& res); diff --git a/FHEOffline/SimpleEncCommit.cpp b/FHEOffline/SimpleEncCommit.cpp index 742482500..6dd002df3 100644 --- a/FHEOffline/SimpleEncCommit.cpp +++ b/FHEOffline/SimpleEncCommit.cpp @@ -22,8 +22,9 @@ SimpleEncCommitBase::SimpleEncCommitBase(const MachineBase& machine) : template SimpleEncCommit::SimpleEncCommit(const PlayerBase& P, const FHE_PK& pk, const FD& FTD, map& timers, const MachineBase& machine, - int thread_num) : - NonInteractiveProofSimpleEncCommit(P, pk, FTD, timers, machine), + int thread_num, bool diagonal) : + NonInteractiveProofSimpleEncCommit(P, pk, FTD, timers, machine, + diagonal), SimpleEncCommitFactory(pk, FTD, machine) { (void)thread_num; @@ -32,13 +33,14 @@ SimpleEncCommit::SimpleEncCommit(const PlayerBase& P, const FHE_PK& pk template NonInteractiveProofSimpleEncCommit::NonInteractiveProofSimpleEncCommit( const PlayerBase& P, const FHE_PK& pk, const FD& FTD, - map& timers, const MachineBase& machine) : + map& timers, const MachineBase& machine, + bool diagonal) : SimpleEncCommitBase_(machine), P(P), pk(pk), FTD(FTD), - proof(machine.sec, pk, machine.extra_slack), + proof(machine.sec, pk, machine.extra_slack, diagonal), #ifdef LESS_ALLOC_MORE_MEM r(proof.U, this->pk.get_params()), prover(proof, FTD), - verifier(proof), + verifier(proof, FTD), #endif timers(timers) { @@ -72,7 +74,7 @@ void SimpleEncCommitFactory::next(Plaintext_& mess, Ciphertext& C) mess = m[cnt]; C = c[cnt]; - if (Proof::use_top_gear(pk)) + if (get_proof().use_top_gear(pk)) { mess = mess + mess; C = C + C; @@ -86,7 +88,10 @@ template void SimpleEncCommitFactory::prepare_plaintext(PRNG& G) { for (auto& mess : m) - mess.randomize(G); + if (get_proof().get_diagonal()) + mess.randomize(G, Diagonal); + else + mess.randomize(G); } template @@ -126,7 +131,7 @@ size_t NonInteractiveProofSimpleEncCommit::generate_proof(AddableVector > prover(proof, FTD); #endif size_t prover_memory = prover.NIZKPoK(proof, ciphertexts, cleartexts, - pk, c, m, r, false, false); + pk, c, m, r, false); timers["Proving"].stop(); if (proof.top_gear) @@ -187,7 +192,7 @@ size_t NonInteractiveProofSimpleEncCommit::create_more(octetStream& cipherte #endif timers["Verifying"].start(); verifier.NIZKPoK(others_ciphertexts, ciphertexts, - cleartexts, get_pk_for_verification(i), false, false); + cleartexts, get_pk_for_verification(i), false); timers["Verifying"].stop(); add_ciphertexts(others_ciphertexts, i); this->memory_usage.update("verifier", verifier.report_size(CAPACITY)); @@ -214,12 +219,12 @@ void SimpleEncCommit::add_ciphertexts( template SummingEncCommit::SummingEncCommit(const Player& P, const FHE_PK& pk, const FD& FTD, map& timers, const MachineBase& machine, - int thread_num) : + int thread_num, bool diagonal) : SimpleEncCommitFactory(pk, FTD, machine), SimpleEncCommitBase_( - machine), proof(machine.sec, pk, P.num_players()), pk(pk), FTD( + machine), proof(machine.sec, pk, P.num_players(), diagonal), pk(pk), FTD( FTD), P(P), thread_num(thread_num), #ifdef LESS_ALLOC_MORE_MEM - prover(proof, FTD), verifier(proof), preimages(proof.V, + prover(proof, FTD), verifier(proof, FTD), preimages(proof.V, this->pk, FTD.get_prime(), P.num_players()), #endif timers(timers) @@ -246,7 +251,7 @@ void SummingEncCommit::create_more() #endif this->generate_ciphertexts(this->c, this->m, r, pk, timers, proof); this->timers["Stage 1 of proof"].start(); - prover.Stage_1(proof, ciphertexts, this->c, this->pk, false, false); + prover.Stage_1(proof, ciphertexts, this->c, this->pk, false); this->timers["Stage 1 of proof"].stop(); this->c.unpack(ciphertexts, this->pk); @@ -307,7 +312,7 @@ void SummingEncCommit::create_more() Verifier verifier(proof); #endif verifier.Stage_2(this->c, ciphertexts, cleartexts, - this->pk, false, false); + this->pk, false); this->timers["Verifying"].stop(); this->cnt = proof.U - 1; @@ -362,9 +367,9 @@ size_t SummingEncCommit::report_size(ReportType type) template MultiEncCommit::MultiEncCommit(const Player& P, const vector& pks, const FD& FTD, map& timers, MachineBase& machine, - PairwiseGenerator& generator) : + PairwiseGenerator& generator, bool diagonal) : NonInteractiveProofSimpleEncCommit(P, pks[P.my_real_num()], FTD, - timers, machine), pks(pks), P(P), generator(generator) + timers, machine, diagonal), pks(pks), P(P), generator(generator) { } diff --git a/FHEOffline/SimpleEncCommit.h b/FHEOffline/SimpleEncCommit.h index e3af52ee8..a899fc2ca 100644 --- a/FHEOffline/SimpleEncCommit.h +++ b/FHEOffline/SimpleEncCommit.h @@ -65,13 +65,15 @@ class NonInteractiveProofSimpleEncCommit : public SimpleEncCommitBase_ NonInteractiveProofSimpleEncCommit(const PlayerBase& P, const FHE_PK& pk, const FD& FTD, map& timers, - const MachineBase& machine); + const MachineBase& machine, bool diagonal = false); virtual ~NonInteractiveProofSimpleEncCommit() {} - size_t generate_proof(AddableVector& c, vector >& m, - octetStream& ciphertexts, octetStream& cleartexts); + size_t generate_proof(AddableVector& c, + vector >& m, octetStream& ciphertexts, + octetStream& cleartexts); size_t create_more(octetStream& my_ciphertext, octetStream& my_cleartext); virtual size_t report_size(ReportType type); using SimpleEncCommitBase_::report_size; + Proof& get_proof() { return proof; } }; template @@ -96,6 +98,7 @@ class SimpleEncCommitFactory void next(Plaintext_& mess, Ciphertext& C); virtual size_t report_size(ReportType type); void report_size(ReportType type, MemoryUsage& res); + virtual Proof& get_proof() = 0; }; template @@ -111,13 +114,15 @@ class SimpleEncCommit: public NonInteractiveProofSimpleEncCommit, public: SimpleEncCommit(const PlayerBase& P, const FHE_PK& pk, const FD& FTD, - map& timers, const MachineBase& machine, int thread_num); + map& timers, const MachineBase& machine, + int thread_num, bool diagonal = false); void next(Plaintext_& mess, Ciphertext& C) { SimpleEncCommitFactory::next(mess, C); } void create_more(); size_t report_size(ReportType type) { return SimpleEncCommitFactory::report_size(type) + EncCommitBase_::report_size(type); } void report_size(ReportType type, MemoryUsage& res) { SimpleEncCommitFactory::report_size(type, res); SimpleEncCommitBase_::report_size(type, res); } + Proof& get_proof() { return NonInteractiveProofSimpleEncCommit::get_proof(); } }; template @@ -146,13 +151,15 @@ class SummingEncCommit: public SimpleEncCommitFactory, map& timers; SummingEncCommit(const Player& P, const FHE_PK& pk, const FD& FTD, - map& timers, const MachineBase& machine, int thread_num); + map& timers, const MachineBase& machine, + int thread_num, bool diagonal = false); void next(Plaintext_& mess, Ciphertext& C) { SimpleEncCommitFactory::next(mess, C); } void create_more(); size_t report_size(ReportType type); void report_size(ReportType type, MemoryUsage& res) { SimpleEncCommitFactory::report_size(type, res); SimpleEncCommitBase_::report_size(type, res); } + Proof& get_proof() { return proof; } }; template @@ -179,7 +186,7 @@ class MultiEncCommit : public NonInteractiveProofSimpleEncCommit MultiEncCommit(const Player& P, const vector& pks, const FD& FTD, map& timers, MachineBase& machine, - PairwiseGenerator& generator); + PairwiseGenerator& generator, bool diagonal = false); }; #endif /* FHEOFFLINE_SIMPLEENCCOMMIT_H_ */ diff --git a/FHEOffline/Verifier.cpp b/FHEOffline/Verifier.cpp index 1b653733d..9c26e94c0 100644 --- a/FHEOffline/Verifier.cpp +++ b/FHEOffline/Verifier.cpp @@ -4,7 +4,8 @@ #include "Math/modp.hpp" template -Verifier::Verifier(Proof& proof) : P(proof) +Verifier::Verifier(Proof& proof, const FD& FieldD) : + P(proof), FieldD(FieldD) { #ifdef LESS_ALLOC_MORE_MEM z.resize(proof.phim); @@ -30,12 +31,28 @@ bool Check_Decoding(const Plaintext& AE,bool Diag) return true; } -template -bool Check_Decoding(const vector& AE, bool Diag) +template <> +bool Check_Decoding(const vector& AE, bool Diag, const FFT_Data&) { - (void)AE; if (Diag) - throw not_implemented(); + { + for (size_t i = 1; i < AE.size(); i++) + if (AE[i] != 0) + return false; + } + return true; +} + +template <> +bool Check_Decoding(const vector& AE, bool Diag, const P2Data& p2d) +{ + if (Diag) + { + Plaintext_ tmp(p2d); + for (size_t i = 0; i < AE.size(); i++) + tmp.set_coeff(i, AE[i].get_limb(0) % 2); + return tmp.is_diagonal(); + } return true; } @@ -45,7 +62,7 @@ template void Verifier::Stage_2( AddableVector& c,octetStream& ciphertexts, octetStream& cleartexts, - const FHE_PK& pk,bool Diag,bool binary) + const FHE_PK& pk,bool binary) { unsigned int i, V; @@ -76,12 +93,7 @@ void Verifier::Stage_2( { cout << "Fail Check 6 " << i << endl; throw runtime_error("ciphertexts don't match"); } - } - - // Now check decoding z[i] - for (i=0; i::Stage_2( template void Verifier::NIZKPoK(AddableVector& c, octetStream& ciphertexts, octetStream& cleartexts, - const FHE_PK& pk,bool Diag, + const FHE_PK& pk, bool binary) { P.set_challenge(ciphertexts); - Stage_2(c,ciphertexts,cleartexts,pk,Diag,binary); + Stage_2(c,ciphertexts,cleartexts,pk,binary); if (P.top_gear) { - assert(not Diag); + assert(not P.get_diagonal()); assert(not binary); c += c; } diff --git a/FHEOffline/Verifier.h b/FHEOffline/Verifier.h index 681a2c667..dd9614488 100644 --- a/FHEOffline/Verifier.h +++ b/FHEOffline/Verifier.h @@ -3,6 +3,9 @@ #include "Proof.h" +template +bool Check_Decoding(const vector& AE, bool Diag, FD& FieldD); + /* Defines the Verifier */ template class Verifier @@ -11,20 +14,21 @@ class Verifier AddableMatrix t; Proof& P; + const FD& FieldD; public: - Verifier(Proof& proof); + Verifier(Proof& proof, const FD& FieldD); void Stage_2( AddableVector& c, octetStream& ciphertexts, - octetStream& cleartexts,const FHE_PK& pk,bool Diag,bool binary=false); + octetStream& cleartexts,const FHE_PK& pk,bool binary=false); /* This is the non-interactive version using the ROM - Creates space for all output values - Diag flag mirrors that in Prover */ void NIZKPoK(AddableVector& c,octetStream& ciphertexts,octetStream& cleartexts, - const FHE_PK& pk,bool Diag,bool binary=false); + const FHE_PK& pk,bool binary=false); size_t report_size(ReportType type) { return z.report_size(type) + t.report_size(type); } }; diff --git a/GC/instructions.h b/GC/instructions.h index 9b694ba2d..467c1d7c1 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -136,6 +136,7 @@ X(RUN_TAPE, MACH->run_tapes(EXTRA)) \ X(JOIN_TAPE, MACH->join_tape(R0)) \ X(USE, ) \ + X(USE_INP, ) \ #define INSTRUCTIONS BIT_INSTRUCTIONS GC_INSTRUCTIONS diff --git a/Machines/SPDZ.hpp b/Machines/SPDZ.hpp index a88da3fb0..07764968d 100644 --- a/Machines/SPDZ.hpp +++ b/Machines/SPDZ.hpp @@ -9,6 +9,8 @@ #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" #include "Processor/Machine.hpp" +#include "Processor/FieldMachine.hpp" + #include "Protocols/MAC_Check.hpp" #include "Protocols/fake-stuff.hpp" #include "Protocols/Beaver.hpp" @@ -26,4 +28,6 @@ #include "GC/ShareSecret.hpp" #include "GC/TinierSharePrep.hpp" +#include "Math/gfp.hpp" + #endif /* MACHINES_SPDZ_HPP_ */ diff --git a/Machines/highgear-party.cpp b/Machines/highgear-party.cpp new file mode 100644 index 000000000..b1a9a8dd0 --- /dev/null +++ b/Machines/highgear-party.cpp @@ -0,0 +1,16 @@ +/* + * highgear-party.cpp + * + */ + +#include "Protocols/HighGearShare.h" + +#include "SPDZ.hpp" +#include "Protocols/ChaiGearPrep.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + CowGearOptions::singleton = CowGearOptions(opt, argc, argv, false); + DishonestMajorityFieldMachine(argc, argv, opt); +} diff --git a/Machines/lowgear-party.cpp b/Machines/lowgear-party.cpp new file mode 100644 index 000000000..fa9f7d278 --- /dev/null +++ b/Machines/lowgear-party.cpp @@ -0,0 +1,16 @@ +/* + * lowgear-party.cpp + * + */ + +#include "Protocols/LowGearShare.h" + +#include "SPDZ.hpp" +#include "Protocols/CowGearPrep.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + CowGearOptions::singleton = CowGearOptions(opt, argc, argv, false); + DishonestMajorityFieldMachine(argc, argv, opt); +} diff --git a/Makefile b/Makefile index d556856c3..d691ace02 100644 --- a/Makefile +++ b/Makefile @@ -47,7 +47,7 @@ binary: rep-bin yao semi-bin-party.x tinier-party.x tiny-party.x ccd-party.x mal ifeq ($(USE_NTL),1) all: overdrive she-offline -gear: cowgear-party.x chaigear-party.x +gear: cowgear-party.x chaigear-party.x lowgear-party.x highgear-party.x arithmetic: hemi-party.x soho-party.x gear endif @@ -186,10 +186,14 @@ hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) +lowgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o +highgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o static/hemi-party.x: $(FHEOFFLINE) static/soho-party.x: $(FHEOFFLINE) static/cowgear-party.x: $(FHEOFFLINE) static/chaigear-party.x: $(FHEOFFLINE) +static/lowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o +static/highgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o mascot-party.x: Machines/SPDZ.o $(OT) static/mascot-party.x: Machines/SPDZ.o Player-Online.x: Machines/SPDZ.o $(OT) diff --git a/Math/FixedVec.h b/Math/FixedVec.h index 3e188513d..df51fa214 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -34,6 +34,11 @@ class FixedVec { return L * T::size(); } + static int size_in_bits() + { + return L * T::size_in_bits(); + } + static string type_string() { return T::type_string() + "^" + to_string(L); diff --git a/Math/Setup.cpp b/Math/Setup.cpp index ea143c0a5..b4800017c 100644 --- a/Math/Setup.cpp +++ b/Math/Setup.cpp @@ -146,8 +146,10 @@ void check_setup(string dir, bigint pr) bigint p; string filename = dir + "Params-Data"; ifstream(filename) >> p; + if (p == 0) + throw runtime_error("no modulus in " + filename); if (p != pr) - throw runtime_error("wrong modulus in " + dir); + throw runtime_error("wrong modulus in " + filename); } string get_prep_sub_dir(const string& prep_dir, int nparties, int log2mod, diff --git a/Math/ValueInterface.h b/Math/ValueInterface.h index 4bd36cd2a..d56a94f89 100644 --- a/Math/ValueInterface.h +++ b/Math/ValueInterface.h @@ -24,7 +24,7 @@ class ValueInterface template static void init(bool mont = true) { (void) mont; } static void init_default(int, bool = true) {} - static void init_field() {} + static void init_field(const bigint& = {}) {} static void read_or_generate_setup(const string&, const OnlineOptions&) {} template diff --git a/Math/bigint.h b/Math/bigint.h index efb2c777d..0d1d2ee69 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -67,6 +67,7 @@ class bigint : public mpz_class bigint& operator=(int n); bigint& operator=(long n); bigint& operator=(word n); + bigint& operator=(double f); template bigint& operator=(const gfp_& other); template @@ -138,6 +139,12 @@ inline bigint& bigint::operator=(word n) return *this; } +inline bigint& bigint::operator=(double f) +{ + mpz_class::operator=(f); + return *this; +} + template bigint::bigint(const Z2& x) { diff --git a/Math/modp.h b/Math/modp.h index 0f0bac88b..fef5956c3 100644 --- a/Math/modp.h +++ b/Math/modp.h @@ -45,6 +45,14 @@ class modp_ inline_mpn_zero(x + M, L - M); } + template + modp_(const gfp_& other) : + modp_() + { + assert(M <= L); + inline_mpn_copyi(x, other.get().get(), M); + } + const mp_limb_t* get() const { return x; } void assign(const void* buffer, int t) { memcpy(x, buffer, t * sizeof(mp_limb_t)); } diff --git a/Math/modp.hpp b/Math/modp.hpp index 94cc4b398..2ebc1e716 100644 --- a/Math/modp.hpp +++ b/Math/modp.hpp @@ -11,7 +11,15 @@ template void modp_::randomize(PRNG& G, const Zp_Data& ZpD) { - G.randomBnd(x, ZpD.get_prA(), ZpD.pr_byte_length, ZpD.overhang_mask()); + const int M = sizeof(mp_limb_t) * L; + switch (ZpD.pr_byte_length) + { +#define X(LL) case LL: G.randomBnd(x, ZpD.get_prA(), ZpD.overhang_mask()); break; + X(M) X(M-1) X(M-2) X(M-3) X(M-4) X(M-5) X(M-6) X(M-7) +#undef X + default: + G.randomBnd(x, ZpD.get_prA(), ZpD.pr_byte_length, ZpD.overhang_mask()); + } } template diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index d6ca875d6..6f947cc65 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -155,25 +155,28 @@ void CryptoPlayer::send_receive_all_no_stats(const vector>& channel } } -void CryptoPlayer::partial_broadcast(const vector& senders, +void CryptoPlayer::partial_broadcast(const vector& my_senders, + const vector& my_receivers, vector& os) const { TimeScope ts(comm_stats["Partial broadcasting"].add(os[my_num()])); - sent += os[my_num()].get_length() * (num_players() - 1); for (int offset = 1; offset < num_players(); offset++) { int other = get_player(offset); - bool receive = senders[other]; - if (senders[my_num()]) + bool receive = my_senders[other]; + if (my_receivers[other]) + { this->senders[other]->request(os[my_num()]); + sent += os[my_num()].get_length(); + } if (receive) this->receivers[other]->request(os[other]); } for (int offset = 1; offset < num_players(); offset++) { int other = get_player(offset); - bool receive = senders[other]; - if (senders[my_num()]) + bool receive = my_senders[other]; + if (my_receivers[other]) this->senders[other]->wait(os[my_num()]); if (receive) this->receivers[other]->wait(os[other]); diff --git a/Networking/CryptoPlayer.h b/Networking/CryptoPlayer.h index c1ca89ead..ec488a38a 100644 --- a/Networking/CryptoPlayer.h +++ b/Networking/CryptoPlayer.h @@ -42,8 +42,8 @@ class CryptoPlayer : public MultiPlayer const vector& to_send, vector& to_receive) const; - void partial_broadcast(const vector& senders, - vector& os) const; + void partial_broadcast(const vector& my_senders, + const vector& my_receivers, vector& os) const; void Broadcast_Receive_no_stats(vector& os) const; }; diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 67faf8b38..e92d5219e 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -522,12 +522,24 @@ void Player::send_receive_all(const vector>& channels, size_t data = 0; for (int i = 0; i < num_players(); i++) if (i != my_num() and channels.at(my_num()).at(i)) - data += to_send.at(i).get_length(); + { + data += to_send.at(i).get_length(); +#ifdef VERBOSE_COMM + cerr << "Send " << to_send.at(i).get_length() << " to " << i << endl; +#endif + } TimeScope ts(comm_stats["Sending/receiving"].add(data)); sent += data; send_receive_all_no_stats(channels, to_send, to_receive); } +void Player::partial_broadcast(const vector& senders, + vector& os) const +{ + partial_broadcast(senders, vector(num_players(), senders[my_num()]), + os); +} + template void MultiPlayer::send_receive_all_no_stats( const vector>& channels, const vector& to_send, diff --git a/Networking/Player.h b/Networking/Player.h index 99fa42322..9e9326582 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -218,7 +218,9 @@ class Player : public PlayerBase const vector& to_send, vector& to_receive) const = 0; - virtual void partial_broadcast(const vector&, + virtual void partial_broadcast(const vector& senders, + vector& os) const; + virtual void partial_broadcast(const vector&, const vector&, vector& os) const { unchecked_broadcast(os); } // dummy functions for compatibility diff --git a/Networking/sockets.cpp b/Networking/sockets.cpp index ab3a194aa..311fd527f 100644 --- a/Networking/sockets.cpp +++ b/Networking/sockets.cpp @@ -113,7 +113,8 @@ void set_up_client_socket(int& mysocket,const char* hostname,int Portnum) if (fl < 0) { - cout << attempts << " attempts" << endl; + cout << attempts << " attempts to " << hostname << ":" << Portnum + << endl; error("set_up_socket:connect:", hostname); } diff --git a/OT/NPartyTripleGenerator.hpp b/OT/NPartyTripleGenerator.hpp index 08c2c7b65..59d303cea 100644 --- a/OT/NPartyTripleGenerator.hpp +++ b/OT/NPartyTripleGenerator.hpp @@ -261,7 +261,7 @@ void NPartyTripleGenerator::generateInputs(int player) mac_sum = (ot_multipliers[i_thread])->input_macs[j]; } inputs[j] = {{share, mac_sum}, secrets[j]}; - auto r = G.get(); + auto r = G.get(); check_sum += typename W::input_check_type(r * share, r * mac_sum); } inputs.resize(nTriplesPerLoop); diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 770776dd8..8daaff250 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -60,6 +60,7 @@ class DataPositions map, long long> edabits; DataPositions(int num_players = 0); + DataPositions(const Player& P) : DataPositions(P.num_players()) {} ~DataPositions(); void reset(); diff --git a/Processor/FixInput.cpp b/Processor/FixInput.cpp index 6820697ec..e72115568 100644 --- a/Processor/FixInput.cpp +++ b/Processor/FixInput.cpp @@ -16,7 +16,13 @@ void FixInput_::read(std::istream& in, const int* params) template<> void FixInput_::read(std::istream& in, const int* params) { +#ifdef HIGH_PREC_INPUT mpf_class x; in >> x; items[0] = x << *params; +#else + double x; + in >> x; + items[0] = x * (1 << *params); +#endif } diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index d5cc119e9..c566233ee 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -209,7 +209,7 @@ void Machine::fill_buffers(int thread_number, int tape_number, dynamic_cast&>(tinfo[thread_number].processor->share_thread.DataF); for (int i = 0; i < DIV_CEIL(usage.files[DATA_GF2][DATA_TRIPLE], bit_type::default_length); i++) - dest.push_triple(source.get_triple(bit_type::default_length)); + dest.push_triple(source.get_triple_no_count(bit_type::default_length)); } catch (bad_cast& e) { diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index 4ade74df4..28f0d8aac 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -194,7 +194,9 @@ void OnlineMachine::start_networking() if (ipFileName.size() > 0) { if (my_port != Names::DEFAULT_PORT) throw runtime_error("cannot set port number when using IP file"); - playerNames.init(playerno, pnbase, ipFileName); + if (nplayers == 0 and opt.isSet("-N")) + opt.get("-N")->getInt(nplayers); + playerNames.init(playerno, pnbase, ipFileName, nplayers); } else { if (not opt.get("-ext-server")->isSet) { diff --git a/Processor/PrivateOutput.h b/Processor/PrivateOutput.h index a0ac2a50a..498f35fc1 100644 --- a/Processor/PrivateOutput.h +++ b/Processor/PrivateOutput.h @@ -24,6 +24,9 @@ class PrivateOutput void start(int player, int target, int source); void stop(int player, int dest, int source); + + T start(int player, const T& source); + typename T::clear stop(int player, const typename T::clear& masked); }; #endif /* PROCESSOR_PRIVATEOUTPUT_H_ */ diff --git a/Processor/PrivateOutput.hpp b/Processor/PrivateOutput.hpp index 400d1c449..977e7e15d 100644 --- a/Processor/PrivateOutput.hpp +++ b/Processor/PrivateOutput.hpp @@ -8,24 +8,42 @@ template void PrivateOutput::start(int player, int target, int source) +{ + proc.get_S_ref(target) = start(player, proc.get_S_ref(source)); +} + +template +T PrivateOutput::start(int player, const T& source) { assert (player < proc.P.num_players()); open_type mask; - proc.DataF.get_input(proc.get_S_ref(target), mask, player); - proc.get_S_ref(target) += proc.get_S_ref(source); + T res; + proc.DataF.get_input(res, mask, player); + res += source; if (player == proc.P.my_num()) masks.push_back(mask); + + return res; } template void PrivateOutput::stop(int player, int dest, int source) { - if (player == proc.P.my_num() and proc.Proc) - { - auto& value = proc.get_C_ref(dest); - value = (proc.get_C_ref(source) - masks.front()); + auto& value = proc.get_C_ref(dest); + value = stop(player, proc.get_C_ref(source)); + if (proc.Proc) value.output(proc.Proc->private_output, false); +} + +template +typename T::clear PrivateOutput::stop(int player, const typename T::clear& source) +{ + typename T::clear value; + if (player == proc.P.my_num()) + { + value = source - masks.front(); masks.pop_front(); } + return value; } diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index d36fdab79..a139338e5 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -39,6 +39,8 @@ SubProcessor::SubProcessor(typename T::MAC_Check& MC, template SubProcessor::~SubProcessor() { + protocol.check(); + for (size_t i = 0; i < personal_bit_preps.size(); i++) { auto& x = personal_bit_preps[i]; @@ -225,8 +227,9 @@ void Processor::split(const Instruction& instruction) assert(unit == 64); int n_inputs = instruction.get_size(); int n_bits = instruction.get_start().size() / n; + assert(share_thread.protocol != 0); sint::split(Procb.S, instruction.get_start(), n_bits, - &read_Sp(instruction.get_r(0)), n_inputs, P); + &read_Sp(instruction.get_r(0)), n_inputs, *share_thread.protocol); } diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index c4d9944f6..b40219d14 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -85,13 +85,14 @@ if '2dense' in program.args: program.disable_memory_warnings() -layers[-1].Y.input_from(0) -layers[0].X.input_from(0) - Y = sint.Matrix(n_test, 10) X = sfix.Matrix(n_test, n_features) -Y.input_from(0) -X.input_from(0) + +if not ('no_acc' in program.args and 'no_loss' in program.args): + layers[-1].Y.input_from(0) + layers[0].X.input_from(0) + Y.input_from(0) + X.input_from(0) if 'always_acc' in program.args: n_part_epochs = 1 diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index 4b64b0ceb..71bc96fce 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -35,6 +35,8 @@ class Beaver : public ProtocolBase Beaver(Player& P) : prep(0), MC(0), P(P) {} + Player& branch(); + void init_mul(SubProcessor* proc); void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); typename T::clear prepare_mul(const T& x, const T& y, int n = -1); diff --git a/Protocols/Beaver.hpp b/Protocols/Beaver.hpp index 20a2aad0b..d3a7c0815 100644 --- a/Protocols/Beaver.hpp +++ b/Protocols/Beaver.hpp @@ -12,6 +12,11 @@ #include +template +Player& Beaver::branch() +{ + return P; +} template void Beaver::init_mul(SubProcessor* proc) diff --git a/Protocols/BrainPrep.h b/Protocols/BrainPrep.h index 77565811d..ded5b0d53 100644 --- a/Protocols/BrainPrep.h +++ b/Protocols/BrainPrep.h @@ -20,6 +20,7 @@ class BrainPrep : public MaliciousRingPrep BrainPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), MaliciousRingPrep(proc, usage) { } diff --git a/Protocols/ChaiGearPrep.h b/Protocols/ChaiGearPrep.h index 12281e74a..21e18eeb6 100644 --- a/Protocols/ChaiGearPrep.h +++ b/Protocols/ChaiGearPrep.h @@ -24,6 +24,11 @@ class ChaiGearPrep : public MaliciousRingPrep Generator& get_generator(); + template + void buffer_bits(true_type); + template + void buffer_bits(false_type); + public: static void basic_setup(Player& P); static void key_setup(Player& P, mac_key_type alphai); @@ -32,6 +37,7 @@ class ChaiGearPrep : public MaliciousRingPrep ChaiGearPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), MaliciousRingPrep(proc, usage), generator(0), square_producer(0), input_producer(0) { diff --git a/Protocols/ChaiGearPrep.hpp b/Protocols/ChaiGearPrep.hpp index 988b3cdbb..a829a2e58 100644 --- a/Protocols/ChaiGearPrep.hpp +++ b/Protocols/ChaiGearPrep.hpp @@ -10,6 +10,8 @@ #include "FHEOffline/SimpleMachine.h" #include "FHEOffline/Producer.h" +#include "FHEOffline/DataSetup.hpp" + template MultiplicativeMachine* ChaiGearPrep::machine = 0; template @@ -65,7 +67,8 @@ void ChaiGearPrep::key_setup(Player& P, mac_key_type alphai) assert(machine); auto& setup = machine->setup.part(); auto& options = CowGearOptions::singleton; - setup.covert_secrets_generation(P, *machine, options.covert_security); + read_or_generate_secrets(setup, P, *machine, options.covert_security, + T::covert); // adjust mac key mac_key_type diff = alphai - setup.alphai; @@ -169,14 +172,22 @@ void ChaiGearPrep::buffer_inputs(int player) template inline void ChaiGearPrep::buffer_bits() +{ + buffer_bits<0>(T::clear::characteristic_two); +} + +template +template +void ChaiGearPrep::buffer_bits(false_type) { buffer_bits_from_squares(*this); } -template<> -inline void ChaiGearPrep>::buffer_bits() +template +template +void ChaiGearPrep::buffer_bits(true_type) { - buffer_bits_without_check(); + this->buffer_bits_without_check(); assert(not this->bits.empty()); for (auto& bit : this->bits) bit.force_to_bit(); diff --git a/Protocols/ChaiGearShare.h b/Protocols/ChaiGearShare.h index 1178bee9d..865a1aeb3 100644 --- a/Protocols/ChaiGearShare.h +++ b/Protocols/ChaiGearShare.h @@ -26,6 +26,8 @@ class ChaiGearShare : public Share const static bool needs_ot = false; + const static true_type covert; + ChaiGearShare() { } @@ -37,5 +39,7 @@ class ChaiGearShare : public Share } }; +template +const true_type ChaiGearShare::covert; #endif /* PROTOCOLS_CHAIGEARSHARE_H_ */ diff --git a/Protocols/CowGearOptions.cpp b/Protocols/CowGearOptions.cpp index 2060679bd..c1fbcaab4 100644 --- a/Protocols/CowGearOptions.cpp +++ b/Protocols/CowGearOptions.cpp @@ -12,38 +12,44 @@ using namespace std; CowGearOptions CowGearOptions::singleton; -CowGearOptions::CowGearOptions() +CowGearOptions::CowGearOptions(bool covert) { - covert_security = 20; - lowgear_from_covert(); - use_top_gear = false; -} + if (covert) + { + covert_security = 20; + } + else + { + covert_security = -1; + } -void CowGearOptions::lowgear_from_covert() -{ - lowgear_security = ceil(log2(covert_security)); + lowgear_security = 40; + use_top_gear = false; } CowGearOptions::CowGearOptions(ez::ezOptionParser& opt, int argc, - const char** argv) : CowGearOptions() + const char** argv, bool covert) : CowGearOptions(covert) { + if (covert) + { + opt.add( + "", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + ("Covert security parameter c. " + "Cheating will be detected with probability 1/c (default: " + + to_string(covert_security) + ")").c_str(), // Help description. + "-c", // Flag token. + "--covert-security" // Flag token. + ); + } opt.add( "", // Default. 0, // Required? 1, // Number of args expected. 0, // Delimiter if expecting multiple args. - ("Covert security parameter c. " - "Cheating will be detected with probability 1/c (default: " - + to_string(covert_security) + ")").c_str(), // Help description. - "-c", // Flag token. - "--covert-security" // Flag token. - ); - opt.add( - "", // Default. - 0, // Required? - 1, // Number of args expected. - 0, // Delimiter if expecting multiple args. - "LowGear security parameter (default: ceil(log2(c))", // Help description. + "LowGear security parameter (default: 40)", // Help description. "-l", // Flag token. "--lowgear-security" // Flag token. ); @@ -71,8 +77,6 @@ CowGearOptions::CowGearOptions(ez::ezOptionParser& opt, int argc, if (covert_security > (1LL << lowgear_security)) insecure(", LowGear security less than key generation security"); } - else - lowgear_from_covert(); use_top_gear = opt.isSet("-T"); opt.resetArgs(); } diff --git a/Protocols/CowGearOptions.h b/Protocols/CowGearOptions.h index 2c8de8ea0..f79bd5212 100644 --- a/Protocols/CowGearOptions.h +++ b/Protocols/CowGearOptions.h @@ -12,21 +12,25 @@ class CowGearOptions { bool use_top_gear; - void lowgear_from_covert(); - public: static CowGearOptions singleton; int covert_security; int lowgear_security; - CowGearOptions(); - CowGearOptions(ez::ezOptionParser& opt, int argc, const char** argv); + CowGearOptions(bool covert = true); + CowGearOptions(ez::ezOptionParser& opt, int argc, const char** argv, + bool covert = true); bool top_gear() { return use_top_gear; } + + void set_top_gear(bool use) + { + use_top_gear = use; + } }; #endif /* PROTOCOLS_COWGEAROPTIONS_H_ */ diff --git a/Protocols/CowGearPrep.h b/Protocols/CowGearPrep.h index a62614247..93c973489 100644 --- a/Protocols/CowGearPrep.h +++ b/Protocols/CowGearPrep.h @@ -24,6 +24,11 @@ class CowGearPrep : public MaliciousRingPrep PairwiseGenerator& get_generator(); + template + void buffer_bits(true_type); + template + void buffer_bits(false_type); + public: static void basic_setup(Player& P); static void key_setup(Player& P, mac_key_type alphai); @@ -33,6 +38,7 @@ class CowGearPrep : public MaliciousRingPrep CowGearPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), MaliciousRingPrep(proc, usage), pairwise_generator(0) { diff --git a/Protocols/CowGearPrep.hpp b/Protocols/CowGearPrep.hpp index d296f6ca7..866ec9d43 100644 --- a/Protocols/CowGearPrep.hpp +++ b/Protocols/CowGearPrep.hpp @@ -8,6 +8,7 @@ #include "Tools/Bundle.h" #include "Protocols/ReplicatedPrep.hpp" +#include "FHEOffline/DataSetup.hpp" template PairwiseMachine* CowGearPrep::pairwise_machine = 0; @@ -39,8 +40,9 @@ void CowGearPrep::basic_setup(Player& P) auto& setup = machine.setup(); auto& options = CowGearOptions::singleton; #ifdef VERBOSE - cerr << "Covert security parameter for key and MAC generation: " - << options.covert_security << endl; + if (T::covert) + cerr << "Covert security parameter for key and MAC generation: " + << options.covert_security << endl; cerr << "LowGear security parameter: " << options.lowgear_security << endl; #endif setup.secure_init(P, machine, T::clear::length(), options.lowgear_security); @@ -66,8 +68,8 @@ void CowGearPrep::key_setup(Player& P, mac_key_type alphai) auto& machine = *pairwise_machine; auto& setup = machine.setup(); auto& options = CowGearOptions::singleton; - read_or_generate_covert_secrets(setup, P, machine, - options.covert_security); + read_or_generate_secrets(setup, P, machine, + options.covert_security, T::covert); // adjust mac key mac_key_type diff = alphai - setup.alphai; @@ -149,14 +151,22 @@ void CowGearPrep::buffer_inputs(int player) template inline void CowGearPrep::buffer_bits() +{ + buffer_bits<0>(T::clear::characteristic_two); +} + +template +template +void CowGearPrep::buffer_bits(false_type) { buffer_bits_from_squares(*this); } -template<> -inline void CowGearPrep>::buffer_bits() +template +template +void CowGearPrep::buffer_bits(true_type) { - buffer_bits_without_check(); + this->buffer_bits_without_check(); assert(not this->bits.empty()); for (auto& bit : this->bits) bit.force_to_bit(); diff --git a/Protocols/CowGearShare.h b/Protocols/CowGearShare.h index 4f810184b..66dd59ad8 100644 --- a/Protocols/CowGearShare.h +++ b/Protocols/CowGearShare.h @@ -26,6 +26,8 @@ class CowGearShare : public Share const static bool needs_ot = false; + const static true_type covert; + CowGearShare() { } @@ -37,4 +39,7 @@ class CowGearShare : public Share }; +template +const true_type CowGearShare::covert; + #endif /* PROTOCOLS_COWGEARSHARE_H_ */ diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index cce26b222..83379c7c4 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -47,6 +47,11 @@ class FakeProtocol : public ProtocolBase } #endif + FakeProtocol branch() + { + return P; + } + void init_mul(SubProcessor*) { results.clear(); diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index e5e4fd8b9..36b0b4d1b 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -68,8 +68,9 @@ class FakeShare : public T, public ShareInterface *this = a - b; } - static void split(vector& dest, const vector& regs, int n_bits, - const This* source, int n_inputs, Player& P); + static void split(vector& dest, const vector& regs, + int n_bits, const This* source, int n_inputs, + GC::FakeSecret::Protocol& protocol); }; #endif /* PROTOCOLS_FAKESHARE_H_ */ diff --git a/Protocols/FakeShare.hpp b/Protocols/FakeShare.hpp index 29e49db3e..7994f5104 100644 --- a/Protocols/FakeShare.hpp +++ b/Protocols/FakeShare.hpp @@ -10,7 +10,7 @@ template void FakeShare::split(vector& dest, const vector& regs, int n_bits, const This* source, int n_inputs, - Player&) + GC::FakeSecret::Protocol&) { assert(n_bits <= 64); int unit = GC::Clear::N_BITS; diff --git a/Protocols/HighGearKeyGen.cpp b/Protocols/HighGearKeyGen.cpp new file mode 100644 index 000000000..c3465634e --- /dev/null +++ b/Protocols/HighGearKeyGen.cpp @@ -0,0 +1,39 @@ +/* + * KeyGen.cpp + * + */ + +#include "FHEOffline/DataSetup.h" +#include "Processor/OnlineOptions.h" + +#include "Protocols/HighGearKeyGen.hpp" + +template<> +void PartSetup::key_and_mac_generation(Player& P, + MachineBase& machine, int, false_type) +{ + auto& batch_size = OnlineOptions::singleton.batch_size; + auto backup = batch_size; + batch_size = 100; + bool done = false; + int n_limbs[2]; + for (int i = 0; i < 2; i++) + n_limbs[i] = params.FFTD()[i].get_prD().get_t(); +#define X(L, M) \ + if (n_limbs[0] == L and n_limbs[1] == M) \ + { \ + HighGearKeyGen(P, params).run(*this, machine); \ + done = true; \ + } + X(5, 3) X(4, 3) X(3, 2) + if (not done) + throw runtime_error("not compiled for choice of parameters"); + batch_size = backup; +} + +template<> +void PartSetup::key_and_mac_generation(Player& P, MachineBase& machine, + int, false_type) +{ + HighGearKeyGen<2, 2>(P, params).run(*this, machine); +} diff --git a/Protocols/HighGearKeyGen.h b/Protocols/HighGearKeyGen.h new file mode 100644 index 000000000..6f9a7b665 --- /dev/null +++ b/Protocols/HighGearKeyGen.h @@ -0,0 +1,71 @@ +/* + * HighGearKeyGen.h + * + */ + +#ifndef PROTOCOLS_HIGHGEARKEYGEN_H_ +#define PROTOCOLS_HIGHGEARKEYGEN_H_ + +#include "LowGearKeyGen.h" + +#include +using namespace std; + +template +class KeyGenBitFactory +{ + U& keygen; + deque& buffer; + +public: + KeyGenBitFactory(U& keygen, deque& buffer) : + keygen(keygen), buffer(buffer) + { + } + + T get_bit() + { + if (buffer.empty()) + keygen.buffer_mabits(); + auto res = buffer.front(); + buffer.pop_front(); + return res; + } +}; + +template +class HighGearKeyGen +{ +public: + typedef KeyGenProtocol<5, L> Proto0; + typedef KeyGenProtocol<7, M> Proto1; + + typedef typename Proto0::share_type share_type0; + typedef typename Proto1::share_type share_type1; + + typedef typename share_type0::open_type open_type0; + typedef typename share_type1::open_type open_type1; + + typedef ShareVector vector_type0; + typedef ShareVector vector_type1; + + typedef typename share_type0::bit_type BT; + + Player& P; + const FHE_Params& params; + + Proto0 proto0; + Proto1 proto1; + + deque bits0; + deque bits1; + + HighGearKeyGen(Player& P, const FHE_Params& params); + + void buffer_mabits(); + + template + void run(PartSetup& setup, MachineBase& machine); +}; + +#endif /* PROTOCOLS_HIGHGEARKEYGEN_H_ */ diff --git a/Protocols/HighGearKeyGen.hpp b/Protocols/HighGearKeyGen.hpp new file mode 100644 index 000000000..9645a3319 --- /dev/null +++ b/Protocols/HighGearKeyGen.hpp @@ -0,0 +1,159 @@ +/* + * HighGearKeyGen.cpp + * + */ + +#include "HighGearKeyGen.h" +#include "FHE/Rq_Element.h" + +#include "LowGearKeyGen.hpp" + +template +HighGearKeyGen::HighGearKeyGen(Player& P, const FHE_Params& params) : + P(P), params(params), proto0(P, params, 0), proto1(P, params, 1) +{ +} + +template +void HighGearKeyGen::buffer_mabits() +{ + vector diffs; + vector open_diffs; + vector my_bits0; + vector my_bits1; + int batch_size = 1000; + auto& bmc = *GC::ShareThread::s().MC; + for (int i = 0; i < batch_size; i++) + { + share_type0 a0; + share_type1 a1; + BT b0, b1; + proto0.prep->get_dabit(a0, b0); + proto1.prep->get_dabit(a1, b1); + my_bits0.push_back(a0); + my_bits1.push_back(a1); + diffs.push_back(b0 + b1); + } + bmc.POpen(open_diffs, diffs, P); + bmc.Check(P); + for (int i = 0; i < batch_size; i++) + { + bits0.push_back(my_bits0[i]); + bits1.push_back( + my_bits1[i] + + share_type1::constant(open_diffs.at(i), P.my_num(), + proto1.MC->get_alphai()) + - my_bits1[i] * open_diffs.at(i) * 2); + } +} + +template +template +void HighGearKeyGen::run(PartSetup& setup, MachineBase& machine) +{ + RunningTimer timer; + + GlobalPRNG global_prng(P); + auto& fftd = params.FFTD(); + + AddableVector a0(params.phi_m()), a0_prime(params.phi_m()); + AddableVector a1(params.phi_m()), a1_prime(params.phi_m()); + a0.randomize(global_prng); + a1.randomize(global_prng); + a0_prime.randomize(global_prng); + a1_prime.randomize(global_prng); + + KeyGenBitFactory> factory0(*this, bits0); + KeyGenBitFactory> factory1(*this, bits1); + + vector_type0 sk0; + vector_type1 sk1; + proto0.secret_key(sk0, factory0); + proto1.secret_key(sk1, factory1); + + vector_type0 e0, e0_prime; + vector_type1 e1, e1_prime; + proto0.binomial(e0, factory0); + proto0.binomial(e0_prime, factory0); + proto1.binomial(e1, factory1); + proto1.binomial(e1_prime, factory1); + + auto f0 = sk0; + auto f0_prime = proto0.schur_product(f0, f0); + + Rq_Element a(Ring_Element(fftd[0], evaluation, a0), + Ring_Element(fftd[1], evaluation, a1)); + Rq_Element Sw_a(Ring_Element(fftd[0], evaluation, a0_prime), + Ring_Element(fftd[1], evaluation, a1_prime)); + + bigint p = setup.FieldD.get_prime(); + bigint p1 = fftd[1].get_prime(); + vector b0, b0_prime; + vector b1, b1_prime; + proto0.MC->POpen(b0, sk0 * a0 + e0 * p, P); + proto1.MC->POpen(b1, sk1 * a1 + e1 * p, P); + proto0.MC->POpen(b0_prime, sk0 * a0_prime + e0_prime * p - f0_prime * p1, P); + proto1.MC->POpen(b1_prime, sk1 * a1_prime + e1_prime * p, P); + + Rq_Element b(Ring_Element(fftd[0], evaluation, b0), + Ring_Element(fftd[1], evaluation, b1)); + Rq_Element Sw_b(Ring_Element(fftd[0], evaluation, b0_prime), + Ring_Element(fftd[1], evaluation, b1_prime)); + + setup.pk.assign(a, b, Sw_a, Sw_b); + + vector s0_shares; + vector s1_shares; + for (int i = 0; i < params.phi_m(); i++) + { + s0_shares.push_back(sk0.at(i).get_share()); + s1_shares.push_back(sk1.at(i).get_share()); + } + setup.sk.assign({Ring_Element(fftd[0], evaluation, s0_shares), + Ring_Element(fftd[1], evaluation, s1_shares)}); + + GC::ShareThread::s().MC->Check(P); + +#ifdef DEBUG_HIGHGEAR_KEYGEN + proto0.MC->POpen(s0_shares, sk0, P); + proto1.MC->POpen(s1_shares, sk1, P); + + vector e0_open, e0_prime_open; + vector e1_open, e1_prime_open; + proto0.MC->POpen(e0_open, e0, P); + proto0.MC->POpen(e0_prime_open, e0_prime, P); + proto1.MC->POpen(e1_open, e1, P); + proto1.MC->POpen(e1_prime_open, e1_prime, P); + + Rq_Element s(fftd, s0_shares, s1_shares); + Rq_Element e(fftd, e0_open, e1_open); + assert(b == s * a + e * p); + + Rq_Element e_prime(fftd, e0_prime_open, e1_prime_open); + assert(Sw_b == s * Sw_a + e_prime * p - s * s * p1); + + cerr << "Revealed secret key for check" << endl; +#endif + + cerr << "Key generation took " << timer.elapsed() << " seconds" << endl; + timer.reset(); + + map timers; + SimpleEncCommit_ EC(P, setup.pk, setup.FieldD, timers, machine, 0, true); + Plaintext_ alpha(setup.FieldD); + EC.next(alpha, setup.calpha); + assert(alpha.is_diagonal()); + + setup.alphai = alpha.element(0); + + cerr << "MAC key generation took " << timer.elapsed() << " seconds" << endl; + +#ifdef DEBUG_HIGHGEAR_KEYGEN + auto d = SemiMC>().open(setup.sk, P).decrypt( + setup.calpha, setup.FieldD); + auto dd = SemiMC>().open(setup.alphai, P); + for (unsigned i = 0; i < d.num_slots(); i++) + assert(d.element(i) == dd); + cerr << "Revealed MAC key for check" << endl; +#endif +} diff --git a/Protocols/HighGearShare.h b/Protocols/HighGearShare.h new file mode 100644 index 000000000..faa7fa267 --- /dev/null +++ b/Protocols/HighGearShare.h @@ -0,0 +1,41 @@ +/* + * HighGearShare.h + * + */ + +#ifndef PROTOCOLS_HIGHGEARSHARE_H_ +#define PROTOCOLS_HIGHGEARSHARE_H_ + +#include "ChaiGearShare.h" + +template +class HighGearShare : public ChaiGearShare +{ + typedef HighGearShare This; + typedef ChaiGearShare super; + +public: + typedef MAC_Check_ MAC_Check; + typedef Direct_MAC_Check Direct_MC; + typedef ::Input Input; + typedef ::PrivateOutput PrivateOutput; + typedef SPDZ Protocol; + typedef ChaiGearPrep LivePrep; + + const static false_type covert; + + HighGearShare() + { + } + + template + HighGearShare(const U& other) : + super(other) + { + } +}; + +template +const false_type HighGearShare::covert; + +#endif /* PROTOCOLS_HIGHGEARSHARE_H_ */ diff --git a/Protocols/LowGearKeyGen.cpp b/Protocols/LowGearKeyGen.cpp new file mode 100644 index 000000000..ffa42eb79 --- /dev/null +++ b/Protocols/LowGearKeyGen.cpp @@ -0,0 +1,34 @@ +/* + * LowGearKeyGen.cpp + * + */ + +#include "FHEOffline/DataSetup.h" +#include "Processor/OnlineOptions.h" + +#include "Protocols/LowGearKeyGen.hpp" + +template<> +void PairwiseSetup::key_and_mac_generation(Player& P, + PairwiseMachine& machine, int, false_type) +{ + int n_limbs = params.FFTD()[0].get_prD().get_t(); + switch (n_limbs) + { +#define X(L) case L: LowGearKeyGen(P, machine, params).run(*this); break; + X(3) X(4) X(5) X(6) +#undef X + default: + throw runtime_error( + "not compiled for choice of parameters, add X(" + + to_string(n_limbs) + ") at " + __FILE__ + ":" + + to_string(__LINE__ - 5)); + } +} + +template<> +void PairwiseSetup::key_and_mac_generation(Player& P, + PairwiseMachine& machine, int, false_type) +{ + LowGearKeyGen<2>(P, machine, params).run(*this); +} diff --git a/Protocols/LowGearKeyGen.h b/Protocols/LowGearKeyGen.h new file mode 100644 index 000000000..396b68e1c --- /dev/null +++ b/Protocols/LowGearKeyGen.h @@ -0,0 +1,78 @@ +/* + * LowGearKeyGen.h + * + */ + +#ifndef PROTOCOLS_LOWGEARKEYGEN_H_ +#define PROTOCOLS_LOWGEARKEYGEN_H_ + +#include "ShareVector.h" +#include "FHEOffline/PairwiseMachine.h" +#include "Protocols/MascotPrep.h" +#include "Processor/Processor.h" +#include "GC/TinierSecret.h" +#include "Math/gfp.h" + +template +class KeyGenProtocol +{ +public: + typedef Share> share_type; + typedef typename share_type::open_type open_type; + typedef ShareVector vector_type; + +protected: + Player& P; + const FHE_Params& params; + const FFT_Data& fftd; + + SeededPRNG G; + DataPositions usage; + + share_type get_bit() + { + return prep->get_bit(); + } + +public: + Preprocessing* prep; + MAC_Check_* MC; + SubProcessor* proc; + + KeyGenProtocol(Player& P, const FHE_Params& params, int level = 0); + ~KeyGenProtocol(); + + void input(vector& shares, const Rq_Element& secret); + template + void binomial(vector_type& shares, T& prep); + template + void hamming(vector_type& shares, T& prep); + template + void secret_key(vector_type& shares, T& prep); + vector_type schur_product(const vector_type& x, const vector_type& y); + void output_to(int player, vector& opened, + vector& shares); +}; + +template +class LowGearKeyGen : public KeyGenProtocol<5, L> +{ + typedef KeyGenProtocol<5, L> super; + + typedef typename super::share_type share_type; + typedef typename super::open_type open_type; + typedef typename super::vector_type vector_type; + + Player& P; + PairwiseMachine& machine; + + void generate_keys(FHE_Params& params); + +public: + LowGearKeyGen(Player& P, PairwiseMachine& machine, FHE_Params& params); + + template + void run(PairwiseSetup& setup); +}; + +#endif /* PROTOCOLS_LOWGEARKEYGEN_H_ */ diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp new file mode 100644 index 000000000..e820e89c4 --- /dev/null +++ b/Protocols/LowGearKeyGen.hpp @@ -0,0 +1,284 @@ +/* + * LowGearKeyGen.cpp + * + */ + +#include "LowGearKeyGen.h" +#include "FHE/Rq_Element.h" + +#include "Tools/benchmarking.h" + +#include "Machines/SPDZ.hpp" +#include "ShareVector.hpp" + +template +LowGearKeyGen::LowGearKeyGen(Player& P, PairwiseMachine& machine, + FHE_Params& params) : + KeyGenProtocol<5, L>(P, params), P(P), machine(machine) +{ +} + +template +KeyGenProtocol::KeyGenProtocol(Player& P, const FHE_Params& params, + int level) : + P(P), params(params), fftd(params.FFTD().at(level)), usage(P) +{ + open_type::init_field(params.FFTD().at(level).get_prD().pr); + typename share_type::mac_key_type alphai; + + if (OnlineOptions::singleton.live_prep) + { + prep = new MascotDabitOnlyPrep(0, usage); + alphai.randomize(G); + } + else + { + prep = new Sub_Data_Files(P.N, + get_prep_sub_dir(P.num_players()), usage); + read_mac_key(get_prep_sub_dir(P.num_players()), P.N, alphai); + } + + MC = new MAC_Check_(alphai); + proc = new SubProcessor(*MC, *prep, P); +} + +template +KeyGenProtocol::~KeyGenProtocol() +{ + MC->Check(P); + + usage.print_cost(); + + delete proc; + delete prep; + delete MC; +} + +template +void KeyGenProtocol::input(vector& shares, const Rq_Element& secret) +{ + assert(secret.level() == 0); + auto s = secret.get(0); + s.change_rep(evaluation); + auto& FFTD = s.get_FFTD(); + auto& inputter = this->proc->input; + inputter.reset_all(P); + for (int i = 0; i < FFTD.num_slots(); i++) + inputter.add_from_all(s.get_element(i)); + inputter.exchange(); + shares.clear(); + shares.resize(P.num_players()); + for (int i = 0; i < FFTD.num_slots(); i++) + for (int j = 0; j < P.num_players(); j++) + shares[j].push_back(inputter.finalize(j)); +} + +template +template +void KeyGenProtocol::binomial(vector_type& shares, T& prep) +{ + shares.resize(params.phi_m()); + RunningTimer timer, total; + for (int i = 0; i < params.phi_m(); i++) + { +#ifdef VERBOSE + if (timer.elapsed() > 10) + { + cerr << i << "/" << params.phi_m() << ", throughput " << + i / total.elapsed() << endl; + timer.reset(); + } +#endif + + auto& share = shares[i]; + share = {}; + for (int i = 0; i < params.get_DG().get_NewHopeB(); i++) + { + share += prep.get_bit(); + share -= prep.get_bit(); + } + } + shares.fft(fftd); +} + +template +template +void KeyGenProtocol::hamming(vector_type& shares, T& prep) +{ + shares.resize(params.phi_m()); + int h = params.get_h(); + assert(h > 0); + assert(shares.size() / h * h == shares.size()); + int n_bits = log(shares.size() / h) / log(2); +// assert(size_t(h << n_bits) == shares.size()); + + for (auto& share : shares) + share = prep.get_bit(); + + auto& protocol = proc->protocol; + for (int i = 0; i < n_bits - 1; i++) + { + protocol.init_mul(proc); + for (auto& share : shares) + protocol.prepare_mul(share, prep.get_bit()); + protocol.exchange(); + for (auto& share : shares) + share = protocol.finalize_mul(); + } + + protocol.init_mul(proc); + auto one = share_type::constant(1, P.my_num(), MC->get_alphai()); + for (auto& share : shares) + protocol.prepare_mul(share, prep.get_bit() * 2 - one); + protocol.exchange(); + for (auto& share : shares) + share = protocol.finalize_mul(); + + shares.fft(fftd); +} + +template +template +void KeyGenProtocol::secret_key(vector_type& shares, T& prep) +{ + assert(params.get_h() != 0); + cerr << "Generate secret key by "; + if (params.get_h() > 0) + { + cerr << "Hamming weight" << endl; + hamming(shares, prep); + } + else + { + cerr << "binomial" << endl; + binomial(shares, prep); + } +} + +template +typename KeyGenProtocol::vector_type KeyGenProtocol::schur_product( + const vector_type& x, const vector_type& y) +{ + vector_type res; + assert(x.size() == y.size()); + auto& protocol = proc->protocol; + protocol.init_mul(proc); + for (size_t i = 0; i < x.size(); i++) + protocol.prepare_mul(x[i], y[i]); + protocol.exchange(); + for (size_t i = 0; i < x.size(); i++) + res.push_back(protocol.finalize_mul()); + return res; +} + +template +void KeyGenProtocol::output_to(int player, vector& opened, + vector& shares) +{ + PrivateOutput po(*proc); + vector masked; + for (auto& share : shares) + masked.push_back(po.start(player, share)); + MC->POpen(opened, masked, P); + for (auto& x : opened) + x = po.stop(player, x); +} + +template +void LowGearKeyGen::generate_keys(FHE_Params& params) +{ + RunningTimer timer; + auto& pk = machine.pk; + + GlobalPRNG global_prng(P); + auto& FFTD = pk.get_params().FFTD()[0]; + + for (int i = 0; i < P.num_players(); i++) + { + vector_type sk; + this->secret_key(sk, *this); + vector open_sk; + this->output_to(i, open_sk, sk); + if (P.my_num() == i) + machine.sk.assign(Ring_Element(FFTD, evaluation, open_sk)); + vector_type e0; + this->binomial(e0, *this); + AddableVector a0(pk.get_params().phi_m()); + a0.randomize(global_prng); + vector b0; + assert(machine.sk.p() != 0); + this->MC->POpen(b0, sk * a0 + e0 * machine.sk.p(), P); + machine.other_pks[i] = FHE_PK(params, machine.sk.p()); + machine.other_pks[i].assign(Ring_Element(FFTD, evaluation, a0), + Ring_Element(FFTD, evaluation, b0)); + } + + this->MC->Check(P); + + cerr << "Key generation took " << timer.elapsed() << " seconds" << endl; +} + +template +template +void LowGearKeyGen::run(PairwiseSetup& setup) +{ + generate_keys(setup.params); + machine.sk.check(machine.pk, setup.FieldD); + + RunningTimer timer; + + auto mac_key = SeededPRNG().get(); + + PairwiseGenerator generator(0, machine, &P); + map timers; + MultiEncCommit EC(P, machine.other_pks, setup.FieldD, + timers, machine, generator, true); + assert(EC.proof.get_diagonal()); + vector> m(EC.proof.U, setup.FieldD); + for (auto& mm : m) + mm.assign_constant(mac_key); + + AddableVector C; + octetStream ciphertexts, cleartexts; + EC.generate_proof(C, m, ciphertexts, cleartexts); + + AddableVector others_ciphertexts; + others_ciphertexts.resize(EC.proof.U, machine.pk.get_params()); + Verifier verifier(EC.proof, setup.FieldD); + verifier.NIZKPoK(others_ciphertexts, ciphertexts, + cleartexts, machine.pk, false); + + machine.enc_alphas.clear(); + for (int i = 0; i < P.num_players(); i++) + machine.enc_alphas.push_back(machine.other_pks[i]); + + for (int i = 1; i < P.num_players(); i++) + { + int player = P.get_player(-i); +#ifdef VERBOSE_HE + cerr << "Sending proof with " << 1e-9 * ciphertexts.get_length() << "+" + << 1e-9 * cleartexts.get_length() << " GB" << endl; +#endif + timers["Sending"].start(); + P.pass_around(ciphertexts); + P.pass_around(cleartexts); + timers["Sending"].stop(); +#ifdef VERBOSE_HE + cerr << "Checking proof of player " << i << endl; +#endif + timers["Verifying"].start(); + verifier.NIZKPoK(others_ciphertexts, ciphertexts, + cleartexts, machine.other_pks[player], false); + timers["Verifying"].stop(); + machine.enc_alphas.at(player) = others_ciphertexts.at(0); + } + + setup.set_alphai(mac_key); + machine.enc_alphas.at(P.my_num()) = C.at(0); + + auto test = machine.sk.decrypt(C[0], setup.FieldD); + for (int i = 0; i < setup.FieldD.num_slots(); i++) + assert(test.element(i) == mac_key); + + cerr << "MAC key generation took " << timer.elapsed() << " seconds" << endl; +} diff --git a/Protocols/LowGearShare.h b/Protocols/LowGearShare.h new file mode 100644 index 000000000..afb2b7dac --- /dev/null +++ b/Protocols/LowGearShare.h @@ -0,0 +1,41 @@ +/* + * LowGearShare.h + * + */ + +#ifndef PROTOCOLS_LOWGEARSHARE_H_ +#define PROTOCOLS_LOWGEARSHARE_H_ + +#include "CowGearShare.h" + +template +class LowGearShare : public CowGearShare +{ + typedef LowGearShare This; + typedef CowGearShare super; + +public: + typedef MAC_Check_ MAC_Check; + typedef Direct_MAC_Check Direct_MC; + typedef ::Input Input; + typedef ::PrivateOutput PrivateOutput; + typedef SPDZ Protocol; + typedef CowGearPrep LivePrep; + + const static false_type covert; + + LowGearShare() + { + } + + template + LowGearShare(const U& other) : + super(other) + { + } +}; + +template +const false_type LowGearShare::covert; + +#endif /* PROTOCOLS_LOWGEARSHARE_H_ */ diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index 4a784cdd5..2dc93cbcb 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -40,6 +40,7 @@ MalRepRingPrepWithBits::MalRepRingPrepWithBits(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), MaliciousRingPrep(proc, usage), MalRepRingPrep(proc, usage), RingOnlyBitsFromSquaresPrep(proc, usage), SimplerMalRepRingPrep(proc, usage) diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index f2bf9d984..e968ae0b0 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -32,6 +32,7 @@ MaliciousRepPrepWithBits::MaliciousRepPrepWithBits(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), MaliciousRepPrep(proc, usage), BitPrep(proc, usage), RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), MaliciousRingPrep(proc, usage) { } @@ -173,6 +174,7 @@ void MaliciousBitOnlyRepPrep::buffer_bits() auto buffer_size = this->buffer_size; assert(honest_proc); Player& P = honest_proc->P; + honest_prep.buffer_size = buffer_size; bits.clear(); for (int i = 0; i < buffer_size; i++) { diff --git a/Protocols/MaliciousRingPrep.hpp b/Protocols/MaliciousRingPrep.hpp index 4caf78fe1..ba45184ff 100644 --- a/Protocols/MaliciousRingPrep.hpp +++ b/Protocols/MaliciousRingPrep.hpp @@ -12,21 +12,21 @@ #include "Spdz2kPrep.hpp" template -void MaliciousRingPrep::buffer_dabits(ThreadQueues* queues) +void MaliciousDabitOnlyPrep::buffer_dabits(ThreadQueues* queues) { buffer_dabits<0>(queues, T::clear::characteristic_two); } template template -void MaliciousRingPrep::buffer_dabits(ThreadQueues*, true_type) +void MaliciousDabitOnlyPrep::buffer_dabits(ThreadQueues*, true_type) { throw runtime_error("only implemented for integer-like domains"); } template template -void MaliciousRingPrep::buffer_dabits(ThreadQueues* queues, false_type) +void MaliciousDabitOnlyPrep::buffer_dabits(ThreadQueues* queues, false_type) { assert(this->proc != 0); vector> check_dabits; diff --git a/Protocols/MaliciousShamirMC.hpp b/Protocols/MaliciousShamirMC.hpp index 0309852e8..41c6f208a 100644 --- a/Protocols/MaliciousShamirMC.hpp +++ b/Protocols/MaliciousShamirMC.hpp @@ -24,7 +24,8 @@ void MaliciousShamirMC::init_open(const Player& P, int n) reconstructions[i].resize(i); for (int j = 0; j < i; j++) reconstructions[i][j] = - Shamir::get_rec_factor(j, i); + Shamir::get_rec_factor(P.get_player(j), + P.num_players(), P.my_num(), i); } } @@ -37,7 +38,7 @@ typename T::open_type MaliciousShamirMC::finalize_open() int threshold = ShamirMachine::s().threshold; shares.resize(2 * threshold + 1); for (size_t j = 0; j < shares.size(); j++) - shares[j].unpack((*this->os)[j]); + shares[j].unpack((*this->os)[this->player->get_player(j)]); return reconstruct(shares); } diff --git a/Protocols/MaliciousShamirPO.hpp b/Protocols/MaliciousShamirPO.hpp index 6c6f13bec..1e867d176 100644 --- a/Protocols/MaliciousShamirPO.hpp +++ b/Protocols/MaliciousShamirPO.hpp @@ -9,6 +9,7 @@ template MaliciousShamirPO::MaliciousShamirPO(Player& P) : P(P), shares(P.num_players()) { + MC.init_open(P); } template diff --git a/Protocols/MamaPrep.hpp b/Protocols/MamaPrep.hpp index ff91836a4..ef61ec7b9 100644 --- a/Protocols/MamaPrep.hpp +++ b/Protocols/MamaPrep.hpp @@ -10,7 +10,9 @@ template MamaPrep::MamaPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), - RingPrep(proc, usage), OTPrep(proc, usage), + RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), + OTPrep(proc, usage), MaliciousRingPrep(proc, usage) { this->params.amplify = true; diff --git a/Protocols/MamaShare.h b/Protocols/MamaShare.h index b875bd42c..fa3bc9f03 100644 --- a/Protocols/MamaShare.h +++ b/Protocols/MamaShare.h @@ -58,6 +58,7 @@ class MamaShare : public Share_, MamaMac> typedef MascotMultiplier Multiplier; typedef FixedVec sacri_type; typedef This input_type; + typedef This input_check_type; typedef MamaRectangle Square; typedef typename T::Square Rectangle; @@ -95,6 +96,12 @@ class MamaShare : public Share_, MamaMac> super(other.get_share(), other.get_mac()) { } + + template + MamaShare(const U& share, const V& mac) : + super(share, mac) + { + } }; #endif /* PROTOCOLS_MAMASHARE_H_ */ diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 50d08735f..3972b91a7 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -27,50 +27,88 @@ class OTPrep : public virtual BitPrep }; template -class MascotTriplePrep : public OTPrep +class MascotInputPrep : public OTPrep { + void buffer_inputs(int player); + public: - MascotTriplePrep(SubProcessor *proc, DataPositions &usage) : + MascotInputPrep(SubProcessor *proc, DataPositions &usage) : BufferPrep(usage), BitPrep(proc, usage), OTPrep(proc, usage) { } +}; + +template +class MascotTriplePrep : public MascotInputPrep +{ +public: + MascotTriplePrep(SubProcessor *proc, DataPositions &usage) : + BufferPrep(usage), BitPrep(proc, usage), + MascotInputPrep(proc, usage) + { + } void buffer_triples(); - void buffer_inputs(int player); }; template -class MascotPrep: public virtual MaliciousRingPrep, +class MascotDabitOnlyPrep : public virtual MaliciousDabitOnlyPrep, public virtual MascotTriplePrep { +public: + MascotDabitOnlyPrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), BitPrep(proc, usage), + RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), + MascotTriplePrep(proc, usage) + { + } + virtual ~MascotDabitOnlyPrep() + { + } + + virtual void buffer_bits(); +}; + +template +class MascotPrep : public virtual MaliciousRingPrep, + public virtual MascotDabitOnlyPrep +{ public: MascotPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), MaliciousRingPrep(proc, usage), - MascotTriplePrep(proc, usage) + MascotTriplePrep(proc, usage), + MascotDabitOnlyPrep(proc, usage) { } virtual ~MascotPrep() { } - void buffer_bits() { throw runtime_error("use subclass"); } + virtual void buffer_bits() + { + MascotDabitOnlyPrep::buffer_bits(); + } + void buffer_edabits(bool strict, int n_bits, ThreadQueues* queues); }; template -class MascotFieldPrep : public MascotPrep +class MascotFieldPrep : public virtual MascotPrep { - void buffer_bits(); - public: MascotFieldPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), MaliciousRingPrep(proc, usage), - MascotTriplePrep(proc, usage), MascotPrep(proc, usage) + MascotTriplePrep(proc, usage), + MascotDabitOnlyPrep(proc, usage), + MascotPrep(proc, usage) { } }; diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index 50d786405..317469f95 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -36,6 +36,10 @@ void OTPrep::set_protocol(typename T::Protocol& protocol) BitPrep::set_protocol(protocol); SubProcessor* proc = this->proc; assert(proc != 0); + + // make sure not to use Montgomery multiplication + T::open_type::next::template init(false); + triple_generator = new typename T::TripleGenerator( BaseMachine::s().fresh_ot_setup(), proc->P.N, -1, @@ -66,7 +70,7 @@ void MascotTriplePrep::buffer_triples() } template -void MascotFieldPrep::buffer_bits() +void MascotDabitOnlyPrep::buffer_bits() { this->params.generateBits = true; auto& triple_generator = this->triple_generator; @@ -78,7 +82,7 @@ void MascotFieldPrep::buffer_bits() } template -void MascotTriplePrep::buffer_inputs(int player) +void MascotInputPrep::buffer_inputs(int player) { auto& triple_generator = this->triple_generator; assert(triple_generator); diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index 9a90f7720..70371744d 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -56,13 +56,14 @@ class PostSacriRepRingShare : public Rep3Share2 } template - static void split(vector& dest, const vector& regs, - int n_bits, const super* source, int n_inputs, Player& P) + static void split(vector& dest, const vector& regs, int n_bits, + const super* source, int n_inputs, + typename bit_type::Protocol& protocol) { if (regs.size() / n_bits != 3) throw runtime_error("only secure with three-way split"); - super::split(dest, regs, n_bits, source, n_inputs, P); + super::split(dest, regs, n_bits, source, n_inputs, protocol); } }; diff --git a/Protocols/PostSacrifice.h b/Protocols/PostSacrifice.h index 0ce53c7d4..c9ed65b6f 100644 --- a/Protocols/PostSacrifice.h +++ b/Protocols/PostSacrifice.h @@ -25,6 +25,8 @@ class PostSacrifice : public ProtocolBase PostSacrifice(Player& P); ~PostSacrifice(); + Player& branch(); + void init_mul(SubProcessor* proc); typename T::clear prepare_mul(const T& x, const T& y, int n = -1); void exchange() { internal.exchange(); } diff --git a/Protocols/PostSacrifice.hpp b/Protocols/PostSacrifice.hpp index c141ef185..4db3b73b4 100644 --- a/Protocols/PostSacrifice.hpp +++ b/Protocols/PostSacrifice.hpp @@ -18,6 +18,12 @@ PostSacrifice::~PostSacrifice() check(); } +template +Player& PostSacrifice::branch() +{ + return P; +} + template void PostSacrifice::init_mul(SubProcessor* proc) { diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index 7c2bc8666..7cdf66235 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -41,9 +41,11 @@ class Rep3Share2 : public Rep3Share> } template - static void split(vector& dest, const vector& regs, - int n_bits, const Rep3Share2* source, int n_inputs, Player& P) + static void split(vector& dest, const vector& regs, int n_bits, + const Rep3Share2* source, int n_inputs, + typename U::Protocol& protocol) { + auto& P = protocol.P; int my_num = P.my_num(); assert(n_bits <= 64); int unit = GC::Clear::N_BITS; diff --git a/Protocols/Rep4.h b/Protocols/Rep4.h index f9bfcd5a2..f2dbaf7a6 100644 --- a/Protocols/Rep4.h +++ b/Protocols/Rep4.h @@ -14,6 +14,7 @@ class Rep4 : public ProtocolBase friend class Rep4RingPrep; typedef typename T::open_type open_type; + typedef array, 3> prngs_type; octetStreams send_os; octetStreams receive_os; @@ -21,6 +22,7 @@ class Rep4 : public ProtocolBase array, 4> send_hashes, receive_hashes; array, 5> add_shares; + array dotprod_shares; vector bit_lengths; class ResTuple @@ -56,10 +58,14 @@ class Rep4 : public ProtocolBase false_type); public: - array, 3> rep_prngs; + prngs_type rep_prngs; Player& P; Rep4(Player& P); + Rep4(Player& P, prngs_type& prngs); + ~Rep4(); + + Rep4 branch(); void init_mul(SubProcessor* proc = 0); void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); @@ -78,6 +84,10 @@ class Rep4 : public ProtocolBase void trunc_pr(const vector& regs, int size, SubProcessor& proc); + template + void split(vector& dest, const vector& regs, int n_bits, + const U* source, int n_inputs); + int get_n_relevant_players() { return 2; } }; diff --git a/Protocols/Rep4.hpp b/Protocols/Rep4.hpp index dc4781f4c..e77b4e6f5 100644 --- a/Protocols/Rep4.hpp +++ b/Protocols/Rep4.hpp @@ -24,6 +24,40 @@ Rep4::Rep4(Player& P) : rep_prngs[i].SetSeed(to_receive[P.get_player(i)].get_data()); } +template +Rep4::Rep4(Player& P, prngs_type& prngs) : + my_num(P.my_num()), P(P) +{ + for (int i = 0; i < 3; i++) + rep_prngs[i].SetSeed(prngs[i]); +} + +template +Rep4::~Rep4() +{ + for (auto& x : receive_hashes) + for (auto& y : x) + if (y.size > 0) + { + check(); + return; + } + + for (auto& x : send_hashes) + for (auto& y : x) + if (y.size > 0) + { + check(); + return; + } +} + +template +Rep4 Rep4::branch() +{ + return {P, rep_prngs}; +} + template void Rep4::init_mul(SubProcessor*) { @@ -184,7 +218,7 @@ template void Rep4::init_dotprod(SubProcessor*) { init_mul(); - next_dotprod(); + dotprod_shares = {}; } template @@ -192,15 +226,16 @@ void Rep4::prepare_dotprod(const T& x, const T& y) { auto a = get_addshares(x, y); for (int i = 0; i < 5; i++) - add_shares[i].back() += a[i]; + dotprod_shares[i] += a[i]; } template void Rep4::next_dotprod() { - for (auto& a : add_shares) - a.push_back({}); + for (int i = 0; i < 5; i++) + add_shares[i].push_back(dotprod_shares[i]); bit_lengths.push_back(-1); + dotprod_shares = {}; } template @@ -234,7 +269,6 @@ T Rep4::finalize_mul(int) template T Rep4::finalize_dotprod(int) { - this->counter++; return finalize_mul(); } @@ -297,6 +331,7 @@ void Rep4::trunc_pr(const vector& regs, int size, SubProcessor& proc, false_type) { assert(regs.size() % 4 == 0); + this->trunc_pr_counter += size * regs.size() / 4; typedef typename T::open_type open_type; vector> infos; @@ -419,3 +454,76 @@ void Rep4::trunc_pr(const vector& regs, int size, } } } + +template +template +void Rep4::split(vector& dest, const vector& regs, int n_bits, + const U* source, int n_inputs) +{ + assert(regs.size() / n_bits == 2); + assert(n_bits <= 64); + int unit = GC::Clear::N_BITS; + int my_num = P.my_num(); + int i0 = -1; + + switch (my_num) + { + case 0: + i0 = 1; + break; + case 1: + i0 = 0; + break; + case 2: + i0 = 1; + break; + case 3: + i0 = 0; + break; + } + + vector to_share; + init_mul(); + + for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++) + { + int start = k * unit; + int m = min(unit, n_inputs - start); + + square64 square; + + for (int j = 0; j < m; j++) + { + auto& input_share = source[j + start]; + auto input_value = input_share[i0] + input_share[i0 + 1]; + square.rows[j] = Integer(input_value).get(); + } + + square.transpose(m, n_bits); + + for (int j = 0; j < n_bits; j++) + { + to_share.push_back(square.rows[j]); + bit_lengths.push_back(m); + } + } + + array, 2> results; + for (auto& x : results) + x.resize(to_share.size()); + prepare_joint_input(0, 1, 3, 2, to_share, results[0]); + prepare_joint_input(2, 3, 1, 0, to_share, results[1]); + P.send_receive_all(channels, send_os, receive_os); + finalize_joint_input(0, 1, 3, 2, results[0]); + finalize_joint_input(2, 3, 1, 0, results[1]); + + for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++) + { + for (int j = 0; j < n_bits; j++) + for (int i = 0; i < 2; i++) + { + auto res = results[i].next().res; + dest.at(regs.at(2 * j + i) + k) = res; + } + } +} diff --git a/Protocols/Rep4Prep.hpp b/Protocols/Rep4Prep.hpp index 48b843e0c..17915e43d 100644 --- a/Protocols/Rep4Prep.hpp +++ b/Protocols/Rep4Prep.hpp @@ -11,7 +11,9 @@ template Rep4RingPrep::Rep4RingPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), - RingPrep(proc, usage), MaliciousRingPrep(proc, usage) + RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), + MaliciousRingPrep(proc, usage) { } @@ -19,7 +21,9 @@ template Rep4RingOnlyPrep::Rep4RingOnlyPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), - RingPrep(proc, usage), Rep4RingPrep(proc, usage), + RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), + Rep4RingPrep(proc, usage), RepRingOnlyEdabitPrep(proc, usage) { } diff --git a/Protocols/Rep4Share2k.h b/Protocols/Rep4Share2k.h index 9e27c3ea6..a394b1ae7 100644 --- a/Protocols/Rep4Share2k.h +++ b/Protocols/Rep4Share2k.h @@ -41,8 +41,19 @@ class Rep4Share2 : public Rep4Share> } template - static void split(vector& dest, const vector& regs, - int n_bits, const Rep4Share2* source, int n_inputs, Player& P) + static void split(vector& dest, const vector& regs, int n_bits, + const Rep4Share2* source, int n_inputs, Rep4& protocol) + { + int n_split = regs.size() / n_bits; + if (n_split == 2) + protocol.split(dest, regs, n_bits, source, n_inputs); + else + split(dest, regs, n_bits, source, n_inputs, protocol.P); + } + + template + static void split(vector& dest, const vector& regs, int n_bits, + const Rep4Share2* source, int n_inputs, Player& P) { int my_num = P.my_num(); assert(n_bits <= 64); diff --git a/Protocols/RepRingOnlyEdabitPrep.hpp b/Protocols/RepRingOnlyEdabitPrep.hpp index 78721e0d9..c2650c27b 100644 --- a/Protocols/RepRingOnlyEdabitPrep.hpp +++ b/Protocols/RepRingOnlyEdabitPrep.hpp @@ -27,7 +27,8 @@ void RepRingOnlyEdabitPrep::buffer_edabits(int n_bits, ThreadQueues*) regs[i] = i * buffer_size / dl; typedef typename T::bit_type bt; vector bits(n_bits * P.num_players() * buffer_size); - T::split(bits, regs, n_bits, wholes.data(), wholes.size(), this->proc->P); + T::split(bits, regs, n_bits, wholes.data(), wholes.size(), + *GC::ShareThread < bt > ::s().protocol); BitAdder bit_adder; vector>> summands; diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 846a269b0..83729d793 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -49,6 +49,8 @@ class ProtocolBase protected: vector random; + int trunc_pr_counter; + public: typedef T share_type; diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index abc8c82a3..3f604afdb 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -20,7 +20,8 @@ #include "Math/Z2k.hpp" template -ProtocolBase::ProtocolBase() : counter(0) +ProtocolBase::ProtocolBase() : + trunc_pr_counter(0), counter(0) { } @@ -68,6 +69,8 @@ ProtocolBase::~ProtocolBase() #ifdef VERBOSE_COUNT if (counter) cerr << "Number of " << T::type_string() << " multiplications: " << counter << endl; + if (trunc_pr_counter) + cerr << "Number of probabilistic truncations: " << trunc_pr_counter << endl; #endif } diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index e49864985..357677182 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -126,6 +126,7 @@ class BitPrep : public virtual BufferPrep public: BitPrep(SubProcessor* proc, DataPositions& usage); + ~BitPrep(); void set_protocol(typename T::Protocol& protocol); @@ -224,7 +225,26 @@ class SemiHonestRingPrep : public virtual RingPrep }; template -class MaliciousRingPrep : public virtual RingPrep +class MaliciousDabitOnlyPrep : public virtual RingPrep +{ + template + void buffer_dabits(ThreadQueues* queues, true_type); + template + void buffer_dabits(ThreadQueues* queues, false_type); + +public: + MaliciousDabitOnlyPrep(SubProcessor* proc, DataPositions& usage) : + BufferPrep(usage), BitPrep(proc, usage), + RingPrep(proc, usage) + { + } + virtual ~MaliciousDabitOnlyPrep() {} + + virtual void buffer_dabits(ThreadQueues* queues); +}; + +template +class MaliciousRingPrep : public virtual MaliciousDabitOnlyPrep { typedef typename T::bit_type::part_type BT; @@ -246,11 +266,6 @@ class MaliciousRingPrep : public virtual RingPrep void buffer_personal_dabits(int input_player, true_type); void buffer_personal_dabits(int input_player, false_type); - template - void buffer_dabits(ThreadQueues* queues, true_type); - template - void buffer_dabits(ThreadQueues* queues, false_type); - public: static void edabit_sacrifice_buckets(vector>& to_check, size_t n_bits, bool strict, int player, SubProcessor& proc, int begin, int end, @@ -262,13 +277,12 @@ class MaliciousRingPrep : public virtual RingPrep MaliciousRingPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), - RingPrep(proc, usage) + RingPrep(proc, usage), MaliciousDabitOnlyPrep(proc, usage) { } virtual ~MaliciousRingPrep() {} virtual void buffer_bits(); - virtual void buffer_dabits(ThreadQueues* queues); virtual void buffer_edabits(bool strict, int n_bits, ThreadQueues* queues); }; diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index e373aa6b9..f08e10eae 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -54,9 +54,12 @@ BufferPrep::~BufferPrep() cerr << n_bit_rounds << " rounds of random " << type_string << " bit generation" << endl; + this->print_left("triples", triples.size() * T::default_length, + type_string); + #define X(KIND) \ this->print_left(#KIND, KIND.size(), type_string); - X(triples) X(squares) X(inverses) X(bits) X(dabits) + X(squares) X(inverses) X(bits) X(dabits) #undef X for (auto& x : this->edabits) @@ -83,12 +86,22 @@ RingPrep::RingPrep(SubProcessor* proc, DataPositions& usage) : template void BitPrep::set_protocol(typename T::Protocol& protocol) { - this->protocol = &protocol; + this->protocol = new typename T::Protocol(protocol.branch()); auto proc = this->proc; if (proc and proc->Proc) this->base_player = proc->Proc->thread_num; } +template +BitPrep::~BitPrep() +{ + if (protocol) + { + protocol->check(); + delete protocol; + } +} + template void BufferPrep::clear() { @@ -323,8 +336,7 @@ template void BitPrep::buffer_bits_without_check() { SeededPRNG G; - buffer_ring_bits_without_check(this->bits, G, - OnlineOptions::singleton.batch_size); + buffer_ring_bits_without_check(this->bits, G, this->buffer_size); } template diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index bbb48534f..20ecd2682 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -51,9 +51,11 @@ class Semi2kShare : public SemiShare> } template - static void split(vector& dest, const vector& regs, - int n_bits, const Semi2kShare* source, int n_inputs, Player& P) + static void split(vector& dest, const vector& regs, int n_bits, + const Semi2kShare* source, int n_inputs, + typename U::Protocol& protocol) { + auto& P = protocol.P; int my_num = P.my_num(); assert(n_bits <= 64); int unit = GC::Clear::N_BITS; diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index 96b77173c..595957c18 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -27,8 +27,24 @@ template class OTTripleGenerator; namespace GC { class SemiSecret; +class NoValue; } +template +class BasicSemiShare : public T +{ +public: + typedef T open_type; + typedef T clear; + + typedef GC::NoValue mac_key_type; + + template + BasicSemiShare(const U& other) : T(other) + { + } +}; + template class SemiShare : public T, public ShareInterface { diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index b435d687d..e3a6e2881 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -46,6 +46,7 @@ class Shamir : public ProtocolBase Player& P; static U get_rec_factor(int i, int n); + static U get_rec_factor(int i, int n_total, int start, int threshold); Shamir(Player& P); ~Shamir(); diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index b35d8151c..0127be502 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -13,11 +13,21 @@ template typename T::open_type::Scalar Shamir::get_rec_factor(int i, int n) +{ + return get_rec_factor(i, n, 0, n); +} + +template +typename T::open_type::Scalar Shamir::get_rec_factor(int i, int n_total, + int start, int n_points) { U res = 1; - for (int j = 0; j < n; j++) - if (i != j) - res *= U(j + 1) / (U(j + 1) - U(i + 1)); + for (int j = 0; j < n_points; j++) + { + int other = positive_modulo(start + j, n_total); + if (i != other) + res *= U(other + 1) / (U(other + 1) - U(i + 1)); + } return res; } diff --git a/Protocols/ShamirMC.h b/Protocols/ShamirMC.h index 694c1467a..8ac858a17 100644 --- a/Protocols/ShamirMC.h +++ b/Protocols/ShamirMC.h @@ -16,18 +16,17 @@ class ShamirMC : public MAC_Check_Base { vector reconstruction; - bool send; - void finalize(vector& values, const vector& S); protected: Bundle* os; + const Player* player; int threshold; void prepare(const vector& S, const Player& P); public: - ShamirMC() : send(false), os(0), threshold(ShamirMachine::s().threshold) {} + ShamirMC() : os(0), player(0), threshold(ShamirMachine::s().threshold) {} // emulate MAC_Check ShamirMC(const typename T::mac_key_type& _, int __ = 0, int ___ = 0) : ShamirMC() diff --git a/Protocols/ShamirMC.hpp b/Protocols/ShamirMC.hpp index 3b483542d..4e1d241d2 100644 --- a/Protocols/ShamirMC.hpp +++ b/Protocols/ShamirMC.hpp @@ -32,8 +32,8 @@ void ShamirMC::init_open(const Player& P, int n) { reconstruction.resize(n_relevant_players); for (int i = 0; i < n_relevant_players; i++) - reconstruction[i] = Shamir::get_rec_factor(i, - n_relevant_players); + reconstruction[i] = Shamir::get_rec_factor(P.get_player(i), + P.num_players(), P.my_num(), n_relevant_players); } if (not os) @@ -41,9 +41,8 @@ void ShamirMC::init_open(const Player& P, int n) for (auto& o : *os) o.reset_write_head(); - send = P.my_num() <= threshold; - if (send) - os->mine.reserve(n * T::size()); + os->mine.reserve(n * T::size()); + this->player = &P; } template @@ -57,8 +56,7 @@ void ShamirMC::prepare(const vector& S, const Player& P) template void ShamirMC::prepare_open(const T& share) { - if (send) - share.pack(os->mine); + share.pack(os->mine); } template @@ -73,10 +71,13 @@ void ShamirMC::POpen(vector& values, const vector& template void ShamirMC::exchange(const Player& P) { - vector senders(P.num_players()); + vector my_senders(P.num_players()), my_receivers(P.num_players()); for (int i = 0; i < P.num_players(); i++) - senders[i] = i <= threshold; - P.partial_broadcast(senders, *os); + { + my_senders[i] = P.get_offset(i) <= threshold; + my_receivers[i] = P.get_offset(i) >= P.num_players() - threshold; + } + P.partial_broadcast(my_senders, my_receivers, *os); } template @@ -103,7 +104,9 @@ typename T::open_type ShamirMC::finalize_open() typename T::open_type res; for (size_t j = 0; j < reconstruction.size(); j++) { - res += (*os)[j].template get() * reconstruction[j]; + res += + (*os)[player->get_player(j)].template get() + * reconstruction[j]; } return res; diff --git a/Protocols/Share.h b/Protocols/Share.h index 21fc63b58..d47da0d6e 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -189,6 +189,28 @@ class Share : public Share_, SemiShare> super(share, mac) {} }; +template +class ArithmeticOnlyMascotShare : public Share +{ + typedef ArithmeticOnlyMascotShare This; + typedef Share super; + +public: + typedef GC::NoShare bit_type; + + typedef MAC_Check_ MAC_Check; + typedef ::Input Input; + typedef ::PrivateOutput PrivateOutput; + typedef SPDZ Protocol; + + ArithmeticOnlyMascotShare() {} + template + ArithmeticOnlyMascotShare(const U& other) : super(other) {} + ArithmeticOnlyMascotShare(const SemiShare& share, + const SemiShare& mac) : + super(share, mac) {} +}; + // specialized mul by bit for gf2n template <> void Share_, SemiShare>::mul_by_bit(const Share_, SemiShare>& S,const gf2n& aa); diff --git a/Protocols/Share.hpp b/Protocols/Share.hpp index 58bb7085b..a30d15d2e 100644 --- a/Protocols/Share.hpp +++ b/Protocols/Share.hpp @@ -1,3 +1,5 @@ +#ifndef PROTOCOLS_SHARE_H_ +#define PROTOCOLS_SHARE_H_ #include "Share.h" @@ -53,3 +55,5 @@ inline void Share_::unpack(octetStream& os, bool full) if (full) mac.unpack(os); } + +#endif diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index fcce6e0d6..08d5f2067 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -31,7 +31,8 @@ class ShareInterface static string type_short() { return "undef"; } template - static void split(vector, vector, int, T*, int, Player&) + static void split(vector, vector, int, T*, int, + typename U::Protocol&) { throw runtime_error("split not implemented"); } template diff --git a/Protocols/ShareVector.h b/Protocols/ShareVector.h new file mode 100644 index 000000000..41ccec882 --- /dev/null +++ b/Protocols/ShareVector.h @@ -0,0 +1,57 @@ +/* + * ShareVector.h + * + */ + +#ifndef PROTOCOLS_SHAREVECTOR_H_ +#define PROTOCOLS_SHAREVECTOR_H_ + +#include "FHE/AddableVector.h" +#include "Protocols/Share.h" +#include "Math/gfp.h" + +template +class ShareVector : public AddableVector +{ +public: + ShareVector operator+(const ShareVector& other) const + { + assert(this->size() == other.size()); + ShareVector res; + for (size_t i = 0; i < other.size(); i++) + res.push_back((*this)[i] + other[i]); + return res; + } + + ShareVector operator-(const ShareVector& other) const + { + assert(this->size() == other.size()); + ShareVector res; + for (size_t i = 0; i < other.size(); i++) + res.push_back((*this)[i] - other[i]); + return res; + } + + template + ShareVector operator*(const AddableVector& other) const + { + assert(this->size() == other.size()); + ShareVector res; + for (size_t i = 0; i < other.size(); i++) + res.push_back((*this)[i] * other[i]); + return res; + } + + template + ShareVector operator*(const T& other) const + { + ShareVector res; + for (size_t i = 0; i < this->size(); i++) + res.push_back((*this)[i] * other); + return res; + } + + void fft(const FFT_Data& fftd); +}; + +#endif /* PROTOCOLS_SHAREVECTOR_H_ */ diff --git a/Protocols/ShareVector.hpp b/Protocols/ShareVector.hpp new file mode 100644 index 000000000..b9a4a9e2d --- /dev/null +++ b/Protocols/ShareVector.hpp @@ -0,0 +1,31 @@ +/* + * ShareVector.cpp + * + */ + +#include "ShareVector.h" +#include "FHE/FFT.h" + +template +void ShareVector::fft(const FFT_Data& fftd) +{ + array, 2> data; + for (auto& share : *this) + { + data[0].push_back(share.get_share()); + data[1].push_back(share.get_mac()); + } + + for (auto& x : data) + { + if (fftd.get_twop() == 0) + FFT_Iter2(x, fftd.phi_m(), fftd.get_root(0), fftd.get_prD()); + else + FFT_non_power_of_two(x, x, fftd); + } + + for (int i = 0; i < fftd.phi_m(); i++) + { + (*this)[i] = {data[0][i], data[1][i]}; + } +} diff --git a/Protocols/SohoPrep.hpp b/Protocols/SohoPrep.hpp index e2b1a2e30..b1223df17 100644 --- a/Protocols/SohoPrep.hpp +++ b/Protocols/SohoPrep.hpp @@ -21,7 +21,7 @@ void SohoPrep::basic_setup(Player& P) setup = new PartSetup; MachineBase machine; setup->secure_init(P, machine, T::clear::length(), 0); - setup->covert_secrets_generation(P, machine, 1); + setup->key_and_mac_generation(P, machine, 1, true_type()); T::clear::template init(); } diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index 3907512ee..5bcd3c6c0 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -14,7 +14,9 @@ template void bits_from_square_in_ring(vector& bits, int buffer_size, U* bit_prep); template -class Spdz2kPrep : public virtual MascotPrep, public virtual RingOnlyPrep +class Spdz2kPrep : public virtual MaliciousRingPrep, + public virtual MascotTriplePrep, + public virtual RingOnlyPrep { typedef Spdz2kShare BitShare; DataPositions bit_pos; diff --git a/Protocols/Spdz2kPrep.hpp b/Protocols/Spdz2kPrep.hpp index 4a2617b2e..7899d26f0 100644 --- a/Protocols/Spdz2kPrep.hpp +++ b/Protocols/Spdz2kPrep.hpp @@ -16,9 +16,9 @@ template Spdz2kPrep::Spdz2kPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), MaliciousRingPrep(proc, usage), MascotTriplePrep(proc, usage), - MascotPrep(proc, usage), RingOnlyPrep(proc, usage) { this->params.amplify = false; @@ -43,7 +43,7 @@ Spdz2kPrep::~Spdz2kPrep() template void Spdz2kPrep::set_protocol(typename T::Protocol& protocol) { - MascotPrep::set_protocol(protocol); + OTPrep::set_protocol(protocol); assert(this->proc != 0); auto& proc = this->proc; bit_MC = new typename BitShare::MAC_Check(proc->MC.get_alphai()); @@ -240,7 +240,7 @@ void MaliciousRingPrep::buffer_edabits_from_personal(bool strict, int n_bits, template size_t Spdz2kPrep::data_sent() { - size_t res = MascotPrep::data_sent(); + size_t res = OTPrep::data_sent(); if (bit_prep) res += bit_prep->data_sent(); return res; diff --git a/Protocols/SpdzWise.h b/Protocols/SpdzWise.h index ebbd0352a..cb049fceb 100644 --- a/Protocols/SpdzWise.h +++ b/Protocols/SpdzWise.h @@ -35,6 +35,8 @@ class SpdzWise : public ProtocolBase SpdzWise(Player& P); virtual ~SpdzWise(); + Player& branch(); + void init(SubProcessor* proc); void init_mul(SubProcessor* proc); diff --git a/Protocols/SpdzWise.hpp b/Protocols/SpdzWise.hpp index 8378a5df1..40f3cee71 100644 --- a/Protocols/SpdzWise.hpp +++ b/Protocols/SpdzWise.hpp @@ -18,6 +18,12 @@ SpdzWise::~SpdzWise() check(); } +template +Player& SpdzWise::branch() +{ + return P; +} + template void SpdzWise::init(SubProcessor* proc) { diff --git a/Protocols/SpdzWisePrep.h b/Protocols/SpdzWisePrep.h index ad2c2cd89..35be8cb94 100644 --- a/Protocols/SpdzWisePrep.h +++ b/Protocols/SpdzWisePrep.h @@ -31,6 +31,7 @@ class SpdzWisePrep : public MaliciousRingPrep SpdzWisePrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), MaliciousRingPrep(proc, usage) { } diff --git a/Protocols/SpdzWiseRingPrep.h b/Protocols/SpdzWiseRingPrep.h index 3bef84259..c201c7a6b 100644 --- a/Protocols/SpdzWiseRingPrep.h +++ b/Protocols/SpdzWiseRingPrep.h @@ -40,6 +40,7 @@ class SpdzWiseRingPrep : public virtual SpdzWisePrep, SpdzWiseRingPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), RingPrep(proc, usage), + MaliciousDabitOnlyPrep(proc, usage), SpdzWisePrep(proc, usage), RepRingOnlyEdabitPrep(proc, usage) { } diff --git a/Protocols/SpdzWiseRingShare.h b/Protocols/SpdzWiseRingShare.h index e49d5733b..bda30f86e 100644 --- a/Protocols/SpdzWiseRingShare.h +++ b/Protocols/SpdzWiseRingShare.h @@ -55,13 +55,14 @@ class SpdzWiseRingShare : public SpdzWiseShare>> } template - static void split(vector& dest, const vector& regs, - int n_bits, const SpdzWiseRingShare* source, int n_inputs, Player& P) + static void split(vector& dest, const vector& regs, int n_bits, + const SpdzWiseRingShare* source, int n_inputs, + typename U::Protocol& protocol) { vector> shares(n_inputs); for (int i = 0; i < n_inputs; i++) shares[i] = source[i].get_share(); - Rep3Share2::split(dest, regs, n_bits, shares.data(), n_inputs, P); + Rep3Share2::split(dest, regs, n_bits, shares.data(), n_inputs, protocol); } static void shrsi(SubProcessor& proc, const Instruction& inst) diff --git a/README.md b/README.md index 61e01b040..ec057af12 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ The following table lists all protocols that are fully supported. | Security model | Mod prime / GF(2^n) | Mod 2^k | Bin. SS | Garbling | | --- | --- | --- | --- | --- | -| Malicious, dishonest majority | [MASCOT](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) | +| Malicious, dishonest majority | [MASCOT / LowGear / HighGear](#secret-sharing) | [SPDZ2k](#secret-sharing) | [Tiny / Tinier](#secret-sharing) | [BMR](#bmr) | | Covert, dishonest majority | [CowGear / ChaiGear](#secret-sharing) | N/A | N/A | N/A | | Semi-honest, dishonest majority | [Semi / Hemi / Soho](#secret-sharing) | [Semi2k](#secret-sharing) | [SemiBin](#secret-sharing) | [Yao's GC](#yaos-garbled-circuits) / [BMR](#bmr) | | Malicious, honest majority | [Shamir / Rep3 / PS / SY](#honest-majority) | [Brain / Rep[34] / PS / SY](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | @@ -399,6 +399,8 @@ The following table shows all programs for dishonest-majority computation using | `spdz2k-party.x` | [SPDZ2k](https://eprint.iacr.org/2018/482) | Mod 2^k | Malicious | `spdz2k.sh` | | `semi-party.x` | OT-based | Mod prime | Semi-honest | `semi.sh` | | `semi2k-party.x` | OT-based | Mod 2^k | Semi-honest | `semi2k.sh` | +| `lowgear-party.x` | [LowGear](https://eprint.iacr.org/2017/1230) | Mod prime | Malicious | `lowgear.sh` | +| `highgear-party.x` | [HighGear](https://eprint.iacr.org/2017/1230) | Mod prime | Malicious | `highgear.sh` | | `cowgear-party.x` | Adapted [LowGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `cowgear.sh` | | `chaigear-party.x` | Adapted [HighGear](https://eprint.iacr.org/2017/1230) | Mod prime | Covert | `chaigear.sh` | | `hemi-party.x` | Semi-homomorphic encryption | Mod prime | Semi-honest | `hemi.sh` | @@ -427,13 +429,13 @@ particular, the SPDZ2k sacrifice does not work for bits, so we replace it by cut-and-choose according to [Furukawa et al.](https://eprint.iacr.org/2016/944) -CowGear denotes a covertly secure version of LowGear. The reason for -this is the key generation that only achieves covert security. It is -possible however to run full LowGear for the offline phase by using -`-s` with the desired security parameter. The same holds for ChaiGear, -an adapted version of HighGear. Option `-T` activates -[TopGear](https://eprint.iacr.org/2019/035) zero-knowledge proofs in -both. +The virtual machines for LowGear and HighGear run a key generation +similar to the one by [Rotaru et +al.](https://eprint.iacr.org/2019/1300). The main difference is using +daBits to generate maBits. CowGear and ChaiGear denote covertly +secure versions of LowGear and HighGear. In all relevant programs, +option `-T` activates [TopGear](https://eprint.iacr.org/2019/035) +zero-knowledge proofs in both. Hemi and Soho denote the stripped version version of LowGear and HighGear, respectively, for semi-honest diff --git a/Scripts/emulate.sh b/Scripts/emulate.sh index d38fc37f0..fe840c7d2 100755 --- a/Scripts/emulate.sh +++ b/Scripts/emulate.sh @@ -4,4 +4,4 @@ test -e logs || mkdir logs prog=${1%.sch} prog=${prog##*/} shift -./emulate.x $prog $* 2>&1 | tee -a logs/emulate-$prog +$prefix ./emulate.x $prog $* 2>&1 | tee -a logs/emulate-$prog diff --git a/Scripts/highgear.sh b/Scripts/highgear.sh new file mode 100755 index 000000000..5c0b5b091 --- /dev/null +++ b/Scripts/highgear.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player highgear-party.x $* || exit 1 diff --git a/Scripts/lowgear.sh b/Scripts/lowgear.sh new file mode 100755 index 000000000..9502a47cb --- /dev/null +++ b/Scripts/lowgear.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +HERE=$(cd `dirname $0`; pwd) +SPDZROOT=$HERE/.. + +. $HERE/run-common.sh + +run_player lowgear-party.x $* || exit 1 diff --git a/Scripts/test_tutorial.sh b/Scripts/test_tutorial.sh index 90d74b9b8..8ac1ef4d7 100755 --- a/Scripts/test_tutorial.sh +++ b/Scripts/test_tutorial.sh @@ -56,11 +56,13 @@ for dabit in ${dabit:-0 1 2}; do for i in rep-field shamir mal-rep-field ps-rep-field sy-rep-field \ mal-shamir sy-shamir hemi semi \ - soho cowgear mascot; do + soho mascot; do test_vm $i $run_opts done - test_vm chaigear $run_opts -l 3 -c 2 + for i in cowgear chaigear; do + test_vm $i $run_opts -l 3 -c 2 + done done if test $dabit != 0; then @@ -76,8 +78,9 @@ fi ./compile.py tutorial -test_vm cowgear $run_opts -T -test_vm chaigear $run_opts -T -l 3 -c 2 +for i in cowgear chaigear; do + test_vm $i $run_opts -l 3 -c 2 -T +done if test $skip_binary; then exit diff --git a/Tools/PointerVector.h b/Tools/PointerVector.h index 68f830f91..32d1b46ee 100644 --- a/Tools/PointerVector.h +++ b/Tools/PointerVector.h @@ -6,17 +6,17 @@ #ifndef TOOLS_POINTERVECTOR_H_ #define TOOLS_POINTERVECTOR_H_ -#include -using namespace std; +#include "CheckVector.h" template -class PointerVector : public vector +class PointerVector : public CheckVector { int i; public: - PointerVector(size_t size = 0) : vector(size), i(0) {} - PointerVector(const vector& other) : vector(other), i(0) {} + PointerVector() : i(0) {} + PointerVector(size_t size) : CheckVector(size), i(0) {} + PointerVector(const vector& other) : CheckVector(other), i(0) {} void clear() { vector::clear(); diff --git a/Tools/octetStream.h b/Tools/octetStream.h index 24de79aa3..e4ab090c1 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -148,6 +148,8 @@ class octetStream void store(const vector& v); template void get(vector& v, const T& init = {}); + template + void get_no_resize(vector& v); void consume(octetStream& s,size_t l) { s.resize(l); @@ -350,5 +352,16 @@ void octetStream::get(vector& v, const T& init) get(x); } +template +void octetStream::get_no_resize(vector& v) +{ + size_t size; + get(size); + if (size != v.size()) + throw runtime_error("wrong vector length"); + for (auto& x : v) + get(x); +} + #endif diff --git a/Tools/random.cpp b/Tools/random.cpp index 7acc91418..ff24c94fa 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -172,18 +172,6 @@ void PRNG::get_octetStream(octetStream& ans,int len) } -template -void PRNG::randomBnd(mp_limb_t* res, const mp_limb_t* B, mp_limb_t mask) -{ - size_t n_limbs = (N_BYTES + sizeof(mp_limb_t) - 1) / sizeof(mp_limb_t); - do - { - get_octets((octet*) res); - res[n_limbs - 1] &= mask; - } - while (mpn_cmp(res, B, n_limbs) >= 0); -} - void PRNG::randomBnd(mp_limb_t* res, const mp_limb_t* B, size_t n_bytes, mp_limb_t mask) { switch (n_bytes) diff --git a/Tools/random.h b/Tools/random.h index 33008dd94..67314db8e 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -203,4 +203,16 @@ inline void PRNG::get_octets(octet* ans) get_octets(ans, L); } +template +inline void PRNG::randomBnd(mp_limb_t* res, const mp_limb_t* B, mp_limb_t mask) +{ + size_t n_limbs = (N_BYTES + sizeof(mp_limb_t) - 1) / sizeof(mp_limb_t); + do + { + get_octets((octet*) res); + res[n_limbs - 1] &= mask; + } + while (mpn_cmp(res, B, n_limbs) >= 0); +} + #endif diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index f164055be..d406e6f22 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -480,6 +480,15 @@ int main(int argc, const char** argv) "-lgp", // Flag token. "--lgp" // Flag token. ); + opt.add( + "", // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Prime for GF(p) field (default: generated from -lgp argument)", // Help description. + "-P", // Flag token. + "--prime" // Flag token. + ); opt.add( to_string(gf2n::default_degree()).c_str(), // Default. 0, // Required? @@ -724,8 +733,18 @@ int FakeParams::generate() G.ReSeed(); prep_data_prefix = PREP_DIR; // Set up the fields - T::clear::template generate_setup(prep_data_prefix, nplayers, lgp); - T::clear::init_default(lgp); + if (opt.isSet("--prime")) + { + string p; + opt.get("--prime")->getString(p); + T::clear::init_field(p); + T::clear::template write_setup(nplayers); + } + else + { + T::clear::template generate_setup(prep_data_prefix, nplayers, lgp); + T::clear::init_default(lgp); + } /* Find number players and MAC keys etc*/ typename T::mac_type::Scalar keyp;