From 99c0549e7205f4a4550cff836abc417227193fa0 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Fri, 2 Jul 2021 15:49:23 +1000 Subject: [PATCH] Convolutional neural network training. --- BMR/network/Node.cpp | 2 +- BMR/network/Node.h | 2 - CHANGELOG.md | 12 + Compiler/GC/instructions.py | 4 +- Compiler/GC/types.py | 81 +- Compiler/allocator.py | 4 +- Compiler/circuit.py | 2 +- Compiler/comparison.py | 1 + Compiler/floatingpoint.py | 73 +- Compiler/graph.py | 11 +- Compiler/instructions.py | 60 +- Compiler/instructions_base.py | 51 +- Compiler/library.py | 104 ++- Compiler/ml.py | 751 +++++++++++++++---- Compiler/mpc_math.py | 51 ++ Compiler/non_linear.py | 26 +- Compiler/program.py | 42 +- Compiler/types.py | 275 ++++++- Compiler/util.py | 2 +- ECDSA/fake-spdz-ecdsa-party.cpp | 3 + ECDSA/mascot-ecdsa-party.cpp | 1 + ECDSA/ot-ecdsa-party.hpp | 2 + ECDSA/preprocessing.hpp | 13 - FHE/AddableVector.h | 5 - FHE/{AddableVector.cpp => AddableVector.hpp} | 5 +- FHE/Ciphertext.h | 4 - FHE/DiscreteGauss.cpp | 76 -- FHE/DiscreteGauss.h | 47 +- FHE/FFT.cpp | 69 +- FHE/FFT.h | 21 + FHE/FFT_Data.cpp | 31 +- FHE/FFT_Data.h | 12 +- FHE/FHE_Keys.cpp | 80 +- FHE/FHE_Params.cpp | 9 - FHE/FHE_Params.h | 19 +- FHE/NTL-Subs.cpp | 5 +- FHE/NoiseBounds.cpp | 17 +- FHE/NoiseBounds.h | 1 - FHE/PPData.cpp | 8 - FHE/PPData.h | 8 - FHE/Plaintext.cpp | 38 +- FHE/Plaintext.h | 14 +- FHE/Random_Coins.h | 2 - FHE/Ring_Element.cpp | 236 ++++-- FHE/Ring_Element.h | 61 +- FHE/Rq_Element.cpp | 23 - FHE/Rq_Element.h | 7 - FHEOffline/DataSetup.hpp | 3 +- FHEOffline/DistKeyGen.cpp | 4 +- FHEOffline/Multiplier.cpp | 2 +- FHEOffline/Multiplier.h | 2 + FHEOffline/Player-Offline.h | 32 - FHEOffline/Producer.cpp | 2 +- FHEOffline/Proof.cpp | 5 + FHEOffline/Prover.cpp | 12 +- FHEOffline/Prover.h | 6 +- FHEOffline/SimpleEncCommit.cpp | 16 +- FHEOffline/Verifier.cpp | 28 +- FHEOffline/Verifier.h | 4 +- GC/CcdPrep.h | 46 +- GC/CcdPrep.hpp | 33 + GC/CcdShare.h | 5 - GC/FakeSecret.h | 2 - GC/MaliciousCcdShare.h | 5 - GC/NoShare.h | 32 +- GC/PersonalPrep.hpp | 2 +- GC/Processor.h | 1 + GC/Processor.hpp | 21 +- GC/Rep4Secret.h | 1 - GC/SemiHonestRepPrep.cpp | 11 - GC/ShareParty.hpp | 4 +- GC/ShareSecret.h | 2 + GC/ShareSecret.hpp | 1 + GC/ShareThread.hpp | 3 +- GC/TinierPrep.h | 37 - GC/TinierSecret.h | 9 +- GC/TinierShare.h | 11 +- GC/TinierSharePrep.h | 12 +- GC/TinierSharePrep.hpp | 40 +- GC/TinyMC.cpp | 11 - GC/TinyMC.h | 2 +- GC/TinyPrep.h | 71 -- GC/TinyPrep.hpp | 182 +---- GC/TinySecret.cpp | 11 - GC/TinySecret.h | 18 +- GC/TinyShare.cpp | 11 - GC/TinyShare.h | 11 +- GC/VectorInput.h | 10 +- GC/VectorProtocol.h | 2 + GC/VectorProtocol.hpp | 7 +- GC/instructions.h | 2 +- Machines/SPDZ.hpp | 3 +- Machines/SPDZ2k.hpp | 4 +- Machines/ShamirMachine.hpp | 1 + Machines/ccd-party.cpp | 1 + Machines/cowgear-party.cpp | 1 - Machines/emulate.cpp | 42 +- Machines/malicious-ccd-party.cpp | 1 + Machines/no-party.cpp | 23 + Machines/sy-rep-field-party.cpp | 1 - Machines/tinier-party.cpp | 4 +- Machines/tiny-party.cpp | 3 + Makefile | 4 +- Math/BitVec.cpp | 7 - Math/FixedVec.h | 4 - Math/Integer.h | 1 + Math/Z2k.h | 17 +- Math/bigint.h | 12 + Math/field_types.h | 2 - Math/gf2n.h | 2 - Math/gf2nlong.h | 2 - Math/gfp.h | 7 +- Math/gfp.hpp | 11 +- Math/gfpvar.cpp | 1 + Math/modp.h | 4 + Math/modp.hpp | 24 + Math/mpn_fixed.h | 2 +- Networking/Player.cpp | 17 +- Networking/Player.h | 22 +- Processor/DataPositions.cpp | 2 +- Processor/Data_Files.h | 16 +- Processor/Data_Files.hpp | 20 +- Processor/DummyProtocol.h | 4 - Processor/ExternalClients.cpp | 2 + Processor/Input.h | 4 +- Processor/Input.hpp | 2 +- Processor/Instruction.h | 2 + Processor/Instruction.hpp | 59 +- Processor/IntInput.hpp | 5 + Processor/Machine.h | 3 - Processor/Machine.hpp | 7 +- Processor/NoLivePrep.h | 41 - Processor/NoProtocol.h | 42 -- Processor/OfflineMachine.hpp | 14 +- Processor/Online-Thread.h | 2 +- Processor/Online-Thread.hpp | 63 +- Processor/OnlineMachine.h | 4 +- Processor/OnlineMachine.hpp | 5 +- Processor/OnlineOptions.h | 12 + Processor/Processor.h | 2 + Processor/Processor.hpp | 112 ++- Processor/Program.h | 2 +- Processor/ThreadJob.h | 25 + Programs/Source/mnist_B.mpc | 73 ++ Programs/Source/mnist_D.mpc | 60 ++ Programs/Source/mnist_full_A.mpc | 56 +- Programs/Source/mnist_full_B.mpc | 73 ++ Programs/Source/mnist_full_C.mpc | 88 +++ Programs/Source/mnist_full_D.mpc | 105 +++ Programs/Source/prep_aes.mpc | 10 +- Programs/Source/test_gc.mpc | 30 +- Programs/Source/tutorial.mpc | 6 +- Programs/Source/vickrey.mpc | 2 +- Protocols/FakeInput.h | 2 +- Protocols/FakeProtocol.h | 44 +- Protocols/FakeShare.h | 1 - Protocols/HighGearKeyGen.cpp | 5 +- Protocols/LowGearKeyGen.h | 2 - Protocols/LowGearKeyGen.hpp | 53 +- Protocols/MAC_Check.h | 4 + Protocols/MaliciousRepPrep.h | 2 + Protocols/MaliciousRepPrep.hpp | 8 + Protocols/MaliciousShamirPO.hpp | 4 +- Protocols/NoLivePrep.h | 53 ++ Protocols/NoProtocol.h | 116 +++ Protocols/NoShare.h | 187 +++++ Protocols/Rep3Share.h | 11 +- Protocols/Rep3Share2k.h | 39 +- Protocols/Rep4Input.h | 2 +- Protocols/Rep4Input.hpp | 2 +- Protocols/Rep4Share.h | 1 - Protocols/Rep4Share2k.h | 1 - Protocols/Replicated.h | 4 + Protocols/Replicated.hpp | 13 +- Protocols/ReplicatedInput.h | 4 +- Protocols/ReplicatedInput.hpp | 2 +- Protocols/ReplicatedPrep.h | 11 + Protocols/Semi2k.h | 32 + Protocols/SemiShare.h | 1 - Protocols/ShamirInput.h | 7 +- Protocols/ShamirInput.hpp | 2 +- Protocols/Share.h | 11 +- Protocols/Share.hpp | 21 +- Protocols/ShareInterface.cpp | 8 + Protocols/ShareInterface.h | 6 + Protocols/SohoPrep.hpp | 3 +- Protocols/SpdzWiseInput.h | 5 +- Protocols/SpdzWiseInput.hpp | 4 +- Protocols/SpdzWiseShare.h | 13 +- Protocols/SpdzWiseShare.hpp | 6 + README.md | 18 +- Scripts/emulate.sh | 4 +- Scripts/run-common.sh | 12 +- Tools/Buffer.cpp | 4 + Tools/names.cpp | 2 +- Utils/Fake-Offline.cpp | 49 +- azure-pipelines.yml | 1 + compile.py | 2 + doc/Compiler.rst | 4 +- doc/add-protocol.rst | 82 ++ doc/conf.py | 2 +- doc/gen-instructions.py | 2 + doc/index.rst | 10 + doc/io.rst | 82 ++ doc/networking.rst | 30 + doc/non-linear.rst | 72 ++ doc/preprocessing.rst | 29 + doc/troubleshooting.rst | 100 +++ 208 files changed, 3612 insertions(+), 1862 deletions(-) rename FHE/{AddableVector.cpp => AddableVector.hpp} (81%) delete mode 100644 FHEOffline/Player-Offline.h create mode 100644 GC/CcdPrep.hpp delete mode 100644 GC/SemiHonestRepPrep.cpp delete mode 100644 GC/TinierPrep.h delete mode 100644 GC/TinyMC.cpp delete mode 100644 GC/TinyPrep.h delete mode 100644 GC/TinySecret.cpp delete mode 100644 GC/TinyShare.cpp create mode 100644 Machines/no-party.cpp delete mode 100644 Math/BitVec.cpp delete mode 100644 Processor/NoLivePrep.h delete mode 100644 Processor/NoProtocol.h create mode 100644 Programs/Source/mnist_B.mpc create mode 100644 Programs/Source/mnist_D.mpc create mode 100644 Programs/Source/mnist_full_B.mpc create mode 100644 Programs/Source/mnist_full_C.mpc create mode 100644 Programs/Source/mnist_full_D.mpc create mode 100644 Protocols/NoLivePrep.h create mode 100644 Protocols/NoProtocol.h create mode 100644 Protocols/NoShare.h create mode 100644 Protocols/ShareInterface.cpp create mode 100644 doc/add-protocol.rst create mode 100644 doc/io.rst create mode 100644 doc/non-linear.rst create mode 100644 doc/preprocessing.rst create mode 100644 doc/troubleshooting.rst diff --git a/BMR/network/Node.cpp b/BMR/network/Node.cpp index 3d32b2457..41e55d1ad 100644 --- a/BMR/network/Node.cpp +++ b/BMR/network/Node.cpp @@ -167,7 +167,7 @@ void Node::Broadcast2(SendBuffer& msg) { } void Node::_identify() { - char* msg = id_msg; + char msg[strlen(ID_HDR)+sizeof(_id)]; memcpy(msg, ID_HDR, strlen(ID_HDR)); memcpy(msg+strlen(ID_HDR), (const char *)&_id, sizeof(_id)); //printf("Node:: identifying myself:\n"); diff --git a/BMR/network/Node.h b/BMR/network/Node.h index cd4839396..da22c27d3 100644 --- a/BMR/network/Node.h +++ b/BMR/network/Node.h @@ -78,8 +78,6 @@ class Node : public ServerUpdatable, public ClientUpdatable { std::map _clientsmap; bool* _clients_connected; NodeUpdatable* _updatable; - - char id_msg[strlen(ID_HDR)+sizeof(_id)]; }; #endif /* NETWORK_NODE_H_ */ diff --git a/CHANGELOG.md b/CHANGELOG.md index fe65cc49e..cf5f1a967 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.2.5 (Jul 2, 2021) + +- Training of convolutional neural networks +- Bit decomposition using edaBits +- Ability to force MAC checks from high-level code +- Ability to close client connection from high-level code +- Binary operators for comparison results +- Faster compilation for emulation +- More documentation +- Fixed security bug: insufficient LowGear secret key randomness +- Fixed security bug: skewed random bit generation + ## 0.2.4 (Apr 19, 2021) - ARM support diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py index 7507fddd1..ac207347e 100644 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -117,7 +117,7 @@ class xorm(NonVectorInstruction): code = opcodes['XORM'] arg_format = ['int','sbw','sb','cb'] -class xorcb(NonVectorInstruction): +class xorcb(BinaryVectorInstruction): """ Bitwise XOR of two single clear bit registers. :param: result (cbit) @@ -125,7 +125,7 @@ class xorcb(NonVectorInstruction): :param: operand (cbit) """ code = opcodes['XORCB'] - arg_format = ['cbw','cb','cb'] + arg_format = ['int','cbw','cb','cb'] class xorcbi(NonVectorInstruction): """ Bitwise XOR of single clear bit register and immediate. diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 563426ec0..0e767c420 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -36,6 +36,7 @@ def get_type(cls, length): class bitsn(cls): n = length cls.types[length] = bitsn + bitsn.clear_type = cbits.get_type(length) bitsn.__name__ = cls.__name__ + str(length) return cls.types[length] @classmethod @@ -115,7 +116,11 @@ def load_mem(cls, address, mem_type=None, size=None): return res def store_in_mem(self, address): self.store_inst[isinstance(address, int)](self, address) + @classmethod + def new(cls, value=None, n=None): + return cls.get_type(n)(value) def __init__(self, value=None, n=None, size=None): + assert n == self.n or n is None if size != 1 and size is not None: raise Exception('invalid size for bit type: %s' % size) self.n = n or self.n @@ -125,7 +130,7 @@ def __init__(self, value=None, n=None, size=None): if value is not None: self.load_other(value) def copy(self): - return type(self)(n=instructions_base.get_global_vector_size()) + return type(self).new(n=instructions_base.get_global_vector_size()) def set_length(self, n): if n > self.n: raise Exception('too long: %d/%d' % (n, self.n)) @@ -154,6 +159,8 @@ def load_other(self, other): bits = other.bit_decompose() bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits)) other = self.bit_compose(bits) + assert(isinstance(other, type(self))) + assert(other.n == self.n) self.load_other(other) except: raise CompilerError('cannot convert %s/%s from %s to %s' % \ @@ -176,6 +183,16 @@ def _new_by_number(self, i, size=1): res.i = i res.program = self.program return res + def if_else(self, x, y): + """ + Vectorized oblivious selection:: + + sb32 = sbits.get_type(32) + print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal()) + + This will output 1. + """ + return result_conv(x, y)(self & (x ^ y) ^ y) class cbits(bits): """ Clear bits register. Helper type with limited functionality. """ @@ -202,14 +219,16 @@ def store_in_dynamic_mem(self, address): inst.stmsdci(self, cbits.conv(address)) def clear_op(self, other, c_inst, ci_inst, op): if isinstance(other, cbits): - res = cbits(n=max(self.n, other.n)) + res = cbits.get_type(max(self.n, other.n))() c_inst(res, self, other) return res + elif isinstance(other, sbits): + return NotImplemented else: if util.is_constant(other): if other >= 2**31 or other < -2**31: return op(self, cbits(other)) - res = cbits(n=max(self.n, len(bin(other)) - 2)) + res = cbits.get_type(max(self.n, len(bin(other)) - 2))() ci_inst(res, self, other) return res else: @@ -221,8 +240,14 @@ def clear_op(self, other, c_inst, ci_inst, op): def __xor__(self, other): if isinstance(other, (sbits, sbitvec)): return NotImplemented + elif isinstance(other, cbits): + res = cbits.get_type(max(self.n, other.n))() + assert res.size == self.size + assert res.size == other.size + inst.xorcb(res.n, res, self, other) + return res else: - self.clear_op(other, inst.xorcb, inst.xorcbi, operator.xor) + return self.clear_op(other, None, inst.xorcbi, operator.xor) __radd__ = __add__ __rxor__ = __xor__ def __mul__(self, other): @@ -230,17 +255,18 @@ def __mul__(self, other): return NotImplemented else: try: - res = cbits(n=min(self.max_length, self.n+util.int_len(other))) + res = cbits.get_type(min(self.max_length, + self.n+util.int_len(other)))() inst.mulcbi(res, self, other) return res except TypeError: return NotImplemented def __rshift__(self, other): - res = cbits(n=self.n-other) + res = cbits.new(n=self.n-other) inst.shrcbi(res, self, other) return res def __lshift__(self, other): - res = cbits(n=self.n+other) + res = cbits.get_type(self.n+other)() inst.shlcbi(res, self, other) return res def print_reg(self, desc=''): @@ -504,16 +530,6 @@ def trans(cls, rows): res = [cls.new(n=len(rows)) for i in range(n_columns)] inst.trans(len(res), *(res + rows)) return res - def if_else(self, x, y): - """ - Vectorized oblivious selection:: - - sb32 = sbits.get_type(32) - print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal()) - - This will output 1. - """ - return result_conv(x, y)(self & (x ^ y) ^ y) @staticmethod def bit_adder(*args, **kwargs): return sbitint.bit_adder(*args, **kwargs) @@ -610,7 +626,7 @@ def __init__(self, other=None, size=None): elif isinstance(other, (list, tuple)): self.v = self.bit_extend(sbitvec(other).v, n) else: - self.v = sbits(other, n=n).bit_decompose(n) + self.v = sbits.get_type(n)(other).bit_decompose() assert len(self.v) == n @classmethod def load_mem(cls, address): @@ -630,6 +646,8 @@ def store_in_mem(self, address): for i in range(n): v[i].store_in_mem(address + i) def reveal(self): + if len(self) > cbits.unit: + return self.elements()[0].reveal() revealed = [cbit() for i in range(len(self))] for i in range(len(self)): try: @@ -784,15 +802,23 @@ class bit(object): def result_conv(x, y): try: + def f(res): + try: + return t.conv(res) + except: + return res if util.is_constant(x): if util.is_constant(y): return lambda x: x else: - return type(y).conv + t = type(y) + return f if util.is_constant(y): - return type(x).conv + t = type(x) + return f if type(x) is type(y): - return type(x).conv + t = type(x) + return f except AttributeError: pass return lambda x: x @@ -807,13 +833,19 @@ def if_else(self, x, y): This will output 5. """ - return result_conv(x, y)(self * (x ^ y) ^ y) + assert self.n == 1 + diff = x ^ y + if isinstance(diff, cbits): + return result_conv(x, y)(self & (diff) ^ y) + else: + return result_conv(x, y)(self * (diff) ^ y) class cbit(bit, cbits): pass sbits.bit_type = sbit cbits.bit_type = cbit +sbit.clear_type = cbit class bitsBlock(oram.Block): value_type = sbits @@ -881,7 +913,7 @@ def round(self, k, m, kappa=None, nearest=None, signed=None): return self.get_type(k - m).compose(res_bits) def int_div(self, other, bit_length=None): k = bit_length or max(self.n, other.n) - return (library.IntDiv(self.extend(k), other.extend(k), k) >> k).cast(k) + return (library.IntDiv(self.cast(k), other.cast(k), k) >> k).cast(k) def Norm(self, k, f, kappa=None, simplex_flag=False): absolute_val = abs(self) #next 2 lines actually compute the SufOR for little indian encoding @@ -1100,7 +1132,8 @@ def output(self): bits = self.v.bit_decompose(self.k) sign = bits[-1] v += (sign << (self.k)) * -1 - inst.print_float_plainb(v, cbits(-self.f, n=32), cbits(0), cbits(0), cbits(0)) + inst.print_float_plainb(v, cbits.get_type(32)(-self.f), cbits(0), + cbits(0), cbits(0)) class sbitfix(_fix): """ Secret signed integer in one binary register. diff --git a/Compiler/allocator.py b/Compiler/allocator.py index 3553532db..5df49035c 100644 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -123,7 +123,7 @@ def dealloc_reg(self, reg, inst, free): for x in itertools.chain(dup.duplicates, base.duplicates): to_check.add(x) - free[reg.reg_type, base.size].add(self.alloc[base]) + free[reg.reg_type, base.size].append(self.alloc[base]) if inst.is_vec() and base.vector: self.defined[base] = inst for i in base.vector: @@ -604,4 +604,4 @@ def run(self, instructions): elif op == 1: instructions[i] = None inst.args[0].link(inst.args[1]) - instructions[:] = filter(lambda x: x is not None, instructions) + instructions[:] = list(filter(lambda x: x is not None, instructions)) diff --git a/Compiler/circuit.py b/Compiler/circuit.py index 4182abf0b..9c4187f75 100644 --- a/Compiler/circuit.py +++ b/Compiler/circuit.py @@ -127,7 +127,7 @@ def sha3_256(x): from circuit import sha3_256 a = sbitvec.from_vec([]) - b = sbitvec(sint(0xcc), 8) + b = sbitvec(sint(0xcc), 8, 8) for x in a, b: sha3_256(x).elements()[0].reveal().print_reg() diff --git a/Compiler/comparison.py b/Compiler/comparison.py index 977bb4b22..0fe62101d 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -73,6 +73,7 @@ def require_ring_size(k, op): if int(program.options.ring) < k: raise CompilerError('ring size too small for %s, compile ' 'with \'-R %d\' or more' % (op, k)) + program.curr_tape.require_bit_length(k) @instructions_base.cisc def LTZ(s, a, k, kappa): diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index cc5aefb2a..16858acb9 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -55,7 +55,7 @@ def EQZ(a, k, kappa): from GC.types import sbitvec v = sbitvec(a, k).v bit = util.tree_reduce(operator.and_, (~b for b in v)) - return types.sint.conv(bit) + return types.sintbit.conv(bit) prog.non_linear.check_security(kappa) return prog.non_linear.eqz(a, k) @@ -263,16 +263,17 @@ def BitAdd(a, b, bits_to_compute=None): def BitDec(a, k, m, kappa, bits_to_compute=None): return program.Program.prog.non_linear.bit_dec(a, k, m) -def BitDecRing(a, k, m): +def BitDecRingRaw(a, k, m): n_shift = int(program.Program.prog.options.ring) - m assert(n_shift >= 0) if program.Program.prog.use_split(): x = a.split_to_two_summands(m) bits = types._bitint.carry_lookahead_adder(x[0], x[1], fewer_inv=False) - # reversing to reduce number of rounds - return [types.sint.conv(bit) for bit in reversed(bits)][::-1] + return bits[:m] else: - if program.Program.prog.use_dabit: + if program.Program.prog.use_edabit(): + r, r_bits = types.sint.get_edabit(m, strict=False) + elif program.Program.prog.use_dabit: r, r_bits = zip(*(types.sint.get_dabit() for i in range(m))) r = types.sint.bit_compose(r) else: @@ -281,7 +282,12 @@ def BitDecRing(a, k, m): shifted = ((a - r) << n_shift).reveal() masked = shifted >> n_shift bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) - return [types.sint.conv(bit) for bit in bits] + return bits + +def BitDecRing(a, k, m): + bits = BitDecRingRaw(a, k, m) + # reversing to reduce number of rounds + return [types.sint.conv(bit) for bit in reversed(bits)][::-1] def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): r_dprime = types.sint() @@ -429,7 +435,7 @@ def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa): s = (1 - overflow) * t + overflow * t / 2 return s, overflow -def Int2FL(a, gamma, l, kappa): +def Int2FL(a, gamma, l, kappa=None): lam = gamma - 1 s = a.less_than(0, gamma, security=kappa) z = a.equal(0, gamma, security=kappa) @@ -598,13 +604,13 @@ def SDiv_mono(a, b, l, kappa): # Unconditionally Secure Constant-Rounds Multi-party Computation # for Equality, Comparison, Bits and Exponentiation def BITLT(a, b, bit_length): - sint = types.sint - e = [sint(0)]*bit_length - g = [sint(0)]*bit_length - h = [sint(0)]*bit_length + from .types import sint, regint, longint, cint + e = [None]*bit_length + g = [None]*bit_length + h = [None]*bit_length for i in range(bit_length): # Compute the XOR (reverse order of e for PreOpL) - e[bit_length-i-1] = a[i].bit_xor(b[i]) + e[bit_length-i-1] = util.bit_xor(a[i], b[i]) f = PreOpL(or_op, e) g[bit_length-1] = f[0] for i in range(bit_length-1): @@ -612,7 +618,7 @@ def BITLT(a, b, bit_length): g[i] = f[bit_length-i-1]-f[bit_length-i-2] ans = 0 for i in range(bit_length): - h[i] = g[i]*b[i] + h[i] = g[i].bit_and(b[i]) ans = ans + h[i] return ans @@ -620,9 +626,9 @@ def BITLT(a, b, bit_length): # - From the paper # Multiparty Computation for Interval, Equality, and Comparison without # Bit-Decomposition Protocol -def BitDecFull(a): +def BitDecFull(a, maybe_mixed=False): from .library import get_program, do_while, if_, break_point - from .types import sint, regint, longint + from .types import sint, regint, longint, cint p = get_program().prime assert p bit_length = p.bit_length() @@ -631,9 +637,16 @@ def BitDecFull(a): # inspired by Rabbit (https://eprint.iacr.org/2021/119) # no need for exact randomness generation # if modulo a power of two is close enough - bbits = [sint.get_random_bit(size=a.size) for i in range(logp)] - if logp != bit_length: - bbits += [sint(0, size=a.size)] + if get_program().use_edabit(): + b, bbits = sint.get_edabit(logp, True, size=a.size) + if logp != bit_length: + from .GC.types import sbits + bbits += [sbits.get_type(a.size)(0)] + else: + bbits = [sint.get_random_bit(size=a.size) for i in range(logp)] + b = sint.bit_compose(bbits) + if logp != bit_length: + bbits += [sint(0, size=a.size)] else: bbits = [sint(size=a.size) for i in range(bit_length)] tbits = [[sint(size=1) for i in range(bit_length)] for j in range(a.size)] @@ -653,15 +666,21 @@ def _(): for j in range(a.size): for i in range(bit_length): movs(bbits[i][j], tbits[j][i]) - b = sint.bit_compose(bbits) + b = sint.bit_compose(bbits) c = (a-b).reveal() - t = (p-c).bit_decompose(bit_length) + cmodp = c + t = bbits[0].bit_decompose_clear(p - c, bit_length) c = longint(c, bit_length) czero = (c==0) - q = 1-BITLT( bbits, t, bit_length) - fbar=((1< 1: + @multithread(n_threads, left // 2) + def _(base, size): + outputs.assign_vector( + function(inputs.get_vector(2 * base, size), + inputs.get_vector(2 * base + size, size)), base) + inputs.assign_vector(outputs.get_vector(0, left // 2)) + if left % 2 == 1: + inputs[left // 2] = inputs[left - 1] + left = (left + 1) // 2 + return inputs[0] + def foreach_enumerate(a): """ Run-time loop over public data. This uses ``Player-Data/Public-Input/``. Example: @@ -1511,6 +1549,15 @@ def break_point(name=''): """ get_tape().start_new_basicblock(name=name) +def check_point(): + """ + Force MAC checks in current thread and all idle threads if the + current thread is the main thread. This implies a break point. + """ + break_point('pre-check') + check() + break_point('post-check') + # Fixed point ops from math import ceil, log @@ -1566,6 +1613,9 @@ def cint_cint_division(a, b, k, f): # theta can be replaced with something smaller # for safety we assume that is the same theta from previous GS method + if get_program().options.ring: + assert 2 * f < int(get_program().options.ring) + theta = int(ceil(log(k/3.5) / log(2))) two = cint(2) * two_power(f) @@ -1579,9 +1629,11 @@ def cint_cint_division(a, b, k, f): B = absolute_b W = w0 - for i in range(1, theta): - A = (A * W) >> f - B = (B * W) >> f + corr = cint(1) << (f - 1) + + for i in range(theta): + A = (A * W + corr) >> f + B = (B * W + corr) >> f W = two - B return (sign_a * sign_b) * A @@ -1592,7 +1644,7 @@ def sint_cint_division(a, b, k, f, kappa): """ theta = int(ceil(log(k/3.5) / log(2))) two = cint(2) * two_power(f) - sign_b = cint(1) - 2 * cint(b < 0) + sign_b = cint(1) - 2 * cint(b.less_than(0, k)) sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa) absolute_b = b * sign_b absolute_a = a * sign_a @@ -1652,7 +1704,8 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): y = y.extend(2 * k) * (alpha + x).extend(2 * k) y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True) return y -def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False): + +def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False): """ Approximate reciprocal of [b]: Given [b], compute [1/b] @@ -1662,7 +1715,7 @@ def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False): #v should be 2**{k - m} where m is the length of the bitwise repr of [b] d = alpha - 2 * c w = d * v - w = w.round(2 * k, 2 * (k - f), kappa, nearest, signed=True) + w = w.round(2 * k + 1, 2 * (k - f), kappa, nearest, signed=True) # now w * 2 ^ {-f} should be an initial approximation of 1/b return w @@ -1674,7 +1727,7 @@ def Norm(b, k, f, kappa, simplex_flag=False): # For simplex, we can get rid of computing abs(b) temp = None if simplex_flag == False: - temp = comparison.LessThanZero(b, 2 * k, kappa) + temp = comparison.LessThanZero(b, k, kappa) elif simplex_flag == True: temp = cint(0) @@ -1682,7 +1735,7 @@ def Norm(b, k, f, kappa, simplex_flag=False): absolute_val = sign * b #next 2 lines actually compute the SufOR for little indian encoding - bits = absolute_val.bit_decompose(k, kappa)[::-1] + bits = absolute_val.bit_decompose(k, kappa, maybe_mixed=True)[::-1] suffixes = PreOR(bits, kappa)[::-1] z = [0] * k @@ -1690,10 +1743,7 @@ def Norm(b, k, f, kappa, simplex_flag=False): z[i] = suffixes[i] - suffixes[i+1] z[k - 1] = suffixes[k-1] - #doing complicated stuff to compute v = 2^{k-m} - acc = cint(0) - for i in range(k): - acc += two_power(k-i-1) * z[i] + acc = sint.bit_compose(reversed(z)) part_reciprocal = absolute_val * acc signed_acc = sign * acc diff --git a/Compiler/ml.py b/Compiler/ml.py index 2fb115d92..e58eb4bf5 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1,8 +1,8 @@ """ This module contains machine learning functionality. It is work in progress, so you must expect things to change. The only tested -functionality for training is using consective dense/fully-connected -layers. This includes logistic regression. It can be run as +functionality for training is using consective layers. +This includes logistic regression. It can be run as follows:: sgd = ml.SGD([ml.Dense(n_examples, n_features, 1), @@ -18,6 +18,22 @@ :py:obj:`sgd.layers[1].b`. The :py:obj:`approx` parameter determines whether to use an approximate sigmoid function. Setting it to 5 uses a five-piece approximation instead of a three-piece one. + +A simple network for MNIST using two dense layers can be trained as +follows:: + + sgd = ml.SGD([ml.Dense(60000, 784, 128, activation='relu'), + ml.Dense(60000, 128, 10), + ml.MultiOutput(60000, 10)], n_epochs, + report_loss=True) + sgd.layers[0].X.input_from(0) + sgd.layers[1].Y.input_from(1) + sgd.reset() + sgd.run() + +See `this repository `_ +for scripts importing MNIST training data and further examples. + Inference can be run as follows:: data = sfix.Matrix(n_test, n_features) @@ -40,9 +56,6 @@ See the `readme `_ for an example of how to run MP-SPDZ on TensorFlow graphs. - -See also `this repository `_ -for an example of how to train a model for MNIST. """ import math @@ -178,11 +191,17 @@ def assign_vector(self, *args): self.alloc() return super(Tensor, self).assign_vector(*args) + def assign_vector_by_indices(self, *args): + self.alloc() + return super(Tensor, self).assign_vector_by_indices(*args) + class Layer: n_threads = 1 inputs = [] input_bias = True thetas = lambda self: () + debug_output = False + back_batch_size = 128 @property def shape(self): @@ -206,8 +225,17 @@ def Y(self): def Y(self, value): self._Y = value + def forward(self, batch=None, training=None): + if batch is None: + batch = Array.create_from(regint(0)) + self._forward(batch) + + def __str__(self): + return type(self).__name__ + str(self._Y.sizes) + class NoVariableLayer(Layer): input_from = lambda *args, **kwargs: None + output_weights = lambda *args: None nablas = lambda self: () reset = lambda self: None @@ -241,7 +269,7 @@ def __init__(self, N, debug=False, approx=False): def divisor(self, divisor, size): return cfix(1.0 / divisor, size=size) - def forward(self, batch): + def _forward(self, batch): if self.approx == 5: self.l.write(999) return @@ -284,14 +312,10 @@ def _(base, size): # @for_range_opt(len(diff)) # def _(i): # self.nabla_X[i] = self.nabla_X[i] * self.weights[i] - if self.debug: - a = cfix.Array(len(diff)) - a.assign(diff.reveal()) - @for_range(len(diff)) - def _(i): - x = a[i] - print_ln_if((x < -1.001) + (x > 1.001), 'sigmoid') - #print_ln('%s', x) + if self.debug_output: + print_ln('sigmoid X %s', self.X.reveal_nested()) + print_ln('sigmoid nabla %s', self.nabla_X.reveal_nested()) + print_ln('batch %s', batch.reveal_nested()) def set_weights(self, weights): self.weights = cfix.Array(len(weights)) @@ -350,9 +374,11 @@ def reveal_correctness(self, n=None, Y=None, debug=False): n = self.X.sizes[0] if Y is None: Y = self.Y - n_correct = MemValue(0) n_printed = MemValue(0) - @for_range_opt(n) + assert n <= len(self.X) + assert n <= len(Y) + Y.address = MemValue.if_necessary(Y.address) + @map_sum(None if debug else self.n_threads, None, n, 1, regint) def _(i): a = Y[i].reveal_list() b = self.X[i].reveal_list() @@ -363,13 +389,13 @@ def _(i): truth = argmax(a) guess = argmax(b) correct = truth == guess - n_correct.iadd(correct) if debug: to_print = (1 - correct) * (n_printed < 10) n_printed.iadd(to_print) print_ln_if(to_print, '%s: %s %s %s %s %s %s', i, truth, guess, loss, b, exp, nabla) - return n_correct + return correct + return _() @property def n_outputs(self): @@ -410,7 +436,7 @@ def __init__(self, N, d_out, approx=False, debug=False): self.debug = debug self.true_X = sfix.Array(N) - def forward(self, batch): + def _forward(self, batch): N = len(batch) d_out = self.X.sizes[1] tmp = self.losses @@ -458,7 +484,9 @@ def _(i): return res @for_range_opt_multithread(self.n_threads, N) def _(i): - e = exp(self.X[i].get_vector()) + x = self.X[i].get_vector() - \ + util.max(self.X[i].get_vector()).expand_to_vector(d_out) + e = exp(x) res[i].assign_vector(e / sum(e).expand_to_vector(d_out)) return res @@ -529,7 +557,7 @@ class ReluMultiOutput(MultiOutputBase): :param N: number of examples :param d_out: number of classes """ - def forward(self, batch): + def forward(self, batch, training=None): self.l.write(999) def backward(self, batch): @@ -550,6 +578,10 @@ class DenseBase(Layer): thetas = lambda self: (self.W, self.b) nablas = lambda self: (self.nabla_W, self.nabla_b) + def output_weights(self): + print_ln('%s', self.W.reveal_nested()) + print_ln('%s', self.b.reveal_nested()) + def backward_params(self, f_schur_Y, batch): N = len(batch) tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) @@ -586,6 +618,10 @@ def _(i): progress('nabla b') + if self.debug_output: + print_ln('dense nabla Y %s', self.nabla_Y.reveal_nested()) + print_ln('dense W %s', self.W.reveal_nested()) + print_ln('dense nabla X %s', self.nabla_X.reveal_nested()) if self.debug: limit = N * self.debug @for_range_opt(self.d_in) @@ -646,8 +682,9 @@ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): self.W = Tensor([d_in, d_out], sfix) self.b = sfix.Array(d_out) - self.nabla_Y = MultiArray([N, d, d_out], sfix) - self.nabla_X = MultiArray([N, d, d_in], sfix) + back_N = min(N, self.back_batch_size) + self.nabla_Y = MultiArray([back_N, d, d_out], sfix) + self.nabla_X = MultiArray([back_N, d, d_in], sfix) self.nabla_W = sfix.Matrix(d_in, d_out) self.nabla_b = sfix.Array(d_out) @@ -665,6 +702,7 @@ def reset(self): d_in = self.d_in d_out = self.d_out r = math.sqrt(6.0 / (d_in + d_out)) + print('Initializing dense weights in [%f,%f]' % (-r, r)) self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size())) self.b.assign_all(0) @@ -702,13 +740,18 @@ def _(i): self.f_input[i].assign_vector(v) progress('f input') - def forward(self, batch=None): + def _forward(self, batch=None): if batch is None: batch = regint.Array(self.N) batch.assign(regint.inc(self.N)) self.compute_f_input(batch=batch) if self.activation_layer: self.activation_layer.forward(batch) + if self.debug_output: + print_ln('dense X %s', self.X.reveal_nested()) + print_ln('dense W %s', self.W.reveal_nested()) + print_ln('dense b %s', self.b.reveal_nested()) + print_ln('dense Y %s', self.Y.reveal_nested()) if self.debug: limit = self.debug @for_range_opt(len(batch)) @@ -784,7 +827,7 @@ def _(j): self.W[i][j] = sfix.get_random(-1, 1) self.b.assign_all(0) - def forward(self): + def _forward(self): @for_range_opt(self.d_in) def _(i): @for_range_opt(self.d_out) @@ -806,8 +849,14 @@ def backward(self, compute_nabla_X=False): assert not compute_nabla_X self.backward_params(self.nabla_Y) -class Dropout: - def __init__(self, N, d1, d2=1): +class Dropout(NoVariableLayer): + """ Dropout layer. + + :param N: number of examples + :param d1: total dimension + :param alpha: probability (power of two) + """ + def __init__(self, N, d1, d2=1, alpha=0.5): self.N = N self.d1 = d1 self.d2 = d2 @@ -815,48 +864,76 @@ def __init__(self, N, d1, d2=1): self.Y = MultiArray([N, d1, d2], sfix) self.nabla_Y = MultiArray([N, d1, d2], sfix) self.nabla_X = MultiArray([N, d1, d2], sfix) - self.alpha = 0.5 + self.alpha = alpha self.B = MultiArray([N, d1, d2], sint) - def forward(self): - assert self.alpha == 0.5 - @for_range(self.N) - def _(i): - @for_range(self.d1) - def _(j): - @for_range(self.d2) - def _(k): - self.B[i][j][k] = sint.get_random_bit() - self.Y = self.X.schur(self.B) + def forward(self, batch, training=False): + if training: + n_bits = -math.log(self.alpha, 2) + assert n_bits == int(n_bits) + n_bits = int(n_bits) + @for_range_opt_multithread(self.n_threads, len(batch)) + def _(i): + size = self.d1 * self.d2 + self.B[i].assign_vector(util.tree_reduce( + util.or_op, (sint.get_random_bit(size=size) + for i in range(n_bits)))) + @for_range_opt_multithread(self.n_threads, len(batch)) + def _(i): + self.Y[i].assign_vector(1 / (1 - self.alpha) * + self.X[batch[i]].get_vector() * self.B[i].get_vector()) + else: + @for_range(len(batch)) + def _(i): + self.Y[i] = self.X[batch[i]] + if self.debug_output: + print_ln('dropout X %s', self.X.reveal_nested()) + print_ln('dropout Y %s', self.Y.reveal_nested()) - def backward(self): - self.nabla_X = self.nabla_Y.schur(self.B) + def backward(self, compute_nabla_X=True, batch=None): + if compute_nabla_X: + @for_range_opt_multithread(self.n_threads, len(batch)) + def _(i): + self.nabla_X[batch[i]].assign_vector( + self.nabla_Y[i].get_vector() * self.B[i].get_vector()) + if self.debug_output: + print_ln('dropout nabla_Y %s', self.nabla_Y.reveal_nested()) + print_ln('dropout nabla_X %s', self.nabla_X.reveal_nested()) class ElementWiseLayer(NoVariableLayer): def __init__(self, shape, inputs=None): self.X = Tensor(shape, sfix) self.Y = Tensor(shape, sfix) - self.nabla_X = Tensor(shape, sfix) - self.nabla_Y = Tensor(shape, sfix) + backward_shape = list(shape) + backward_shape[0] = min(shape[0], self.back_batch_size) + self.nabla_X = Tensor(backward_shape, sfix) + self.nabla_Y = Tensor(backward_shape, sfix) self.inputs = inputs - def forward(self, batch=[0]): - @multithread(self.n_threads, len(batch), 128) + def _forward(self, batch=[0]): + n_per_item = reduce(operator.mul, self.X.sizes[1:]) + @multithread(self.n_threads, len(batch), max(1, 1000 // n_per_item)) def _(base, size): - self.Y.assign_part_vector(self.f( - self.X.get_part_vector(base, size)), base) + self.Y.assign_part_vector(self.f_part(base, size), base) + + if self.debug_output: + name = self + @for_range(len(batch)) + def _(i): + print_ln('%s X %s %s', name, i, self.X[i].reveal_nested()) + print_ln('%s Y %s %s', name, i, self.Y[i].reveal_nested()) def backward(self, batch): f_prime_bit = MultiArray(self.X.sizes, self.prime_type) + n_elements = len(batch) * reduce(operator.mul, f_prime_bit.sizes[1:]) - @multithread(self.n_threads, f_prime_bit.total_size()) + @multithread(self.n_threads, n_elements) def _(base, size): - f_prime_bit.assign_vector( - self.f_prime(self.X.get_vector(base, size)), base) + f_prime_bit.assign_vector(self.f_prime_part(base, size), base) progress('f prime') - @multithread(self.n_threads, f_prime_bit.total_size()) + @multithread(self.n_threads, n_elements) def _(base, size): self.nabla_X.assign_vector(self.nabla_Y.get_vector(base, size) * f_prime_bit.get_vector(base, size), @@ -864,6 +941,15 @@ def _(base, size): progress('f prime schur Y') + if self.debug_output: + name = self + @for_range(len(batch)) + def _(i): + print_ln('%s X %s %s', name, i, self.X[i].reveal_nested()) + print_ln('%s f_prime %s %s', name, i, f_prime_bit[i].reveal_nested()) + print_ln('%s nabla Y %s %s', name, i, self.nabla_Y[i].reveal_nested()) + print_ln('%s nabla X %s %s', name, i, self.nabla_X[i].reveal_nested()) + class Relu(ElementWiseLayer): """ Fixed-point ReLU layer. @@ -872,6 +958,20 @@ class Relu(ElementWiseLayer): f = staticmethod(relu) f_prime = staticmethod(relu_prime) prime_type = sint + comparisons = None + + def __init__(self, shape, inputs=None): + super(Relu, self).__init__(shape) + self.comparisons = MultiArray(shape, sint) + + def f_part(self, base, size): + x = self.X.get_part_vector(base, size) + c = x > 0 + self.comparisons.assign_part_vector(c, base) + return c.if_else(x, 0) + + def f_prime_part(self, base, size): + return self.comparisons.get_vector(base, size) class Square(ElementWiseLayer): """ Fixed-point square layer. @@ -904,14 +1004,45 @@ def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), self.Y = Tensor(output_shape, sfix) self.strides = strides self.ksize = ksize + self.nabla_X = Tensor(shape, sfix) + self.nabla_Y = Tensor(output_shape, sfix) + self.N = shape[0] + self.comparisons = MultiArray([self.N, self.X.sizes[3], + ksize[1] * ksize[2]], sint) + + def _forward(self, batch): + def process(pool, bi, k, i, j): + def m(a, b): + c = a[0] > b[0] + l = [c * x for x in a[1]] + l += [(1 - c) * x for x in b[1]] + return c.if_else(a[0], b[0]), l + red = util.tree_reduce(m, [(x[0], [1]) for x in pool]) + self.Y[bi][i][j][k] = red[0] + for i, x in enumerate(red[1]): + self.comparisons[bi][k][i] = x + self.traverse(batch, process) - def forward(self, batch=[0]): - assert len(batch) == 1 - bi = MemValue(batch[0]) + def backward(self, compute_nabla_X=True, batch=None): + if compute_nabla_X: + self.nabla_X.alloc() + def process(pool, bi, k, i, j): + for (x, h_in, w_in, h, w), c in zip(pool, + self.comparisons[bi][k]): + hh = h * h_in + ww = w * w_in + self.nabla_X[bi][hh][ww][k] = \ + util.if_else(h_in * w_in, c * self.nabla_Y[bi][i][j][k], + self.nabla_X[bi][hh][ww][k]) + self.traverse(batch, process) + + def traverse(self, batch, process): need_padding = [self.strides[i] * (self.Y.sizes[i] - 1) + self.ksize[i] > self.X.sizes[i] for i in range(4)] - @for_range_opt_multithread(self.n_threads, self.X.sizes[3]) - def _(k): + @for_range_opt_multithread(self.n_threads, + [len(batch), self.X.sizes[3]]) + def _(l, k): + bi = batch[l] @for_range_opt(self.Y.sizes[1]) def _(i): h_base = self.strides[1] * i @@ -932,10 +1063,10 @@ def _(j): else: w_in = True if not is_zero(h_in * w_in): - pool.append(h_in * w_in * self.X[bi][h_in * h] - [w_in * w][k]) - self.Y[bi][i][j][k] = util.tree_reduce( - lambda a, b: a.max(b), pool) + pool.append([h_in * w_in * self.X[bi][h_in * h] + [w_in * w][k], h_in, w_in, h, w]) + process(pool, bi, k, i, j) + class Argmax(NoVariableLayer): """ Fixed-point Argmax layer. @@ -947,7 +1078,7 @@ def __init__(self, shape): self.X = MultiArray(shape, sfix) self.Y = Array(shape[0], sint) - def forward(self, batch=[0]): + def _forward(self, batch=[0]): assert len(batch) == 1 self.Y[batch[0]] = argmax(self.X[batch[0]]) @@ -973,7 +1104,7 @@ def __init__(self, inputs, dimension): shape.append(shapes[0][i]) self.Y = Tensor(shape, sfix) - def forward(self, batch=[0]): + def _forward(self, batch=[0]): assert len(batch) == 1 @for_range_multithread(self.n_threads, 1, self.Y.sizes[1:3]) def _(i, j): @@ -996,7 +1127,7 @@ def __init__(self, inputs): self.Y = Tensor(shape, sfix) self.inputs = inputs - def forward(self, batch=[0]): + def _forward(self, batch=[0]): assert len(batch) == 1 @multithread(self.n_threads, self.Y[0].total_size()) def _(base, size): @@ -1024,7 +1155,7 @@ def input_from(self, player, raw=False): tmp.input_from(player, raw=raw) tmp.input_from(player, raw=raw) - def forward(self, batch=[0]): + def _forward(self, batch=[0]): assert len(batch) == 1 @for_range_opt_multithread(self.n_threads, self.X.sizes[1:3]) def _(i, j): @@ -1081,17 +1212,29 @@ def __init__(self, input_shape, output_shape, inputs=None): self.X = Tensor(input_shape, self.input_squant) self.Y = Tensor(output_shape, self.output_squant) + + back_shapes = list(input_shape), list(output_shape) + for x in back_shapes: + x[0] = min(x[0], self.back_batch_size) + + self.nabla_X = MultiArray(back_shapes[0], self.input_squant) + self.nabla_Y = MultiArray(back_shapes[1], self.output_squant) self.inputs = inputs def temp_shape(self): return [0] + @property + def N(self): + return self.input_shape[0] + class ConvBase(BaseLayer): fewer_rounds = True - use_conv2ds = False + use_conv2ds = True temp_weights = None temp_inputs = None thetas = lambda self: (self.weights, self.bias) + nablas = lambda self: (self.nabla_weights, self.nabla_bias) @classmethod def init_temp(cls, layers): @@ -1114,6 +1257,7 @@ def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, self.padding = [] for i in 1, 2: s = stride[i - 1] + assert output_shape[i] >= input_shape[i] // s if tf_weight_format: w = weight_shape[i - 1] else: @@ -1134,7 +1278,10 @@ def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, self.weights = Tensor(weight_shape, self.weight_squant) self.bias = Array(output_shape[-1], self.bias_squant) - self.unreduced = Tensor(self.output_shape, sint) + self.nabla_weights = Tensor(weight_shape, self.weight_squant) + self.nabla_bias = Array(output_shape[-1], self.bias_squant) + + self.unreduced = Tensor(self.output_shape, sint, address=self.Y.address) if tf_weight_format: weight_in = weight_shape[2] @@ -1153,6 +1300,10 @@ def input_from(self, player, raw=False): if self.input_bias: self.bias.input_from(player, raw=raw) + def output_weights(self): + print_ln('%s', self.weights.reveal_nested()) + print_ln('%s', self.bias.reveal_nested()) + def dot_product(self, iv, wv, out_y, out_x, out_c): bias = self.bias[out_c] acc = self.output_squant.unreduced_dot_product(iv, wv) @@ -1161,12 +1312,13 @@ def dot_product(self, iv, wv, out_y, out_x, out_c): #self.Y[0][out_y][out_x][out_c] = acc.reduce_after_mul() self.unreduced[0][out_y][out_x][out_c] = acc.v - def reduction(self): + def reduction(self, batch_length=1): unreduced = self.unreduced n_summands = self.n_summands() - start_timer(2) - n_outputs = reduce(operator.mul, self.output_shape) - @multithread(self.n_threads, n_outputs) + #start_timer(2) + n_outputs = batch_length * reduce(operator.mul, self.output_shape[1:]) + @multithread(self.n_threads, n_outputs, + 1000 if sfix.round_nearest else 10 ** 6) def _(base, n_per_thread): res = self.input_squant().unreduced( sint.load_mem(unreduced.address + base, @@ -1175,8 +1327,7 @@ def _(base, n_per_thread): self.output_squant.params, n_summands).reduce_after_mul() res.store_in_mem(self.Y.address + base) - stop_timer(2) - unreduced.delete() + #stop_timer(2) def temp_shape(self): return list(self.output_shape[1:]) + [self.n_summands()] @@ -1195,9 +1346,7 @@ def n_summands(self): _, inputs_h, inputs_w, n_channels_in = self.input_shape return weights_h * weights_w * n_channels_in - def forward(self, batch=[None]): - assert len(batch) == 1 - + def _forward(self, batch): if self.tf_weight_format: assert(self.weight_shape[3] == self.output_shape[-1]) weights_h, weights_w, _, _ = self.weight_shape @@ -1210,35 +1359,47 @@ def forward(self, batch=[None]): stride_h, stride_w = self.stride padding_h, padding_w = self.padding - self.unreduced.alloc() - if self.use_conv2ds: - @for_range_opt_multithread(self.n_threads, n_channels_out) - def _(j): - inputs = self.X.get_part_vector(0) + n_parts = max(1, round(self.n_threads / n_channels_out)) + while len(batch) % n_parts != 0: + n_parts -= 1 + print('Convolution in %d parts' % n_parts) + part_size = len(batch) // n_parts + @for_range_multithread(self.n_threads, 1, [n_parts, n_channels_out]) + def _(i, j): + inputs = self.X.get_slice_vector( + batch.get_part(i * part_size, part_size)) if self.tf_weight_format: weights = self.weights.get_vector_by_indices(None, None, None, j) else: weights = self.weights.get_part_vector(j) inputs = inputs.pre_mul() weights = weights.pre_mul() - res = sint(size = output_h * output_w) + res = sint(size = output_h * output_w * part_size) conv2ds(res, inputs, weights, output_h, output_w, inputs_h, inputs_w, weights_h, weights_w, - stride_h, stride_w, n_channels_in, padding_h, padding_w) + stride_h, stride_w, n_channels_in, padding_h, padding_w, + part_size) if self.bias_before_reduction: res += self.bias.expand_to_vector(j, res.size).v - self.unreduced.assign_vector_by_indices(res, 0, None, None, j) - self.reduction() - if not self.bias_before_reduction: - @for_range_multithread(self.n_threads, 1, - [self.output_shape[1], - self.output_shape[2]]) - def _(i, j): - self.Y[0][i][j].assign_vector(self.Y[0][i][j].get_vector() + - self.bias.get_vector()) + else: + res += self.bias.expand_to_vector(j, res.size).v << \ + self.input_squant.f + addresses = regint.inc(res.size, + self.unreduced[i * part_size].address + j, + n_channels_out) + res.store_in_mem(addresses) + self.reduction(len(batch)) + if self.debug_output: + print_ln('%s weights %s', self, self.weights.reveal_nested()) + print_ln('%s bias %s', self, self.bias.reveal_nested()) + @for_range(len(batch)) + def _(i): + print_ln('%s X %s %s', self, i, self.X[batch[i]].reveal_nested()) + print_ln('%s Y %s %s', self, i, self.Y[i].reveal_nested()) return else: + assert len(batch) == 1 if self.fewer_rounds: inputs, weights = self.prepare_temp() @@ -1298,15 +1459,125 @@ class FixConv2d(Conv2d, FixBase): :param output_shape: output shape (tuple/list of four int) :param stride: stride (tuple/list of two int) :param padding: :py:obj:`'SAME'` (default), :py:obj:`'VALID'`, or tuple/list of two int - :param tf_weight_format: weight shape format is (height, width, input channels, output channels) instead of the default (output channels, height, widght, input channels) + :param tf_weight_format: weight shape format is (height, width, input channels, output channels) instead of the default (output channels, height, width, input channels) """ + def reset(self): + assert not self.tf_weight_format + kernel_size = self.weight_shape[1] * self.weight_shape[2] + r = math.sqrt(6.0 / (kernel_size * sum(self.weight_shape[::3]))) + print('Initializing convolution weights in [%f,%f]' % (-r, r)) + self.weights.assign_vector( + sfix.get_random(-r, r, size=self.weights.total_size())) + self.bias.assign_all(0) + + def backward(self, compute_nabla_X=True, batch=None): + assert self.use_conv2ds + + assert not self.tf_weight_format + _, weights_h, weights_w, _ = self.weight_shape + _, inputs_h, inputs_w, n_channels_in = self.input_shape + _, output_h, output_w, n_channels_out = self.output_shape + + stride_h, stride_w = self.stride + padding_h, padding_w = self.padding + + N = len(batch) + + self.nabla_bias.assign_all(0) + + @for_range(N) + def _(i): + self.nabla_bias.assign_vector( + self.nabla_bias.get_vector() + sum(sum( + self.nabla_Y[i][j][k].get_vector() for k in range(output_w)) + for j in range(output_h))) + + input_size = inputs_h * inputs_w * N + batch_repeat = regint.Matrix(N, inputs_h * inputs_w) + batch_repeat.assign_vector(batch.get( + regint.inc(input_size, 0, 1, 1, N)) * + reduce(operator.mul, self.input_shape[1:])) + + @for_range_opt_multithread(self.n_threads, [n_channels_in, n_channels_out]) + def _(i, j): + a = regint.inc(input_size, self.X.address + i, n_channels_in, N, + inputs_h * inputs_w) + inputs = sfix.load_mem(batch_repeat.get_vector() + a).pre_mul() + b = regint.inc(N * output_w * output_h, self.nabla_Y.address + j, n_channels_out, N) + rep_out = regint.inc(output_h * output_w * N, 0, 1, 1, N) * \ + reduce(operator.mul, self.output_shape[1:]) + nabla_outputs = sfix.load_mem(rep_out + b).pre_mul() + res = sint(size = weights_h * weights_w) + conv2ds(res, inputs, nabla_outputs, weights_h, weights_w, inputs_h, + inputs_w, output_h, output_w, -stride_h, -stride_w, N, + padding_h, padding_w, 1) + reduced = unreduced_sfix._new(res).reduce_after_mul() + self.nabla_weights.assign_vector_by_indices(reduced, j, None, None, i) + + if compute_nabla_X: + assert tuple(self.padding) == (0, 0) + assert tuple(self.stride) == (1, 1) + reverse_weights = MultiArray( + [n_channels_in, weights_h, weights_w, n_channels_out], sfix) + @for_range(n_channels_out) + def _(i): + @for_range(weights_h) + def _(j): + @for_range(weights_w) + def _(k): + @for_range(n_channels_in) + def _(l): + reverse_weights[l][weights_h-j-1][k][i] = \ + self.weights[i][j][weights_w-k-1][l] + padded_w = inputs_w + 2 * padding_w + padded_h = inputs_h + 2 * padding_h + if padding_h or padding_w: + output = MultiArray( + [N, padded_h, padded_w, n_channels_in], sfix) + else: + output = self.nabla_X + @for_range_opt_multithread(self.n_threads, + [N, n_channels_in]) + def _(i, j): + res = sint(size = (padded_w * padded_h)) + conv2ds(res, self.nabla_Y[i].get_vector().pre_mul(), + reverse_weights[j].get_vector().pre_mul(), + padded_h, padded_w, output_h, output_w, + weights_h, weights_w, 1, 1, n_channels_out, + weights_h - 1, weights_w - 1, 1) + output.assign_vector_by_indices( + unreduced_sfix._new(res).reduce_after_mul(), + i, None, None, j) + if padding_h or padding_w: + @for_range(N) + def _(i): + @for_range(inputs_h) + def _(j): + @for_range(inputs_w) + def _(k): + self.nabla_X[i][j][k].assign_vector( + output[i][j][k].get_vector()) + + if self.debug_output: + @for_range(len(batch)) + def _(i): + print_ln('%s X %s %s', self, i, list(self.X[i].reveal_nested())) + print_ln('%s nabla Y %s %s', self, i, list(self.nabla_Y[i].reveal_nested())) + if compute_nabla_X: + print_ln('%s nabla X %s %s', self, i, self.nabla_X[batch[i]].reveal_nested()) + print_ln('%s nabla weights %s', self, + (self.nabla_weights.reveal_nested())) + print_ln('%s weights %s', self, (self.weights.reveal_nested())) + print_ln('%s nabla b %s', self, (self.nabla_bias.reveal_nested())) + print_ln('%s bias %s', self, (self.bias.reveal_nested())) + class QuantDepthwiseConv2d(QuantConvBase, Conv2d): def n_summands(self): _, weights_h, weights_w, _ = self.weight_shape return weights_h * weights_w - def forward(self, batch): + def _forward(self, batch): assert len(batch) == 1 assert(self.weight_shape[-1] == self.output_shape[-1]) assert(self.input_shape[-1] == self.output_shape[-1]) @@ -1320,8 +1591,6 @@ def forward(self, batch): depth_multiplier = 1 - self.unreduced.alloc() - if self.use_conv2ds: assert depth_multiplier == 1 assert self.weight_shape[0] == 1 @@ -1336,7 +1605,7 @@ def _(j): res = sint(size = output_h * output_w) conv2ds(res, inputs, weights, output_h, output_w, inputs_h, inputs_w, weights_h, weights_w, - stride_h, stride_w, 1, padding_h, padding_w) + stride_h, stride_w, 1, padding_h, padding_w, 1) res += self.bias.expand_to_vector(j, res.size).v self.unreduced.assign_vector_by_indices(res, 0, None, None, j) self.reduction() @@ -1397,7 +1666,7 @@ def __init__(self, input_shape, output_shape, filter_size, strides=(1, 1)): def input_from(self, player, raw=False): self.input_params_from(player) - def forward(self, batch=[0]): + def _forward(self, batch=[0]): assert len(batch) == 1 _, input_h, input_w, n_channels_in = self.input_shape @@ -1460,7 +1729,7 @@ def input_from(self, player): for i in range(2): sint.get_input_from(player) - def forward(self, batch): + def _forward(self, batch): assert len(batch) == 1 # reshaping is implicit self.Y.assign(self.X) @@ -1471,7 +1740,7 @@ def input_from(self, player): for s in self.input_squant, self.output_squant: s.set_params(sfloat.get_input_from(player), sint.get_input_from(player)) - def forward(self, batch): + def _forward(self, batch): assert len(batch) == 1 assert(len(self.input_shape) == 2) @@ -1486,6 +1755,32 @@ class Optimizer: """ Base class for graphs of layers. """ n_threads = Layer.n_threads always_shuffle = True + time_layers = False + revealing_correctness = False + + @staticmethod + def from_args(program, layers): + if 'adam' in program.args or 'adamapprox' in program.args: + return Adam(layers, 1, approx='adamapprox' in program.args) + elif 'amsgrad' in program.args: + return Adam(layers, approx=True, amsgrad=True) + elif 'quotient' in program.args: + return Adam(layers, approx=True, amsgrad=True, normalize=True) + else: + return SGD(layers, 1) + + def __init__(self, report_loss=None): + self.tol = 0.000 + if report_loss is None: + self.report_loss = self.layers[-1].compute_loss + else: + self.report_loss = report_loss + self.X_by_label = None + self.print_update_average = False + self.print_losses = False + self.print_loss_reduction = False + self.i_epoch = MemValue(0) + self.stopped_on_loss = MemValue(0) @property def layers(self): @@ -1510,6 +1805,13 @@ def set_layers_with_inputs(self, layers): layer.last_used = list(filter(lambda x: x not in used, layer.inputs)) used.update(layer.inputs) + def reset(self): + """ Initialize weights. """ + for layer in self.layers: + layer.reset() + self.i_epoch.write(0) + self.stopped_on_loss.write(0) + def batch_for(self, layer, batch): if layer in (self.layers[0], self.layers[-1]): return batch @@ -1520,7 +1822,7 @@ def batch_for(self, layer, batch): @_no_mem_warnings def forward(self, N=None, batch=None, keep_intermediate=True, - model_from=None): + model_from=None, training=False): """ Compute graph. :param N: batch size (used if batch not given) @@ -1530,14 +1832,18 @@ def forward(self, N=None, batch=None, keep_intermediate=True, if batch is None: batch = regint.Array(N) batch.assign(regint.inc(N)) - for layer in self.layers: + for i, layer in enumerate(self.layers): if layer.inputs and len(layer.inputs) == 1 and layer.inputs[0] is not None: layer._X.address = layer.inputs[0].Y.address layer.Y.alloc() if model_from is not None: layer.input_from(model_from) break_point() - layer.forward(batch=self.batch_for(layer, batch)) + if self.time_layers: + start_timer(100 + i) + layer.forward(batch=self.batch_for(layer, batch), training=training) + if self.time_layers: + stop_timer(100 + i) break_point() if not keep_intermediate: for l in layer.last_used: @@ -1556,16 +1862,20 @@ def eval(self, data): @_no_mem_warnings def backward(self, batch): """ Compute backward propagation. """ - for layer in reversed(self.layers): - if len(layer.inputs) == 0: + for i, layer in reversed(list(enumerate(self.layers))): + assert len(batch) <= layer.back_batch_size + if self.time_layers: + start_timer(200 + i) + if not layer.inputs: layer.backward(compute_nabla_X=False, batch=self.batch_for(layer, batch)) else: layer.backward(batch=self.batch_for(layer, batch)) if len(layer.inputs) == 1: - layer.inputs[0].nabla_Y.alloc() - layer.inputs[0].nabla_Y.assign_vector( - layer.nabla_X.get_part_vector(0, len(batch))) + layer.inputs[0].nabla_Y.address = \ + layer.nabla_X.address + if self.time_layers: + stop_timer(200 + i) @_no_mem_warnings def run(self, batch_size=None, stop_on_loss=0): @@ -1581,6 +1891,7 @@ def run(self, batch_size=None, stop_on_loss=0): N = self.layers[0].N i = self.i_epoch n_iterations = MemValue(0) + self.n_correct = MemValue(0) @for_range(self.n_epochs) def _(_): if self.X_by_label is None: @@ -1599,6 +1910,7 @@ def _(_): if self.always_shuffle or n_per_epoch > 1: indices.shuffle() loss_sum = MemValue(sfix(0)) + self.n_correct.write(0) @for_range(n_per_epoch) def _(j): n_iterations.iadd(1) @@ -1608,7 +1920,7 @@ def _(j): batch.assign(indices.get_vector(j * n, n) + regint(label * len(self.X_by_label[0]), size=n), label * n) - self.forward(batch=batch) + self.forward(batch=batch, training=True) self.backward(batch=batch) self.update(i, batch=batch) loss_sum.iadd(self.layers[-1].l) @@ -1619,13 +1931,23 @@ def _(j): print_ln('loss reduction in batch %s: %s (%s - %s)', j, before - after, before, after) elif self.print_losses: - print_ln('loss in batch %s: %s', j, self.layers[-1].average_loss(N)) + print_str('\rloss in batch %s: %s/%s', j, + self.layers[-1].average_loss(N), + loss_sum.reveal() / (j + 1)) + if self.revealing_correctness: + part_truth = self.layers[-1].Y.same_shape() + part_truth.assign_vector( + self.layers[-1].Y.get_slice_vector(batch)) + self.n_correct.iadd( + self.layers[-1].reveal_correctness(batch_size, part_truth)) if stop_on_loss: loss = self.layers[-1].average_loss(N) - res = (loss < stop_on_loss) * (loss >= 0) + res = (loss < stop_on_loss) * (loss >= -1) self.stopped_on_loss.write(1 - res) return res - if self.report_loss and self.layers[-1].approx != 5: + if self.print_losses: + print_ln() + if self.report_loss and self.layers[-1].compute_loss and self.layers[-1].approx != 5: print_ln('loss in epoch %s: %s', i, (loss_sum.reveal() * cfix(1 / n_per_epoch))) else: @@ -1636,98 +1958,210 @@ def _(j): if self.tol > 0: res *= (1 - (loss >= 0) * (loss < self.tol)).reveal() return res - print_ln('finished after %s epochs and %s iterations', i, n_iterations) + + def reveal_correctness(self, data, truth, batch_size): + training_data = self.layers[0].X.address + training_truth = self.layers[-1].Y.address + self.layers[0].X.address = data.address + self.layers[-1].Y.address = truth.address + N = data.sizes[0] + batch = regint.Array(batch_size) + n_correct = MemValue(0) + loss = MemValue(sfix(0)) + def f(start, batch_size): + batch.assign_vector(regint.inc(batch_size, start)) + self.forward(batch=batch) + part_truth = truth.get_part(start, batch_size) + n_correct.iadd( + self.layers[-1].reveal_correctness(batch_size, part_truth)) + loss.iadd(self.layers[-1].l * batch_size) + @for_range(N // batch_size) + def _(i): + start = i * batch_size + f(start, batch_size) + batch_size = N % batch_size + if batch_size: + start = N - batch_size + f(start, batch_size) + self.layers[0].X.address = training_data + self.layers[-1].Y.address = training_truth + loss = loss.reveal() + if cfix.f < 31: + loss = cfix._new(loss.v << (31 - cfix.f), k=63, f=31) + return n_correct, loss / N @_no_mem_warnings - def run_by_args(self, program, n_runs, batch_size, test_X, test_Y): + def run_by_args(self, program, n_runs, batch_size, test_X, test_Y, + acc_batch_size=None): + if acc_batch_size is None: + acc_batch_size = batch_size + depreciation = None for arg in program.args: m = re.match('rate(.*)', arg) if m: self.gamma = MemValue(cfix(float(m.group(1)))) + m = re.match('dep(.*)', arg) + if m: + depreciation = float(m.group(1)) if 'nomom' in program.args: self.momentum = 0 + self.print_losses = 'print_losses' in program.args + self.time_layers = 'time_layers' in program.args + self.revealing_correctness = not 'no_acc' in program.args + self.layers[-1].compute_loss = not 'no_loss' in program.args model_input = 'model_input' in program.args + acc_first = model_input and not 'train_first' in program.args if model_input: for layer in self.layers: layer.input_from(0) else: self.reset() + if 'one_iter' in program.args: + self.output_weights() + print_ln('loss') + print_ln('%s', self.eval( + self.layers[0].X.get_part(0, batch_size)).reveal_nested()) + for layer in self.layers: + print_ln('%s', layer.X.get_part(0, batch_size).reveal_nested()) + print_ln('%s', self.layers[-1].Y.get_part(0, batch_size).reveal_nested()) + batch = Array.create_from(regint.inc(batch_size)) + self.forward(batch=batch, training=True) + self.backward(batch=batch) + self.update(0, batch=batch) + print_ln('loss %s', self.layers[-1].l.reveal()) + self.output_weights() + return @for_range(n_runs) def _(i): - if not model_input: + if not acc_first: start_timer(1) - self.run(batch_size, stop_on_loss=100) + self.run(batch_size, + stop_on_loss=0 if 'no_loss' in program.args else 100) stop_timer(1) if 'no_acc' in program.args: return N = self.layers[0].X.sizes[0] - self.forward(N) - batch = regint.Array(N) - batch.assign_vector(regint.inc(N)) - self.layers[-1].backward(batch) - n_correct = self.layers[-1].reveal_correctness(N, debug=True) - print_ln('train_acc: %s (%s/%s)', cfix(n_correct, k=63, f=32) / N, - n_correct, N) - training_address = self.layers[0].X.address - self.layers[0].X.address = test_X.address + n_trained = (N + batch_size - 1) // batch_size * batch_size + print_ln('train_acc: %s (%s/%s)', + cfix(self.n_correct, k=63, f=31) / n_trained, + self.n_correct, n_trained) n_test = len(test_Y) - self.forward(n_test) - self.layers[0].X.address = training_address - n_correct = self.layers[-1].reveal_correctness(n_test, test_Y) - print_ln('acc: %s (%s/%s)', cfix(n_correct, k=63, f=32) / n_test, + n_correct, loss = self.reveal_correctness(test_X, test_Y, acc_batch_size) + print_ln('test loss: %s', loss) + print_ln('acc: %s (%s/%s)', cfix(n_correct, k=63, f=31) / n_test, n_correct, n_test) - if model_input: + if acc_first: start_timer(1) self.run(batch_size) stop_timer(1) else: @if_(util.or_op(self.stopped_on_loss, n_correct < - int(n_test // self.layers[-1].n_outputs * 1.1))) + int(n_test // self.layers[-1].n_outputs * 1.2))) def _(): self.gamma.imul(.5) self.reset() print_ln('reset after reducing learning rate to %s', self.gamma) + if depreciation: + self.gamma.imul(depreciation) + print_ln('reducing learning rate to %s', self.gamma) + if 'model_output' in program.args: + self.output_weights() + + def output_weights(self): + print_float_precision(max(6, sfix.f // 3)) + for layer in self.layers: + layer.output_weights() class Adam(Optimizer): - def __init__(self, layers, n_epochs): - self.alpha = .001 + """ Adam/AMSgrad optimizer. + + :param layers: layers of linear graph + :param approx: use approximation for inverse square root (bool) + :param amsgrad: use AMSgrad (bool) + """ + def __init__(self, layers, n_epochs=1, approx=False, amsgrad=False, + normalize=False): + self.gamma = MemValue(cfix(.001)) self.beta1 = 0.9 self.beta2 = 0.999 - self.epsilon = 10 ** -8 + self.beta1_power = MemValue(cfix(1)) + self.beta2_power = MemValue(cfix(1)) + self.epsilon = max(2 ** -((sfix.k - sfix.f - 8) / (1 + approx)), 10 ** -8) self.n_epochs = n_epochs + self.approx = approx + self.amsgrad = amsgrad + self.normalize = normalize + if amsgrad: + print_str('Using AMSgrad ') + else: + print_str('Using Adam ') + if approx: + print_ln('with inverse square root approximation') + else: + print_ln('with more precise inverse square root') + if normalize: + print_ln('Normalize gradient') self.layers = layers self.ms = [] self.vs = [] self.gs = [] self.thetas = [] + self.vhats = [] for layer in layers: for nabla in layer.nablas(): self.gs.append(nabla) for x in self.ms, self.vs: x.append(nabla.same_shape()) + if amsgrad: + self.vhats.append(nabla.same_shape()) for theta in layer.thetas(): self.thetas.append(theta) - self.mhat_factors = Array(n_epochs, sfix) - self.vhat_factors = Array(n_epochs, sfix) - - for i in range(n_epochs): - for factors, beta in ((self.mhat_factors, self.beta1), - (self.vhat_factors, self.beta2)): - factors[i] = 1. / (1 - beta ** (i + 1)) - - def update(self, i_epoch): - for m, v, g, theta in zip(self.ms, self.vs, self.gs, self.thetas): - @for_range_opt(len(m)) - def _(k): - m[k] = self.beta1 * m[k] + (1 - self.beta1) * g[k] - v[k] = self.beta2 * v[k] + (1 - self.beta2) * g[k] ** 2 - mhat = m[k] * self.mhat_factors[i_epoch] - vhat = v[k] * self.vhat_factors[i_epoch] - theta[k] = theta[k] - self.alpha * mhat / \ - mpc_math.sqrt(vhat) + self.epsilon + super(Adam, self).__init__() + + def update(self, i_epoch, batch): + self.beta1_power *= self.beta1 + self.beta2_power *= self.beta2 + m_factor = MemValue(1 / (1 - self.beta1_power)) + v_factor = MemValue(1 / (1 - self.beta2_power)) + for i_layer, (m, v, g, theta) in enumerate(zip(self.ms, self.vs, + self.gs, self.thetas)): + if self.normalize: + abs_g = g.same_shape() + @multithread(self.n_threads, g.total_size()) + def _(base, size): + abs_g.assign_vector(abs(g.get_vector(base, size)), base) + max_g = tree_reduce_multithread(self.n_threads, + util.max, abs_g.get_vector()) + scale = MemValue(sfix._new(library.AppRcr( + max_g.v, max_g.k, max_g.f, simplex_flag=True))) + @multithread(self.n_threads, m.total_size()) + def _(base, size): + m_part = m.get_vector(base, size) + v_part = v.get_vector(base, size) + g_part = g.get_vector(base, size) + if self.normalize: + g_part *= scale.expand_to_vector(size) + m_part = self.beta1 * m_part + (1 - self.beta1) * g_part + v_part = self.beta2 * v_part + (1 - self.beta2) * g_part ** 2 + m.assign_vector(m_part, base) + v.assign_vector(v_part, base) + if self.amsgrad: + vhat = self.vhats [i_layer].get_vector(base, size) + vhat = util.max(vhat, v_part) + self.vhats[i_layer].assign_vector(vhat, base) + diff = self.gamma.expand_to_vector(size) * m_part + else: + mhat = m_part * m_factor.expand_to_vector(size) + vhat = v_part * v_factor.expand_to_vector(size) + diff = self.gamma.expand_to_vector(size) * mhat + if self.approx: + diff *= mpc_math.InvertSqrt(vhat + self.epsilon ** 2) + else: + diff /= mpc_math.sqrt(vhat) + self.epsilon + theta.assign_vector(theta.get_vector(base, size) - diff, base) class SGD(Optimizer): """ Stochastic gradient descent. @@ -1750,17 +2184,7 @@ def __init__(self, layers, n_epochs, debug=False, report_loss=None): self.delta_thetas.append(theta.same_shape()) self.gamma = MemValue(cfix(0.01)) self.debug = debug - if report_loss is None: - self.report_loss = layers[-1].compute_loss - else: - self.report_loss = report_loss - self.tol = 0.000 - self.X_by_label = None - self.print_update_average = False - self.print_losses = False - self.print_loss_reduction = False - self.i_epoch = MemValue(0) - self.stopped_on_loss = MemValue(0) + super(SGD, self).__init__(report_loss) @_no_mem_warnings def reset(self, X_by_label=None): @@ -1778,10 +2202,7 @@ def _(i): self.layers[-1].Y[j] = label for y in self.delta_thetas: y.assign_all(0) - for layer in self.layers: - layer.reset() - self.i_epoch.write(0) - self.stopped_on_loss.write(0) + super(SGD, self).reset() def update(self, i_epoch, batch): for nabla, theta, delta_theta in zip(self.nablas, self.thetas, diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py index 50da62c27..87d11def8 100644 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -846,3 +846,54 @@ def acos(x): """ y = asin(x) return pi_over_2 - y + + +def tanh(x): + """ + Hyperbolic tangent. For efficiency, accuracy is diminished + around :math:`\pm \log(k - f - 2) / 2` where :math:`k` and + :math:`f` denote the fixed-point parameters. + """ + limit = math.log(2 ** (x.k - x.f - 2)) / 2 + s = x < -limit + t = x > limit + y = pow_fx(math.e, 2 * x) + return s.if_else(-1, t.if_else(1, (y - 1) / (y + 1))) + + +# next functions due to https://dl.acm.org/doi/10.1145/3411501.3419427 + +def Sep(x): + b = floatingpoint.PreOR(list(reversed(x.v.bit_decompose(x.k, maybe_mixed=True)))) + t = x.v * (1 + x.v.bit_compose(b_i.bit_not() for b_i in b[-2 * x.f + 1:])) + u = types.sfix._new(t.right_shift(x.f, 2 * x.k, signed=False)) + b += [b[0].long_one()] + return u, [b[i + 1] - b[i] for i in reversed(range(x.k))] + +def SqrtComp(z, old=False): + f = types.sfix.f + k = len(z) + if isinstance(z[0], types.sint): + return types.sfix._new(sum(z[i] * types.cfix( + 2 ** (-(i - f + 1) / 2)).v for i in range(k))) + k_prime = k // 2 + f_prime = f // 2 + c1 = types.sfix(2 ** ((f + 1) / 2 + 1)) + c0 = types.sfix(2 ** (f / 2 + 1)) + a = [z[2 * i].bit_or(z[2 * i + 1]) for i in range(k_prime)] + tmp = types.sfix._new(types.sint.bit_compose(reversed(a[:2 * f_prime]))) + if old: + b = sum(types.sint.conv(zi).if_else(i, 0) for i, zi in enumerate(z)) % 2 + else: + b = util.tree_reduce(lambda x, y: x.bit_xor(y), z[::2]) + return types.sint.conv(b).if_else(c1, c0) * tmp + +@types.vectorize +def InvertSqrt(x, old=False): + """ + Reciprocal square root approximation by `Lu et al. + `_ + """ + u, z = Sep(x) + c = 3.14736 + u * (4.63887 * u - 5.77789) + return c * SqrtComp(z, old=old) diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py index 6af4b414e..3ef709ca8 100644 --- a/Compiler/non_linear.py +++ b/Compiler/non_linear.py @@ -44,7 +44,7 @@ def eqz(self, a, k): d = [None]*k for i,b in enumerate(r[0].bit_decompose_clear(c, k)): d[i] = r[i].bit_xor(b) - return 1 - types.sint.conv(self.kor(d)) + return 1 - types.sintbit.conv(self.kor(d)) class Prime(Masking): """ Non-linear functionality modulo a prime with statistical masking. """ @@ -71,8 +71,11 @@ def _mask(self, a, k): def _trunc_pr(self, a, k, m, signed=None): return TruncPrField(a, k, m, self.kappa) - def bit_dec(self, a, k, m): - return BitDecField(a, k, m, self.kappa) + def bit_dec(self, a, k, m, maybe_mixed=False): + if maybe_mixed: + return BitDecFieldRaw(a, k, m, self.kappa) + else: + return BitDecField(a, k, m, self.kappa) def kor(self, d): return KOR(d, self.kappa) @@ -85,7 +88,7 @@ def __init__(self, prime): def _mod2m(self, a, k, m, signed): if signed: a += cint(1) << (k - 1) - return sint.bit_compose(self.bit_dec(a, k, k)[:m]) + return sint.bit_compose(self.bit_dec(a, k, k, True)[:m]) def _trunc_pr(self, a, k, m, signed): # nearest truncation @@ -96,14 +99,14 @@ def trunc_round_nearest(self, a, k, m, signed): if signed: a += cint(1) << (k - 1) k += 1 - res = sint.bit_compose(self.bit_dec(a, k, k)[m:]) + res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:]) if signed: res -= cint(1) << (k - m - 2) return res - def bit_dec(self, a, k, m): + def bit_dec(self, a, k, m, maybe_mixed=False): assert k < self.prime.bit_length() - bits = BitDecFull(a) + bits = BitDecFull(a, maybe_mixed=maybe_mixed) if len(bits) < m: raise CompilerError('%d has fewer than %d bits' % (self.prime, m)) return bits[:m] @@ -111,7 +114,7 @@ def bit_dec(self, a, k, m): def eqz(self, a, k): # always signed a += two_power(k) - return 1 - KORL(self.bit_dec(a, k, k)) + return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True))) class Ring(Masking): """ Non-linear functionality modulo a power of two known at compile time. @@ -130,8 +133,11 @@ def _mask(self, a, k): def _trunc_pr(self, a, k, m, signed): return TruncPrRing(a, k, m, signed=signed) - def bit_dec(self, a, k, m): - return BitDecRing(a, k, m) + def bit_dec(self, a, k, m, maybe_mixed=False): + if maybe_mixed: + return BitDecRingRaw(a, k, m) + else: + return BitDecRing(a, k, m) def kor(self, d): return KORL(d) diff --git a/Compiler/program.py b/Compiler/program.py index 74e03a1bb..b004c48c6 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -28,9 +28,7 @@ square = 1, bit = 2, inverse = 3, - bittriple = 4, - bitgf2ntriple = 5, - dabit = 6, + dabit = 4, ) field_types = dict( @@ -62,6 +60,7 @@ class defaults: asmoutfile = None stop = False insecure = False + keep_cisc = False class Program(object): """ A program consists of a list of tapes representing the whole @@ -80,14 +79,14 @@ def __init__(self, args, options=defaults): self.init_names(args) self._security = 40 self.prime = None + self.tapes = [] if sum(x != 0 for x in(options.ring, options.field, options.binary)) > 1: raise CompilerError('can only use one out of -B, -R, -F') if options.prime and (options.ring or options.binary): raise CompilerError('can only use one out of -B, -R, -p') if options.ring: - self.bit_length = int(options.ring) - 1 - self.non_linear = Ring(int(options.ring)) + self.set_ring_size(int(options.ring)) else: self.bit_length = int(options.binary) or int(options.field) if options.prime: @@ -108,7 +107,6 @@ def __init__(self, args, options=defaults): if self.verbose: print('Galois length:', self.galois_length) self.tape_counter = 0 - self.tapes = [] self._curr_tape = None self.DEBUG = options.debug self.allocated_mem = RegType.create_dict(lambda: USER_MEM) @@ -204,6 +202,16 @@ def init_names(self, args): for arg in args[1:]) self.progname = progname + def set_ring_size(self, ring_size): + from .non_linear import Ring + for tape in self.tapes: + prev = tape.req_bit_length['p'] + if prev and prev != ring_size: + raise CompilerError('cannot have different ring sizes') + self.bit_length = ring_size - 1 + self.non_linear = Ring(ring_size) + self.options.ring = str(ring_size) + def new_tape(self, function, args=[], name=None, single_thread=False): """ Create a new tape from a function. See @@ -414,7 +422,7 @@ def finalize_memory(self): self.curr_tape.start_new_basicblock(None, 'memory-usage') # reset register counter to 0 self.curr_tape.init_registers() - for mem_type,size in list(self.allocated_mem.items()): + for mem_type,size in sorted(self.allocated_mem.items()): if size: #print "Memory of type '%s' of size %d" % (mem_type, size) if mem_type in self.types: @@ -488,7 +496,7 @@ def use_split(self, change=None): else: if change and not self.options.ring: raise CompilerError('splitting only supported for rings') - assert change > 1 + assert change > 1 or change == False self._split = change def use_square(self, change=None): @@ -575,7 +583,7 @@ def __init__(self, parent, name, scope, exit_condition=None): scope.children.append(self) self.alloc_pool = scope.alloc_pool else: - self.alloc_pool = defaultdict(set) + self.alloc_pool = defaultdict(list) self.purged = False self.n_rounds = 0 self.n_to_merge = 0 @@ -647,9 +655,14 @@ def add_usage(self, req_node): def expand_cisc(self): new_instructions = [] + if self.parent.program.options.keep_cisc: + skip = ['LTZ', 'Trunc'] + else: + skip = [] for inst in self.instructions: - new_instructions.extend(inst.expand_merged()) - self.n_rounds += inst.expanded_rounds() + new_inst, n_rounds = inst.expand_merged(skip) + new_instructions.extend(new_inst) + self.n_rounds += n_rounds self.instructions = new_instructions def __str__(self): @@ -774,7 +787,10 @@ def optimize(self, options): # allocate registers reg_counts = self.count_regs() - if not options.noreallocate: + if options.noreallocate: + if self.program.verbose: + print('Tape register usage:', dict(reg_counts)) + else: if self.program.verbose: print('Tape register usage before re-allocation:', dict(reg_counts)) @@ -1071,7 +1087,7 @@ def __init__(self, reg_type, program, size=None, i=None): if size is None: size = Compiler.instructions_base.get_global_vector_size() if size is not None and size > self.maximum_size: - raise CompilerError('vector too large') + raise CompilerError('vector too large: %d' % size) self.size = size self.vectorbase = self self.relative_i = 0 diff --git a/Compiler/types.py b/Compiler/types.py index 433c6bd99..fd81afe26 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -591,12 +591,12 @@ def _store_in_mem(self, address, direct_inst, indirect_inst): def prep_res(cls, other): return cls() - @staticmethod - def bit_compose(bits): + @classmethod + def bit_compose(cls, bits): """ Compose value from bits. :param bits: iterable of any type implementing left shift """ - return sum(b << i for i,b in enumerate(bits)) + return sum(cls.conv(b) << i for i,b in enumerate(bits)) @classmethod def malloc(cls, size, creator_tape=None): @@ -840,6 +840,7 @@ def store_in_mem(self, address): def in_immediate_range(value): return value < 2**31 and value >= -2**31 + @vectorize_init def __init__(self, val=None, size=None): """ :param val: initialization (cint/regint/int/cgf2n or list thereof) @@ -1119,12 +1120,6 @@ def load_int(self, val): elif chunk: sum += chunk - def __mul__(self, other): - """ Clear :math:`\mathrm{GF}(2^n)` multiplication. - - :param other: cgf2n/regint/int """ - return super(cgf2n, self).__mul__(other) - def __neg__(self): """ Identity. """ return self @@ -1209,7 +1204,9 @@ def push(cls, value): def get_random(cls, bit_length): """ Public insecure randomness. - :param bit_length: number of bits (int) """ + :param bit_length: number of bits (int) + :param size: vector size (int, default 1) + """ if isinstance(bit_length, int): bit_length = regint(bit_length) res = cls() @@ -1582,7 +1579,9 @@ class _secret(_register): def get_input_from(cls, player): """ Secret input from player. - :param player: public (regint/cint/int) """ + :param player: public (regint/cint/int) + :param size: vector size (int, default 1) + """ res = cls() asm_input(res, player) return res @@ -1592,7 +1591,9 @@ def get_input_from(cls, player): def get_random_triple(cls): """ Secret random triple according to security model. - :return: :math:`(a, b, ab)` """ + :return: :math:`(a, b, ab)` + :param size: vector size (int, default 1) + """ res = (cls(), cls(), cls()) triple(*res) return res @@ -1602,7 +1603,9 @@ def get_random_triple(cls): def get_random_bit(cls): """ Secret random bit according to security model. - :return: 0/1 50-50 """ + :return: 0/1 50-50 + :param size: vector size (int, default 1) + """ res = cls() bit(res) return res @@ -1612,7 +1615,9 @@ def get_random_bit(cls): def get_random_square(cls): """ Secret random square according to security model. - :return: :math:`(a, a^2)` """ + :return: :math:`(a, a^2)` + :param size: vector size (int, default 1) + """ res = (cls(), cls()) square(*res) return res @@ -1622,7 +1627,9 @@ def get_random_square(cls): def get_random_inverse(cls): """ Secret random inverse tuple according to security model. - :return: :math:`(a, a^{-1})` """ + :return: :math:`(a, a^{-1})` + :param size: vector size (int, default 1) + """ res = (cls(), cls()) inverse(*res) return res @@ -1717,16 +1724,51 @@ def load_other(self, val): else: self.load_clear(self.clear_type(val)) + @classmethod + def bit_compose(cls, bits): + """ Compose value from bits. + + :param bits: iterable of any type convertible to sint """ + from Compiler.GC.types import sbits, sbitintvec + bits = list(bits) + if (program.use_edabit() or program.use_split()) and isinstance(bits[0], sbits): + if program.use_edabit(): + mask = cls.get_edabit(len(bits), strict=True, size=bits[0].n) + else: + tmp = sint(size=bits[0].n) + randoms(tmp, len(bits)) + n_overflow_bits = min(program.use_split().bit_length(), + int(program.options.ring) - len(bits)) + mask_bits = tmp.bit_decompose(len(bits) + n_overflow_bits, + maybe_mixed=True) + if n_overflow_bits: + overflow = sint.bit_compose( + sint.conv(x) for x in mask_bits[-n_overflow_bits:]) + mask = tmp - (overflow << len(bits)), \ + mask_bits[:-n_overflow_bits] + else: + mask = tmp, mask_bits + t = sbitintvec.get_type(len(bits) + 1) + masked = t.from_vec(mask[1] + [0]) + t.from_vec(bits + [0]) + overflow = masked.v[-1] + masked = cls.bit_compose(x.reveal().to_regint_by_bit() for x in masked.v[:-1]) + return masked - mask[0] + (cls(overflow) << len(bits)) + else: + return super(_secret, cls).bit_compose(bits) + @set_instruction_type @read_mem_value @vectorize def secret_op(self, other, s_inst, m_inst, si_inst, reverse=False): - cls = self.__class__ res = self.prep_res(other) + cls = type(res) if isinstance(other, regint): other = res.clear_type(other) if isinstance(other, cls): - s_inst(res, self, other) + if reverse: + s_inst(res, other, self) + else: + s_inst(res, self, other) elif isinstance(other, res.clear_type): if reverse: m_inst(res, other, self) @@ -1861,10 +1903,12 @@ def require_bit_length(n_bits): def get_random_int(cls, bits): """ Secret random n-bit number according to security model. - :param bits: compile-time integer (int) """ + :param bits: compile-time integer (int) + :param size: vector size (int, default 1) + """ if program.use_edabit(): return sint.get_edabit(bits, True)[0] - elif program.use_split() > 2: + elif program.use_split() > 2 and program.use_split() < 5: tmp = sint() randoms(tmp, bits) x = tmp.split_to_two_summands(bits, True) @@ -1882,7 +1926,10 @@ def get_random_int(cls, bits): @vectorized_classmethod def get_random(cls): - """ Secret random ring element according to security model. """ + """ Secret random ring element according to security model. + + :param size: vector size (int, default 1) + """ res = sint() randomfulls(res) return res @@ -1891,7 +1938,9 @@ def get_random(cls): def get_input_from(cls, player): """ Secret input. - :param player: public (regint/cint/int) """ + :param player: public (regint/cint/int) + :param size: vector size (int, default 1) + """ res = cls() inputmixed('int', res, player) return res @@ -1915,7 +1964,7 @@ def get_edabit(cls, n_bits, strict=False): else: a = [sint.get_random_bit() for i in range(n_bits)] return sint.bit_compose(a), a - program.curr_tape.require_bit_length(n_bits) + program.curr_tape.require_bit_length(n_bits - 1) whole = cls() size = get_global_vector_size() from Compiler.GC.types import sbits, sbitvec @@ -1931,6 +1980,7 @@ def long_one(): return 1 @staticmethod + @vectorize def bit_decompose_clear(a, n_bits): return floatingpoint.bits(a, n_bits) @@ -2055,7 +2105,7 @@ def __lt__(self, other, bit_length=None, security=None): :param other: sint/cint/regint/int :return: 0/1 (sint) """ - res = sint() + res = sintbit() comparison.LTZ(res, self - other, (bit_length or program.bit_length) + 1, security or program.security) @@ -2064,7 +2114,7 @@ def __lt__(self, other, bit_length=None, security=None): @read_mem_value @vectorize def __gt__(self, other, bit_length=None, security=None): - res = sint() + res = sintbit() comparison.LTZ(res, other - self, (bit_length or program.bit_length) + 1, security or program.security) @@ -2185,13 +2235,14 @@ def __rrshift__(self, other): return floatingpoint.Trunc(other, program.bit_length, self, program.security) @vectorize - def bit_decompose(self, bit_length=None, security=None): + def bit_decompose(self, bit_length=None, security=None, maybe_mixed=False): """ Secret bit decomposition. """ if bit_length == 0: return [] bit_length = bit_length or program.bit_length - security = security or program.security - return floatingpoint.BitDec(self, bit_length, bit_length, security) + assert program.security == security or program.security + return program.non_linear.bit_dec(self, bit_length, bit_length, + maybe_mixed) def TruncMul(self, other, k, m, kappa=None, nearest=False): return (self * other).round(k, m, kappa, nearest, signed=True) @@ -2249,6 +2300,7 @@ def two_power(n): return floatingpoint.two_power(n) def split_to_n_summands(self, length, n): + comparison.require_ring_size(length, 'splitting') from .GC.types import sbits from .GC.instructions import split columns = [[sbits.get_type(self.size)() @@ -2274,7 +2326,9 @@ def raw_mod2m(self, m): @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. - Result potentially written to ``Player-Data/Private-Output-P.`` + Result potentially written to + ``Player-Data/Private-Output-P``, but not if + :py:obj:`player` is a :py:class:`regint`. :param player: public integer (int/regint/cint): :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` @@ -2288,6 +2342,65 @@ def reveal_to(self, player): else: return super(sint, self).reveal_to(player) +class sintbit(sint): + @classmethod + def prep_res(cls, other): + return sint() + + def load_other(self, other): + if isinstance(other, sint): + movs(self, other) + else: + super(sintbit, self).load_other(other) + + @vectorize + def __and__(self, other): + if isinstance(other, sintbit): + res = sintbit() + muls(res, self, other) + return res + elif util.is_zero(other): + return 0 + elif util.is_one(other): + return self + else: + return NotImplemented + + @vectorize + def __or__(self, other): + if isinstance(other, sintbit): + res = sintbit() + adds(res, self, other - self * other) + return res + elif util.is_zero(other): + return self + elif util.is_one(other): + return 1 + else: + return NotImplemented + + @vectorize + def __xor__(self, other): + if isinstance(other, sintbit): + res = sintbit() + adds(res, self, other - 2 * self * other) + return res + elif util.is_zero(other): + return self + elif util.is_one(other): + return 1 + else: + return NotImplemented + + @vectorize + def __rsub__(self, other): + if util.is_one(other): + res = sintbit() + subsfi(res, self, 1) + return res + else: + return super(sintbit, self).__rsub__(other) + class sgf2n(_secret, _gf2n): """ Secret :math:`\mathrm{GF}(2^n)` value. """ __slots__ = [] @@ -2437,10 +2550,11 @@ def bit_decompose_embedding(self): return [self.clear_type((masked >> wanted_positions[i]) & one) + r for i,r in enumerate(random_bits)] for t in (sint, sgf2n): - t.bit_type = t t.basic_type = t t.default_type = t +sint.bit_type = sintbit +sgf2n.bit_type = sgf2n class _bitint(object): bits = None @@ -3046,14 +3160,17 @@ def _new(cls, other, k=None, f=None): @staticmethod def int_rep(v, f, k=None): + if isinstance(v, regint): + v = cint(v) res = v * (2 ** f) try: res = int(round(res)) - if k and abs(res) >= 2 ** k: + if k and res >= 2 ** (k - 1) or res < -2 ** (k - 1): + limit = 2 ** (k - f - 1) raise CompilerError( - 'Value out of fixed-point range (maximum %d). ' + 'Value out of fixed-point range [-%d, %d). ' 'Use `sfix.set_precision(f, k)` with k being at least f+%d' - % (2 ** (k - f), math.ceil(math.log(abs(v), 2)) + 1)) + % (limit, limit, res.bit_length() - f + 1)) except TypeError: pass return res @@ -3268,6 +3385,14 @@ def __truediv__(self, other): else: raise TypeError('Incompatible fixed point types in division') + @vectorize + def __rtruediv__(self, other): + """ Fixed-point division. + + :param other: sfix/sint/cfix/cint/regint/int """ + other = parse_type(other, self.k, self.f) + return other / self + def print_plain(self): """ Clear fixed-point output. """ print_float_plain(cint.conv(self.v), cint(-self.f), \ @@ -3468,7 +3593,7 @@ def set_precision(cls, f, k = None): set_precision = classmethod(set_precision) @classmethod - def set_precision_from_args(cls, program): + def set_precision_from_args(cls, program, adapt_ring=False): f = None k = None for arg in program.args: @@ -3484,6 +3609,15 @@ def set_precision_from_args(cls, program): cfix.set_precision(f, k) elif k is not None: raise CompilerError('need to set fractional precision') + if 'nearest' in program.args: + print('Nearest rounding instead of proabilistic ' + 'for fixed-point computation') + cls.round_nearest = True + if adapt_ring and program.options.ring: + need = 2 ** int(math.ceil(math.log(2 * cls.k, 2))) + if need != int(program.options.ring): + print('Changing computation modulus to 2^%d' % need) + program.set_ring_size(need) @classmethod def coerce(cls, other): @@ -3609,11 +3743,14 @@ def __truediv__(self, other): :param other: sfix/cfix/sint/cint/regint/int """ if util.is_constant_float(other): assert other != 0 - other_length = self.f + math.ceil(math.log(abs(other), 2)) - if other_length >= self.k: - factor = 2 ** (self.k - other_length - 1) + log = math.ceil(math.log(abs(other), 2)) + other_length = self.f + log + if other_length >= self.k - 1: + factor = 2 ** (self.k - other_length - 2) self *= factor other *= factor + if 2 ** log == other: + return self * 2 ** -log other = self.coerce(other) assert self.k == other.k assert self.f == other.f @@ -3660,7 +3797,9 @@ class sfix(_fix): def get_input_from(cls, player): """ Secret fixed-point input. - :param player: public (regint/cint/int) """ + :param player: public (regint/cint/int) + :param size: vector size (int, default 1) + """ cls.int_type.require_bit_length(cls.k) v = cls.int_type() inputmixed('fix', v, cls.f, player) @@ -3677,6 +3816,7 @@ def get_random(cls, lower, upper): :param lower: float :param upper: float + :param size: vector size (int, default 1) """ log_range = int(math.log(upper - lower, 2)) n_bits = log_range + cls.f @@ -3732,7 +3872,8 @@ def multipliable(v, k, f, size): def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. Raw representation possibly written to - ``Player-Data/Private-Output-P.`` + ``Player-Data/Private-Output-P``, but not if + :py:obj:`player` is a :py:class:`regint`. :param player: public integer (int/regint/cint) :returns: value to be used with :py:func:`~Compiler.library.print_ln_to` @@ -4066,7 +4207,9 @@ def convert_float(v, vlen, plen): def get_input_from(cls, player): """ Secret floating-point input. - :param player: public (regint/cint/int) """ + :param player: public (regint/cint/int) + :param size: vector size (int, default 1) + """ v = sint() p = sint() z = sint() @@ -4444,6 +4587,7 @@ def __init__(self, length, value_type, address=None, debug=None, alloc=True): self.address_cache = {} self.debug = debug self.creator_tape = program.curr_tape + self.sink = None if alloc: self.alloc() @@ -4514,6 +4658,17 @@ def f(i): return self._store(value, self.get_address(index)) + def maybe_get(self, condition, index): + return condition * self[condition * index] + + def maybe_set(self, condition, index, value): + if self.sink is None: + self.sink = self.value_type.Array(1) + addresses = (condition.if_else(x, y) for x, y in + zip(util.tuplify(self.get_address(index)), + util.tuplify(self.sink.get_address(0)))) + self._store(value, util.untuplify(tuple(addresses))) + # the following two are useful for compile-time lengths # and thus differ from the usual Python syntax def get_range(self, start, size): @@ -4590,11 +4745,22 @@ def get_vector(self, base=0, size=None): get_part_vector = get_vector + def get_part(self, base, size): + return Array(size, self.value_type, self.get_address(base)) + def get(self, indices): return self.value_type.load_mem( regint.inc(len(indices), self.address, 0) + indices, size=len(indices)) + def get_slice_vector(self, slice): + assert self.value_type.n_elements() == 1 + assert len(slice) <= self.total_size() + base = regint.inc(len(slice), slice.address, 1, 1) + inc = regint.inc(len(slice), 0, 1, 1, 1) + addresses = slice.value_type.load_mem(base) + inc + return self.value_type.load_mem(self.address + addresses) + def expand_to_vector(self, index, size): assert self.value_type.n_elements() == 1 address = regint(size=size) @@ -4641,6 +4807,12 @@ def __mul__(self, value): :param other: vector or container of same length and type that supports operations with type of this array """ return self.get_vector() * value + def __truediv__(self, value): + """ Vector division. + + :param other: vector or container of same length and type that supports operations with type of this array """ + return self.get_vector() / value + def __pow__(self, value): """ Vector power-of computation. @@ -4674,6 +4846,16 @@ def reveal_list(self): reveal_nested = reveal_list + def sort(self, n_threads=None): + """ + Sort in place using Batchers' odd-even merge mergesort + with complexity :math:`O(n (\log n)^2)`. + + :param n_threads: number of threads to use (single thread by + default) + """ + library.loopy_odd_even_merge_sort(self, n_threads=n_threads) + def __str__(self): return '%s array of length %s at %s' % (self.value_type, len(self), self.address) @@ -4784,6 +4966,15 @@ def assign_part_vector(self, vector, base=0): assert vector.size <= self.total_size() vector.store_in_mem(self.address + base * part_size) + def get_slice_vector(self, slice): + assert self.value_type.n_elements() == 1 + part_size = reduce(operator.mul, self.sizes[1:]) + assert len(slice) * part_size <= self.total_size() + base = regint.inc(len(slice) * part_size, slice.address, 1, part_size) + inc = regint.inc(len(slice) * part_size, 0, 1, 1, part_size) + addresses = slice.value_type.load_mem(base) * part_size + inc + return self.value_type.load_mem(self.address + addresses) + def get_addresses(self, *indices): assert self.value_type.n_elements() == 1 assert len(indices) == len(self.sizes) @@ -4816,6 +5007,10 @@ def same_shape(self): """ :return: new multidimensional array with same shape and basic type """ return MultiArray(self.sizes, self.value_type) + def get_part(self, start, size): + return MultiArray([size] + list(self.sizes[1:]), self.value_type, + address=self[start].address) + def input_from(self, player, budget=None, raw=False): """ Fill with inputs from player if supported by type. @@ -4978,7 +5173,7 @@ def direct_mul_trans(self, other, reduce=True, indices=None): indices = [regint.inc(i) for i in self.sizes + other.sizes[::-1]] assert len(indices[1]) == len(indices[2]) indices = list(indices) - indices[3] *= other.sizes[0] + indices[3] *= other.sizes[1] return self.value_type.direct_matrix_mul( self.address, other.address, None, self.sizes[1], 1, reduce=reduce, indices=indices) diff --git a/Compiler/util.py b/Compiler/util.py index fa41f41c0..aa491e422 100644 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -195,7 +195,7 @@ def is_all_ones(x, n): else: return False -def max(x, y=None): +def max(x, y=None, n_threads=None): if y is None: return tree_reduce(max, x) else: diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp index ad27fe5e0..9fff469ea 100644 --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -7,6 +7,7 @@ #include "Networking/CryptoPlayer.h" #include "Math/gfp.h" #include "ECDSA/P256Element.h" +#include "GC/VectorInput.h" #include "ECDSA/preprocessing.hpp" #include "ECDSA/sign.hpp" @@ -20,6 +21,8 @@ #include "Protocols/MascotPrep.hpp" #include "GC/Secret.hpp" #include "GC/TinyPrep.hpp" +#include "GC/VectorProtocol.hpp" +#include "GC/CcdPrep.hpp" #include "OT/NPartyTripleGenerator.hpp" #include diff --git a/ECDSA/mascot-ecdsa-party.cpp b/ECDSA/mascot-ecdsa-party.cpp index 0cc65edf2..dc2edab31 100644 --- a/ECDSA/mascot-ecdsa-party.cpp +++ b/ECDSA/mascot-ecdsa-party.cpp @@ -5,6 +5,7 @@ #include "GC/TinierSecret.h" #include "GC/TinyMC.h" +#include "GC/VectorInput.h" #include "Protocols/Share.hpp" #include "Protocols/MAC_Check.hpp" diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 9d86d589d..de655e13e 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -19,6 +19,8 @@ #include "Processor/Data_Files.hpp" #include "Processor/Input.hpp" #include "GC/TinyPrep.hpp" +#include "GC/VectorProtocol.hpp" +#include "GC/CcdPrep.hpp" #include diff --git a/ECDSA/preprocessing.hpp b/ECDSA/preprocessing.hpp index fb9f6001c..334d5d1ba 100644 --- a/ECDSA/preprocessing.hpp +++ b/ECDSA/preprocessing.hpp @@ -13,7 +13,6 @@ #include "Protocols/MaliciousShamirShare.h" #include "Protocols/Rep3Share.h" #include "GC/TinierSecret.h" -#include "GC/TinierPrep.h" #include "GC/MaliciousCcdSecret.h" #include "GC/TinyMC.h" @@ -128,16 +127,4 @@ void check(vector>& tuples, T sk, MC.Check(P); } -template<> -void ReplicatedPrep>::buffer_bits() -{ - throw not_implemented(); -} - -template<> -void ReplicatedPrep>::buffer_bits() -{ - throw not_implemented(); -} - #endif /* ECDSA_PREPROCESSING_HPP_ */ diff --git a/FHE/AddableVector.h b/FHE/AddableVector.h index 1540f9424..bbf0b112e 100644 --- a/FHE/AddableVector.h +++ b/FHE/AddableVector.h @@ -149,11 +149,6 @@ class AddableVector: public vector return res; } - bool is_binary() const - { - throw not_implemented(); - } - size_t report_size(ReportType type) { size_t res = 4; diff --git a/FHE/AddableVector.cpp b/FHE/AddableVector.hpp similarity index 81% rename from FHE/AddableVector.cpp rename to FHE/AddableVector.hpp index a99c05931..6a6f3dc03 100644 --- a/FHE/AddableVector.cpp +++ b/FHE/AddableVector.hpp @@ -6,6 +6,7 @@ #include "AddableVector.h" #include "Rq_Element.h" #include "FHE_Keys.h" +#include "P2Data.h" template AddableVector AddableVector::mul_by_X_i(int j, @@ -33,7 +34,3 @@ AddableVector AddableVector::mul_by_X_i(int j, } return res; } - -template -AddableVector AddableVector< - Int_Random_Coins::rand_type>::mul_by_X_i(int j, const FHE_PK& pk) const; diff --git a/FHE/Ciphertext.h b/FHE/Ciphertext.h index 3247450c8..d455f1268 100644 --- a/FHE/Ciphertext.h +++ b/FHE/Ciphertext.h @@ -23,8 +23,6 @@ class Ciphertext word pk_id; public: - static string type_string() { return "ciphertext"; } - static int t() { return 0; } static int size() { return 0; } const FHE_Params& get_params() const { return *params; } @@ -41,8 +39,6 @@ class Ciphertext set(a0, a1, C.get_pk_id()); } - ~Ciphertext() { ; } - // Rely on default copy assignment/constructor word get_pk_id() const { return pk_id; } diff --git a/FHE/DiscreteGauss.cpp b/FHE/DiscreteGauss.cpp index 415ef25ed..abfd8a49b 100644 --- a/FHE/DiscreteGauss.cpp +++ b/FHE/DiscreteGauss.cpp @@ -32,52 +32,6 @@ int DiscreteGauss::sample(PRNG &G, int stretch) const -void RandomVectors::set(int nn,int hh,double R) -{ - n=nn; - h=hh; - DG.set(R); -} - -void RandomVectors::set_n(int nn) -{ - n = nn; -} - -vector RandomVectors::sample_Gauss(PRNG& G, int stretch) const -{ - vector ans(n); - for (int i=0; i RandomVectors::sample_Hwt(PRNG& G) const -{ - if (h > n/2 or h <= 0) { return sample_Gauss(G); } - vector ans(n); - for (int i=0; i RandomVectors::sample_Half(PRNG& G) const -{ - vector ans(n); - for (int i=0; i RandomVectors::sample_Uniform(PRNG& G,const bigint& B) const -{ - vector ans(n); - bigint v; - for (int i=0; i sample_Gauss(PRNG& G, int stretch = 1) const; - - // Next samples from Hwt distribution unless hwt>n/2 in which - // case it uses Gauss - vector sample_Hwt(PRNG& G) const; - - // Sample from {-1,0,1} with Pr(-1)=Pr(1)=1/4 and Pr(0)=1/2 - vector sample_Half(PRNG& G) const; - - // Sample from (-B,0,B) with uniform prob - vector sample_Uniform(PRNG& G,const bigint& B) const; - - bool operator!=(const RandomVectors& other) const; -}; - template class RandomGenerator : public Generator { @@ -103,7 +58,7 @@ class UniformGenerator : public RandomGenerator void get(T& x) const { this->G.get(x, n_bits, positive); } }; -template +template class GaussianGenerator : public RandomGenerator { DiscreteGauss DG; diff --git a/FHE/FFT.cpp b/FHE/FFT.cpp index f15145250..7552e5b4a 100644 --- a/FHE/FFT.cpp +++ b/FHE/FFT.cpp @@ -1,6 +1,7 @@ #include "FHE/FFT.h" #include "Math/Zp_Data.h" +#include "Processor/BaseMachine.h" #include "Math/modp.hpp" @@ -115,17 +116,38 @@ void FFT_Iter(vector& ioput, int n, const T& root, const P& PrD) */ void FFT_Iter2(vector& ioput, int n, const modp& root, const Zp_Data& PrD) { + FFT_Iter(ioput, n, root, PrD, false); +} + +void FFT_Iter2(vector& ioput, int n, const vector& roots, + const Zp_Data& PrD) +{ + FFT_Iter(ioput, n, roots, PrD, false); +} + +void FFT_Iter(vector& ioput, int n, const modp& root, const Zp_Data& PrD, + bool start_with_one) +{ + vector roots(n + 1); + assignOne(roots[0], PrD); + for (int i = 1; i < n + 1; i++) + Mul(roots[i], roots[i - 1], root, PrD); + FFT_Iter(ioput, n, roots, PrD, start_with_one); +} + +void FFT_Iter(vector& ioput, int n, const vector& roots, + const Zp_Data& PrD, bool start_with_one) +{ + assert(roots.size() > size_t(n)); + int i, j, m; - modp t; // Bit-reversal of input for( i = j = 0; i < n; ++i ) { if( j >= i ) { - t = ioput[i]; - ioput[i] = ioput[j]; - ioput[j] = t; + swap(ioput[i], ioput[j]); } m = n / 2; @@ -136,27 +158,38 @@ void FFT_Iter2(vector& ioput, int n, const modp& root, const Zp_Data& PrD) } j += m; } - modp u, alpha, alpha2; m = 0; j = 0; i = 0; // Do the transform + vector alpha2; + alpha2.reserve(n / 2); for (int s = 1; s < n; s = 2*s) { m = 2*s; - Power(alpha, root, n/m, PrD); - alpha2 = alpha; - Mul(alpha, alpha, alpha, PrD); - for (int j = 0; j < m/2; ++j) + + alpha2.clear(); + if (start_with_one) { - //root = root_table[(2*j+1)*n/m]; - for (int k = j; k < n; k += m) - { - Mul(t, alpha2, ioput[k + m/2], PrD); - u = ioput[k]; - Add(ioput[k], u, t, PrD); - Sub(ioput[k + m/2], u, t, PrD); - } - Mul(alpha2, alpha2, alpha, PrD); + for (int j = 0; j < m / 2; j++) + alpha2.push_back(roots[j * n / m]); + } + else + { + for (int j = 0; j < m / 2; j++) + alpha2.push_back(roots.at((j * 2 + 1) * (n / m))); + } + + if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton()) + { + auto& queues = BaseMachine::s().queues; + FftJob job(ioput, alpha2, m, PrD); + int start = queues.distribute(job, n / 2); + for (int i = start; i < n / 2; i++) + FFT_Iter2_body(ioput, alpha2, i, m, PrD); + queues.wrap_up(job); } + else + for (int i = 0; i < n / 2; i++) + FFT_Iter2_body(ioput, alpha2, i, m, PrD); } } diff --git a/FHE/FFT.h b/FHE/FFT.h index b0935d65e..c41563b49 100644 --- a/FHE/FFT.h +++ b/FHE/FFT.h @@ -30,8 +30,29 @@ void FFT2(vector& a,int N,const modp& theta,const Zp_Data& PrD); template void FFT_Iter(vector& a,int N,const T& theta,const P& PrD); +void FFT_Iter(vector& a, int N, const modp& theta, const Zp_Data& PrD, + bool start_with_one = true); void FFT_Iter2(vector& a,int N,const modp& theta,const Zp_Data& PrD); +// variants with precomputed roots + +void FFT_Iter(vector& a, int N, const vector& theta, + const Zp_Data& PrD, bool start_with_one = true); +void FFT_Iter2(vector& a, int N, const vector& theta, + const Zp_Data& PrD); + +inline void FFT_Iter2_body(vector& ioput, const vector& alpha2, int i, + int m, const Zp_Data& PrD) +{ + int j = i % (m / 2); + int kk = i / (m / 2); + int k = j + kk * m; + modp t, u; + Mul(t, alpha2[j], ioput[k + m / 2], PrD); + u = ioput[k]; + Add(ioput[k], u, t, PrD); + Sub(ioput[k + m / 2], u, t, PrD); +} /* BFFT perform FFT and inverse FFT mod PrD for non power of two cyclotomics. * The modulus in PrD (contained in FFT_Data) must be set up diff --git a/FHE/FFT_Data.cpp b/FHE/FFT_Data.cpp index ecf87ac92..c71a4c5da 100644 --- a/FHE/FFT_Data.cpp +++ b/FHE/FFT_Data.cpp @@ -6,24 +6,6 @@ #include "Math/modp.hpp" -void FFT_Data::assign(const FFT_Data& FFTD) -{ - prData=FFTD.prData; - R=FFTD.R; - - root=FFTD.root; - twop=FFTD.twop; - - two_root=FFTD.two_root; - powers=FFTD.powers; - powers_i=FFTD.powers_i; - b=FFTD.b; - - iphi=FFTD.iphi; - -} - - void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD) { @@ -49,6 +31,7 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD) Inv(root[1],root[0],PrD); to_modp(iphi,Rg.phi_m(),PrD); Inv(iphi,iphi,PrD); + compute_roots(Rg.m()); } } else @@ -57,6 +40,7 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD) { throw invalid_params(); } root[0]=Find_Primitive_Root_2m(Rg.m(),Rg.Phi(),PrD); Inv(root[1],root[0],PrD); + compute_roots(2 * Rg.m()); int ptwop=twop; if (twop<0) { ptwop=-twop; } @@ -97,6 +81,14 @@ void FFT_Data::init(const Ring& Rg,const Zp_Data& PrD) } } +void FFT_Data::compute_roots(int n) +{ + roots.resize(n + 1); + assignOne(roots[0], prData); + for (int i = 1; i < n + 1; i++) + Mul(roots[i], roots[i - 1], root[0], prData); +} + void FFT_Data::hash(octetStream& o) const { @@ -111,6 +103,7 @@ void FFT_Data::pack(octetStream& o) const R.pack(o); prData.pack(o); o.store(root); + o.store(roots); o.store(twop); o.store(two_root); o.store(b); @@ -125,6 +118,7 @@ void FFT_Data::unpack(octetStream& o) R.unpack(o); prData.unpack(o); o.get(root); + o.get(roots); o.get(twop); o.get(two_root); o.get(b); @@ -133,7 +127,6 @@ void FFT_Data::unpack(octetStream& o) o.get(powers_i); } - bool FFT_Data::operator!=(const FFT_Data& other) const { if (R != other.R or prData != other.prData or root != other.root diff --git a/FHE/FFT_Data.h b/FHE/FFT_Data.h index fc339c81e..c5d6b2063 100644 --- a/FHE/FFT_Data.h +++ b/FHE/FFT_Data.h @@ -19,6 +19,7 @@ class FFT_Data Zp_Data prData; vector root; // 2m'th Root of Unity mod pr and it's inverse + vector roots; // precomputed powers of root // When twop is equal to zero, m is a power of two // When twop is positive it is equal to 2^e where 2^e>2*m and 2^e divides p-1 @@ -34,6 +35,8 @@ class FFT_Data modp iphi; // 1/phi_m mod pr vector< vector > powers,powers_i; + void compute_roots(int n); + public: typedef gfp T; typedef bigint S; @@ -47,17 +50,9 @@ class FFT_Data void pack(octetStream& o) const; void unpack(octetStream& o); - void assign(const FFT_Data& FFTD); - FFT_Data() { ; } - FFT_Data(const FFT_Data& FFTD) - { assign(FFTD); } FFT_Data(const Ring& Rg,const Zp_Data& PrD) { init(Rg,PrD); } - FFT_Data& operator=(const FFT_Data& FFTD) - { if (this!=&FFTD) { assign(FFTD); } - return *this; - } const Zp_Data& get_prD() const { return prData; } const bigint& get_prime() const { return prData.pr; } @@ -72,6 +67,7 @@ class FFT_Data int get_twop() const { return twop; } modp get_root(int i) const { return root[i]; } modp get_iphi() const { return iphi; } + const vector& get_roots() const { return roots; } const Ring& get_R() const { return R; } diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp index 990a0d206..8a6580926 100644 --- a/FHE/FHE_Keys.cpp +++ b/FHE/FHE_Keys.cpp @@ -42,7 +42,7 @@ Rq_Element FHE_PK::sample_secret_key(PRNG& G) { Rq_Element sk = FHE_SK(*this).s(); // Generate the secret key - sk.from_vec((*params).sampleHwt(G)); + sk.from(GaussianGenerator(params->get_DG(), G)); return sk; } @@ -55,7 +55,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) // b0=a0*s+p*e0 Rq_Element e0((*PK.params).FFTD(),evaluation,evaluation); - e0.from_vec((*PK.params).sampleGaussian(G, noise_boost)); + e0.from(GaussianGenerator(params->get_DG(), G, noise_boost)); mul(PK.b0,PK.a0,sk); mul(e0,e0,PK.pr); add(PK.b0,PK.b0,e0); @@ -72,7 +72,7 @@ void FHE_PK::KeyGen(Rq_Element& sk, PRNG& G, int noise_boost) // bs=as*s+p*es Rq_Element es((*PK.params).FFTD(),evaluation,evaluation); - es.from_vec((*PK.params).sampleGaussian(G, noise_boost)); + es.from(GaussianGenerator(params->get_DG(), G, noise_boost)); mul(PK.Sw_b,PK.Sw_a,sk); mul(es,es,PK.pr); add(PK.Sw_b,PK.Sw_b,es); @@ -120,13 +120,14 @@ void FHE_PK::check_noise(const Rq_Element& x, bool check_modulo) const } -template<> +template void FHE_PK::encrypt(Ciphertext& c, - const Plaintext& mess,const Random_Coins& rc) const + const Plaintext& mess,const Random_Coins& rc) const { if (&c.get_params()!=params) { throw params_mismatch(); } if (&rc.get_params()!=params) { throw params_mismatch(); } - if (pr==2) { throw pr_mismatch(); } + if (T::characteristic_two ^ (pr == 2)) + throw pr_mismatch(); Rq_Element mm((*params).FFTD(),polynomial,polynomial); mm.from(mess.get_iterator()); @@ -134,35 +135,6 @@ void FHE_PK::encrypt(Ciphertext& c, quasi_encrypt(c,mm,rc); } - - -template<> -void FHE_PK::encrypt(Ciphertext& c, - const Plaintext& mess,const Random_Coins& rc) const -{ - if (&c.get_params()!=params) { throw params_mismatch(); } - if (&rc.get_params()!=params) { throw params_mismatch(); } - if (pr==2) { throw pr_mismatch(); } - - mess.to_poly(); - encrypt(c, mess.get_poly(), rc); -} - - - - -template<> -void FHE_PK::encrypt(Ciphertext& c, - const Plaintext& mess,const Random_Coins& rc) const -{ - if (&c.get_params()!=params) { throw params_mismatch(); } - if (&rc.get_params()!=params) { throw params_mismatch(); } - if (pr!=2) { throw pr_mismatch(); } - - mess.to_poly(); - encrypt(c, mess.get_poly(), rc); -} - void FHE_PK::quasi_encrypt(Ciphertext& c, const Rq_Element& mess,const Random_Coins& rc) const { @@ -212,42 +184,12 @@ Ciphertext FHE_PK::encrypt( } -template<> -void FHE_SK::decrypt(Plaintext& mess,const Ciphertext& c) const +template +void FHE_SK::decrypt(Plaintext& mess,const Ciphertext& c) const { if (&c.get_params()!=params) { throw params_mismatch(); } - if (pr==2) { throw pr_mismatch(); } - - Rq_Element ans; - - mul(ans,c.c1(),sk); - sub(ans,c.c0(),ans); - ans.change_rep(polynomial); - mess.set_poly_mod(ans.get_iterator(), ans.get_modulus()); -} - - - -template<> -void FHE_SK::decrypt(Plaintext& mess,const Ciphertext& c) const -{ - if (&c.get_params()!=params) { throw params_mismatch(); } - if (pr==2) { throw pr_mismatch(); } - - Rq_Element ans; - - mul(ans,c.c1(),sk); - sub(ans,c.c0(),ans); - mess.set_poly_mod(ans.to_vec_bigint(),ans.get_modulus()); -} - - - -template<> -void FHE_SK::decrypt(Plaintext& mess,const Ciphertext& c) const -{ - if (&c.get_params()!=params) { throw params_mismatch(); } - if (pr!=2) { throw pr_mismatch(); } + if (T::characteristic_two ^ (pr == 2)) + throw pr_mismatch(); Rq_Element ans; diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp index 7f5563390..8ae6c2885 100644 --- a/FHE/FHE_Params.cpp +++ b/FHE/FHE_Params.cpp @@ -3,14 +3,6 @@ #include "FHE/Ring_Element.h" #include "Tools/Exceptions.h" -void FHE_Params::set(const Ring& R, - const vector& primes,double r,int hwt) -{ - set(R, primes); - - Chi.set(R.phi_m(),hwt,r); -} - void FHE_Params::set(const Ring& R, const vector& primes) { @@ -20,7 +12,6 @@ void FHE_Params::set(const Ring& R, for (size_t i = 0; i < FFTData.size(); i++) FFTData[i].init(R,primes[i]); - Chi.set_n(R.phi_m()); set_sec(40); } diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index d918c9567..ac56668a2 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -21,7 +21,7 @@ class FHE_Params vector FFTData; // Random generator for Multivariate Gaussian Distribution etc - RandomVectors Chi; + mutable DiscreteGauss Chi; // Data for distributed decryption int sec_p; @@ -29,27 +29,17 @@ class FHE_Params public: - FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(-1, 0.7), sec_p(-1) {} + FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(0.7), sec_p(-1) {} int n_mults() const { return FFTData.size() - 1; } // Rely on default copy assignment/constructor (not that they should // ever be needed) - void set(const Ring& R,const vector& primes,double r,int hwt); void set(const Ring& R,const vector& primes); void set(const vector& primes); void set_sec(int sec); - vector sampleGaussian(PRNG& G, int noise_boost = 1) const - { return Chi.sample_Gauss(G, noise_boost); } - vector sampleHwt(PRNG& G) const - { return Chi.sample_Hwt(G); } - vector sampleHalf(PRNG& G) const - { return Chi.sample_Half(G); } - vector sampleUniform(PRNG& G,const bigint& Bd) const - { return Chi.sample_Uniform(G,Bd); } - const vector& FFTD() const { return FFTData; } const bigint& p0() const { return FFTData[0].get_prime(); } @@ -59,9 +49,8 @@ class FHE_Params int secp() const { return sec_p; } const bigint& B() const { return Bval; } double get_R() const { return Chi.get_R(); } - void set_R(double R) const { return Chi.get_DG().set(R); } - DiscreteGauss get_DG() const { return Chi.get_DG(); } - int get_h() const { return Chi.get_h(); } + void set_R(double R) const { return Chi.set(R); } + DiscreteGauss get_DG() const { return Chi; } int phi_m() const { return FFTData[0].phi_m(); } const Ring& get_ring() { return FFTData[0].get_R(); } diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index 3fa3cc86a..e9a29269f 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -52,10 +52,12 @@ int generate_semi_setup(int plaintext_length, int sec, bigint p; generate_prime(p, lgp, m); int lgp0, lgp1; + FHE_Params tmp_params; while (true) { + tmp_params = params; SemiHomomorphicNoiseBounds nb(p, phi_N(m), 1, sec, - numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, params); + numBits(NonInteractiveProof::slack(sec, phi_N(m))), true, tmp_params); bigint p1 = 2 * p * m, p0 = p; while (nb.min_p0(params.n_mults() > 0, p1) > p0) { @@ -75,6 +77,7 @@ int generate_semi_setup(int plaintext_length, int sec, } } + params = tmp_params; int extra_slack = common_semi_setup(params, m, p, lgp0, lgp1, round_up); FTD.init(params.get_ring(), p); diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index f343d7f71..633e5b91f 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -13,29 +13,24 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p, const FHE_Params& params) : p(p), phi_m(phi_m), n(n), sec(sec), slack(numBits(Proof::slack(slack_param, sec, phi_m))), - sigma(params.get_R()), h(params.get_h()) + sigma(params.get_R()) { if (sigma <= 0) this->sigma = sigma = FHE_Params().get_R(); -#ifdef VERBOSE - cerr << "Standard deviation: " << this->sigma << endl; -#endif - if (h > 0) - h += extra_h * sec; - else if (extra_h) + if (extra_h) { sigma *= 1.4; params.set_R(params.get_R() * 1.4); } +#ifdef VERBOSE + cerr << "Standard deviation: " << this->sigma << endl; +#endif produce_epsilon_constants(); // according to documentation of SCALE-MAMBA 1.7 // excluding a factor of n because we don't always add up n ciphertexts - if (h > 0) - V_s = sqrt(h); - else - V_s = sigma * sqrt(phi_m); + V_s = sigma * sqrt(phi_m); B_clean = (bigint(phi_m) << (sec + 1)) * p * (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * V_s); // unify parameters by taking maximum over TopGear or not diff --git a/FHE/NoiseBounds.h b/FHE/NoiseBounds.h index 466190320..ccd50808a 100644 --- a/FHE/NoiseBounds.h +++ b/FHE/NoiseBounds.h @@ -22,7 +22,6 @@ class SemiHomomorphicNoiseBounds const int sec; int slack; mpf_class sigma; - int h; bigint B_clean; bigint B_scale; diff --git a/FHE/PPData.cpp b/FHE/PPData.cpp index b73277e19..282d830bb 100644 --- a/FHE/PPData.cpp +++ b/FHE/PPData.cpp @@ -5,14 +5,6 @@ -void PPData::assign(const PPData& PPD) -{ - R=PPD.R; - prData=PPD.prData; - root=PPD.root; -} - - void PPData::init(const Ring& Rg,const Zp_Data& PrD) { R=Rg; diff --git a/FHE/PPData.h b/FHE/PPData.h index fcb5a3fd1..46c8c8e6e 100644 --- a/FHE/PPData.h +++ b/FHE/PPData.h @@ -27,17 +27,9 @@ class PPData void init(const Ring& Rg,const Zp_Data& PrD); - void assign(const PPData& PPD); - PPData() { ; } - PPData(const PPData& PPD) - { assign(PPD); } PPData(const Ring& Rg,const Zp_Data& PrD) { init(Rg,PrD); } - PPData& operator=(const PPData& PPD) - { if (this!=&PPD) { assign(PPD); } - return *this; - } const Zp_Data& get_prD() const { return prData; } const bigint& get_prime() const { return prData.pr; } diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp index b9353df80..b22b5b949 100644 --- a/FHE/Plaintext.cpp +++ b/FHE/Plaintext.cpp @@ -5,6 +5,7 @@ #include "FHE/P2Data.h" #include "FHE/Rq_Element.h" #include "FHE_Keys.h" +#include "FHE/AddableVector.hpp" #include "Math/Z2k.hpp" #include "Math/modp.hpp" @@ -258,37 +259,9 @@ void Plaintext::randomize(PRNG& G,condition cond) } -template<> -void Plaintext_::randomize(PRNG& G, bigint B, bool Diag, bool binary, PT_Type t) -{ - if (Diag or binary) - throw not_implemented(); - if (B == 0) - throw runtime_error("cannot randomize modulo 0"); - - allocate(t); - switch (t) - { - case Polynomial: - rand_poly(b, G, B, false); - break; - case Evaluation: - for (int i = 0; i < n_slots; i++) - a[i] = G.randomBnd(B); - break; - default: - throw runtime_error("wrong type for randomization with bound"); - break; - } -} - - template -void Plaintext::randomize(PRNG& G, int n_bits, bool Diag, bool binary, PT_Type t) +void Plaintext::randomize(PRNG& G, int n_bits, bool Diag, PT_Type t) { - if (binary) - throw not_implemented(); - allocate(t); switch(t) { @@ -614,10 +587,11 @@ void Plaintext::negate() -template -Rq_Element Plaintext::mul_by_X_i(int i, const FHE_PK& pk) const +template +AddableVector Plaintext::mul_by_X_i(int i, + const FHE_PK& pk) const { - return Rq_Element(pk.get_params(), *this).mul_by_X_i(i); + return AddableVector(get_poly()).mul_by_X_i(i, pk); } diff --git a/FHE/Plaintext.h b/FHE/Plaintext.h index 5781e1951..52ff8b6d4 100644 --- a/FHE/Plaintext.h +++ b/FHE/Plaintext.h @@ -25,6 +25,7 @@ using namespace std; class FHE_PK; class Rq_Element; +template class AddableVector; // Forward declaration as apparently this is needed for friends in templates template class Plaintext; @@ -64,13 +65,6 @@ class Plaintext const FD& get_field() const { return *Field_Data; } unsigned int num_slots() const { return n_slots; } - void assign(const Plaintext& p) - { Field_Data=p.Field_Data; - a=p.a; b=p.b; type=p.type; - n_slots = p.n_slots; - degree = p.degree; - } - Plaintext(const FD& FieldD, PT_Type type = Polynomial) { Field_Data=&FieldD; set_sizes(); allocate(type); } @@ -142,8 +136,7 @@ class Plaintext void to_poly() const; void randomize(PRNG& G,condition cond=Full); - void randomize(PRNG& G, bigint B, bool Diag=false, bool binary=false, PT_Type type=Polynomial); - void randomize(PRNG& G, int n_bits, bool Diag=false, bool binary=false, PT_Type type=Polynomial); + void randomize(PRNG& G, int n_bits, bool Diag=false, PT_Type type=Polynomial); void assign_zero(PT_Type t = Evaluation); void assign_one(PT_Type t = Evaluation); @@ -171,13 +164,12 @@ class Plaintext void negate(); - Rq_Element mul_by_X_i(int i, const FHE_PK& pk) const; + AddableVector mul_by_X_i(int i, const FHE_PK& pk) const; bool equals(const Plaintext& x) const; bool operator!=(const Plaintext& x) { return !equals(x); } bool is_diagonal() const; - bool is_binary() const { throw not_implemented(); } /* Pack and unpack into an octetStream * For unpack we assume the FFTD has been assigned correctly already diff --git a/FHE/Random_Coins.h b/FHE/Random_Coins.h index 65f6c3cf2..ad0d9fdcb 100644 --- a/FHE/Random_Coins.h +++ b/FHE/Random_Coins.h @@ -52,8 +52,6 @@ class Random_Coins { params=&p; } Random_Coins(const FHE_PK& pk); - - ~Random_Coins() { ; } // Rely on default copy assignment/constructor diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 2acb57ef5..9c2545ed8 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -33,17 +33,36 @@ Ring_Element::Ring_Element(const FFT_Data& fftd,RepType r) } +void Ring_Element::prepare(const Ring_Element& other) +{ + assert(this != &other); + FFTD = other.FFTD; + rep = other.rep; + prepare_push(); +} + +void Ring_Element::prepare_push() +{ + element.clear(); + element.reserve(FFTD->phi_m()); +} + + +void Ring_Element::allocate() +{ + element.resize(FFTD->phi_m()); +} + + void Ring_Element::assign_zero() { - element.resize((*FFTD).phi_m()); - for (int i=0; i<(*FFTD).phi_m(); i++) - { assignZero(element[i],(*FFTD).get_prD()); } + element.clear(); } void Ring_Element::assign_one() { - element.resize((*FFTD).phi_m()); + allocate(); modp fill; if (rep==polynomial) { assignZero(fill,(*FFTD).get_prD()); } else { assignOne(fill,(*FFTD).get_prD()); } @@ -56,6 +75,9 @@ void Ring_Element::assign_one() void Ring_Element::negate() { + if (element.empty()) + return; + for (int i=0; i<(*FFTD).phi_m(); i++) { Negate(element[i],element[i],(*FFTD).get_prD()); } } @@ -66,20 +88,58 @@ void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) { if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } - ans.partial_assign(a); - for (int i=0; i<(*ans.FFTD).phi_m(); i++) - { Add(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); } -} + if (a.element.empty()) + { + ans = b; + return; + } + else if (b.element.empty()) + { + ans = a; + return; + } + if (&ans == &a) + { + ans += b; + return; + } + else if (&ans == &b) + { + ans += a; + return; + } + ans.prepare(a); + for (int i=0; i<(*ans.FFTD).phi_m(); i++) + ans.element.push_back(a.element[i].add(b.element[i], a.FFTD->get_prD())); +} void sub(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) { if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } - ans.partial_assign(a); + if (a.element.empty()) + { + ans = b; + ans.negate(); + return; + } + else if (b.element.empty()) + { + ans = a; + return; + } + + if (&ans == &a) + { + ans -= b; + return; + } + + ans.prepare(a); for (int i=0; i<(*ans.FFTD).phi_m(); i++) - { Sub(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); } + ans.element.push_back(a.element[i].sub(b.element[i], a.FFTD->get_prD())); } @@ -88,13 +148,29 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) { if (a.rep!=b.rep) { throw rep_mismatch(); } if (a.FFTD!=b.FFTD) { throw pr_mismatch(); } - ans.partial_assign(a); - if (ans.rep==evaluation) + if (a.element.empty() or b.element.empty()) + { + ans = Ring_Element(*a.FFTD, a.rep); + return; + } + + if (a.rep==evaluation) { // In evaluation representation, so we can just multiply componentwise + if (&ans == &a) + { + ans *= b; + return; + } + else if (&ans == &b) + { + ans *= a; + return; + } + ans.prepare(a); for (int i=0; i<(*ans.FFTD).phi_m(); i++) - { Mul(ans.element[i],a.element[i],b.element[i],(*a.FFTD).get_prD()); } + ans.element.push_back(a.element[i].mul(b.element[i], a.FFTD->get_prD())); } - else if ((*ans.FFTD).get_twop()!=0) + else if ((*a.FFTD).get_twop()!=0) { // This is the case where m is not a power of two // Here we have to do a poly mult followed by a reduction @@ -116,11 +192,13 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) // Now apply reduction, assumes Ring.poly is monic reduce(aa, 2*(*a.FFTD).phi_m(), (*a.FFTD).phi_m(), *a.FFTD); // Now stick into answer + ans.partial_assign(a); for (int i=0; i<(*ans.FFTD).phi_m(); i++) { ans.element[i]=aa[i]; } } - else if ((*ans.FFTD).get_twop()==0) + else if ((*a.FFTD).get_twop()==0) { // m a power of two case + ans.partial_assign(a); Ring_Element aa(*ans.FFTD,ans.rep); modp temp; for (int i=0; i<(*ans.FFTD).phi_m(); i++) @@ -143,31 +221,89 @@ void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b) void mul(Ring_Element& ans,const Ring_Element& a,const modp& b) { - ans.partial_assign(a); + if (&ans == &a) + { + ans *= b; + return; + } + + ans.prepare(a); + if (a.element.empty()) + return; + for (int i=0; i<(*ans.FFTD).phi_m(); i++) - { Mul(ans.element[i],a.element[i],b,(*a.FFTD).get_prD()); } + ans.element.push_back(a.element[i].mul(b, a.FFTD->get_prD())); +} + + +Ring_Element& Ring_Element::operator +=(const Ring_Element& other) +{ + assert(element.size() == other.element.size()); + assert(FFTD == other.FFTD); + assert(rep == other.rep); + for (size_t i = 0; i < element.size(); i++) + element[i] = element[i].add(other.element[i], FFTD->get_prD()); + return *this; +} + + +Ring_Element& Ring_Element::operator -=(const Ring_Element& other) +{ + assert(element.size() == other.element.size()); + assert(FFTD == other.FFTD); + assert(rep == other.rep); + for (size_t i = 0; i < element.size(); i++) + element[i] = element[i].sub(other.element[i], FFTD->get_prD()); + return *this; +} + + +Ring_Element& Ring_Element::operator *=(const Ring_Element& other) +{ + assert(element.size() == other.element.size()); + assert(FFTD == other.FFTD); + assert(rep == other.rep); + assert(rep == evaluation); + for (size_t i = 0; i < element.size(); i++) + element[i] = element[i].mul(other.element[i], FFTD->get_prD()); + return *this; +} + + +Ring_Element& Ring_Element::operator *=(const modp& other) +{ + for (size_t i = 0; i < element.size(); i++) + element[i] = element[i].mul(other, FFTD->get_prD()); + return *this; } Ring_Element Ring_Element::mul_by_X_i(int j) const { Ring_Element ans; + ans.prepare(*this); + if (element.empty()) + return ans; + auto& a = *this; - ans.partial_assign(a); if (ans.rep == evaluation) { modp xj, xj2; Power(xj, (*ans.FFTD).get_root(0), j, (*a.FFTD).get_prD()); Sqr(xj2, xj, (*a.FFTD).get_prD()); + ans.prepare_push(); + modp tmp; for (int i= 0; i < (*ans.FFTD).phi_m(); i++) { - Mul(ans.element[i], a.element[i], xj, (*a.FFTD).get_prD()); + Mul(tmp, a.element[i], xj, (*a.FFTD).get_prD()); + ans.element.push_back(tmp); Mul(xj, xj, xj2, (*a.FFTD).get_prD()); } } else { Ring_Element aa(*ans.FFTD, ans.rep); + aa.allocate(); for (int i= 0; i < (*ans.FFTD).phi_m(); i++) { int k= j + i, s= 1; @@ -193,6 +329,7 @@ Ring_Element Ring_Element::mul_by_X_i(int j) const void Ring_Element::randomize(PRNG& G,bool Diag) { + allocate(); if (Diag==false) { for (int i=0; i<(*FFTD).phi_m(); i++) { element[i].randomize(G,(*FFTD).get_prD()); } @@ -213,12 +350,18 @@ void Ring_Element::randomize(PRNG& G,bool Diag) void Ring_Element::change_rep(RepType r) { + if (element.empty()) + { + rep = r; + return; + } + if (rep==r) { return; } if (r==evaluation) { rep=evaluation; if ((*FFTD).get_twop()==0) { // m a power of two variant - FFT_Iter2(element,(*FFTD).phi_m(),(*FFTD).get_root(0),(*FFTD).get_prD()); + FFT_Iter2(element,(*FFTD).phi_m(),(*FFTD).get_roots(),(*FFTD).get_prD()); } else { // Non m power of two variant and FFT enabled @@ -258,6 +401,11 @@ void Ring_Element::change_rep(RepType r) bool Ring_Element::equals(const Ring_Element& a) const { + if (element.empty() and a.element.empty()) + return true; + else if (element.empty() or a.element.empty()) + throw not_implemented(); + if (rep!=a.rep) { throw rep_mismatch(); } if (*FFTD!=*a.FFTD) { throw pr_mismatch(); } for (int i=0; i<(*FFTD).phi_m(); i++) @@ -266,34 +414,11 @@ bool Ring_Element::equals(const Ring_Element& a) const } -void Ring_Element::from_vec(const vector& v) -{ - RepType t=rep; - rep=polynomial; - bigint tmp; - for (int i=0; i<(*FFTD).phi_m(); i++) - { - tmp = v[i]; - element[i].convert_destroy(tmp, FFTD->get_prD()); - } - change_rep(t); -// cout << "RE:from_vec:: " << *this << endl; -} - -void Ring_Element::from_vec(const vector& v) -{ - RepType t=rep; - rep=polynomial; - for (int i=0; i<(*FFTD).phi_m(); i++) - { to_modp(element[i],v[i],(*FFTD).get_prD()); } - change_rep(t); -// cout << "RE:from_vec:: " << *this << endl; -} - ConversionIterator Ring_Element::get_iterator() const { if (rep != polynomial) throw runtime_error("simple iterator only available in polynomial represention"); + assert(not element.empty()); return {element, (*FFTD).get_prD()}; } @@ -318,6 +443,9 @@ vector Ring_Element::to_vec_bigint() const void Ring_Element::to_vec_bigint(vector& v) const { v.resize(FFTD->phi_m()); + if (element.empty()) + return; + if (rep==polynomial) { for (int i=0; i<(*FFTD).phi_m(); i++) { to_bigint(v[i],element[i],(*FFTD).get_prD()); } @@ -336,11 +464,10 @@ void Ring_Element::to_vec_bigint(vector& v) const modp Ring_Element::get_constant() const { - if (rep==polynomial) - { return element[0]; } - Ring_Element a=*this; - a.change_rep(polynomial); - return a.element[0]; + if (element.empty()) + return {}; + else + return element[0]; } @@ -364,9 +491,14 @@ void get(octetStream& o,vector& v,const Zp_Data& ZpD) + to_string(ZpD.pr_bit_length)); unsigned int length; o.get(length); - v.resize(length); + v.clear(); + v.reserve(length); + modp tmp; for (unsigned int i=0; iphi_m()) + if (not element.empty() and (int)element.size() != FFTD->phi_m()) throw runtime_error("invalid element size"); } diff --git a/FHE/Ring_Element.h b/FHE/Ring_Element.h index ba147062e..f221d0d2b 100644 --- a/FHE/Ring_Element.h +++ b/FHE/Ring_Element.h @@ -41,12 +41,6 @@ class Ring_Element vector element; - // Define a copy - void assign(const Ring_Element& e) - { rep=e.rep; FFTD=e.FFTD; - element=e.element; - } - public: // Used to basically make sure *this is able to cope @@ -57,6 +51,10 @@ class Ring_Element element.resize((*FFTD).phi_m()); } + void prepare(const Ring_Element& e); + void prepare_push(); + void allocate(); + void set_data(const FFT_Data& prd) { FFTD=&prd; } const FFT_Data& get_FFTD() const { return *FFTD; } const Zp_Data& get_prD() const { return (*FFTD).get_prD(); } @@ -80,19 +78,6 @@ class Ring_Element element.push_back(x); } - // Copy Constructor - Ring_Element(const Ring_Element& e) - { assign(e); } - - // Destructor - ~Ring_Element() { ; } - - // Copy Assignment - Ring_Element& operator=(const Ring_Element& e) - { if (this!=&e) { assign(e); } - return *this; - } - /* Functional Operators */ void negate(); friend void add(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b); @@ -102,6 +87,11 @@ class Ring_Element Ring_Element mul_by_X_i(int i) const; + Ring_Element& operator+=(const Ring_Element& other); + Ring_Element& operator-=(const Ring_Element& other); + Ring_Element& operator*=(const Ring_Element& other); + Ring_Element& operator*=(const modp& other); + void randomize(PRNG& G,bool Diag=false); bool equals(const Ring_Element& a) const; @@ -112,8 +102,6 @@ class Ring_Element // Converting to and from a vector of bigint/int's // I/O is assumed to be in poly rep, so from_vec it internally alters // the representation to the current representation - void from_vec(const vector& v); - void from_vec(const vector& v); vector to_vec_bigint() const; void to_vec_bigint(vector& v) const; @@ -136,8 +124,18 @@ class Ring_Element // This gets the constant term of the poly rep as a modp element modp get_constant() const; - modp get_element(int i) const { return element[i]; } - void set_element(int i,const modp& a) { element[i]=a; } + modp get_element(int i) const + { + if (element.empty()) + return {}; + else + return element[i]; + } + void set_element(int i,const modp& a) + { + allocate(); + element[i] = a; + } /* Pack and unpack into an octetStream * For unpack we assume the FFTD has been assigned correctly already @@ -164,7 +162,11 @@ class RingWriteIterator : public WriteConversionIterator public: RingWriteIterator(Ring_Element& element) : WriteConversionIterator(element.element, element.FFTD->get_prD()), - element(element), rep(element.rep) { element.rep = polynomial; } + element(element), rep(element.rep) + { + element.rep = polynomial; + element.allocate(); + } ~RingWriteIterator() { element.change_rep(rep); } }; @@ -175,7 +177,11 @@ class RingReadIterator : public ConversionIterator public: RingReadIterator(const Ring_Element& element) : ConversionIterator(this->element.element, element.FFTD->get_prD()), - element(element) { this->element.change_rep(polynomial); } + element(element) + { + this->element.change_rep(polynomial); + this->element.allocate(); + } }; @@ -189,10 +195,13 @@ void Ring_Element::from(const Generator& generator) RepType t=rep; rep=polynomial; T tmp; + modp tmp2; + prepare_push(); for (int i=0; i<(*FFTD).phi_m(); i++) { generator.get(tmp); - element[i].convert_destroy(tmp, (*FFTD).get_prD()); + tmp2.convert_destroy(tmp, (*FFTD).get_prD()); + element.push_back(tmp2); } change_rep(t); } diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp index fb192087e..af7a664b5 100644 --- a/FHE/Rq_Element.cpp +++ b/FHE/Rq_Element.cpp @@ -48,15 +48,6 @@ void Rq_Element::partial_assign(const Rq_Element& other) { lev=other.lev; a.resize(other.a.size()); - for (size_t i = 0; i < a.size(); i++) - a[i].partial_assign(other.a[i]); -} - -void Rq_Element::assign(const Rq_Element& other) -{ - partial_assign(other); - for (int i=0; i<=lev; ++i) - a[i] = other.a[i]; } void Rq_Element::negate() @@ -134,20 +125,6 @@ bool Rq_Element::equals(const Rq_Element& other) const } -void Rq_Element::from_vec(const vector& v,int level) -{ - set_level(level); - for (int i=0;i<=lev;++i) - a[i].from_vec(v); -} - -void Rq_Element::from_vec(const vector& v,int level) -{ - set_level(level); - for (int i=0;i<=lev;++i) - a[i].from_vec(v); -} - vector Rq_Element::to_vec_bigint() const { vector v; diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h index db3d4649d..d5e718419 100644 --- a/FHE/Rq_Element.h +++ b/FHE/Rq_Element.h @@ -44,7 +44,6 @@ class Rq_Element void assign_zero(const vector& prd); void assign_zero(); void assign_one(); - void assign(const Rq_Element& e); void partial_assign(const Rq_Element& e); // Must be careful not to call by mistake @@ -85,10 +84,6 @@ class Rq_Element a[1] = Ring_Element(prd[1], r, b1); } - // Destructor - ~Rq_Element() - { ; } - const Ring_Element& get(int i) const { return a[i]; } /* Functional Operators */ @@ -131,8 +126,6 @@ class Rq_Element void partial_assign(const Rq_Element& a, const Rq_Element& b); // Converting to and from a vector of bigint's Again I/O is in poly rep - void from_vec(const vector& v,int level=-1); - void from_vec(const vector& v,int level=-1); vector to_vec_bigint() const; void to_vec_bigint(vector& v) const; diff --git a/FHEOffline/DataSetup.hpp b/FHEOffline/DataSetup.hpp index c1124a8e0..f22a15fc2 100644 --- a/FHEOffline/DataSetup.hpp +++ b/FHEOffline/DataSetup.hpp @@ -49,7 +49,8 @@ void read_or_generate_secrets(T& setup, Player& P, U& machine, if (not error.empty()) { - cerr << "Running secrets generation because " << error << endl; + cerr << "Running secrets generation because no suitable material " + "from a previous run was found (" << error << ")" << endl; setup.key_and_mac_generation(P, machine, num_runs, V()); ofstream output(filename); diff --git a/FHEOffline/DistKeyGen.cpp b/FHEOffline/DistKeyGen.cpp index 255ec4519..482a87e59 100644 --- a/FHEOffline/DistKeyGen.cpp +++ b/FHEOffline/DistKeyGen.cpp @@ -109,11 +109,11 @@ DistKeyGen::DistKeyGen(const FHE_Params& params, const bigint& p) : */ void DistKeyGen::Gen_Random_Data(PRNG& G) { - secret.from_vec(params.sampleHwt(G)); + secret.from(GaussianGenerator(params.get_DG(), G)); rc1.generate(G); rc2.generate(G); a.randomize(G); - e.from_vec(params.sampleGaussian(G)); + e.from(GaussianGenerator(params.get_DG(), G)); } DistKeyGen& DistKeyGen::operator+=(const DistKeyGen& other) diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp index 89017c754..8d0c49d92 100644 --- a/FHEOffline/Multiplier.cpp +++ b/FHEOffline/Multiplier.cpp @@ -45,7 +45,7 @@ template void Multiplier::multiply_and_add(Plaintext_& res, const Ciphertext& enc_a, const Rq_Element& b, OT_ROLE role) { - octetStream o; + o.reset_write_head(); if (role & SENDER) { diff --git a/FHEOffline/Multiplier.h b/FHEOffline/Multiplier.h index 17159a7f7..1a147b918 100644 --- a/FHEOffline/Multiplier.h +++ b/FHEOffline/Multiplier.h @@ -36,6 +36,8 @@ class Multiplier size_t volatile_capacity; MemoryUsage memory_usage; + octetStream o; + public: Multiplier(int offset, PairwiseGenerator& generator); Multiplier(int offset, PairwiseMachine& machine, Player& P, diff --git a/FHEOffline/Player-Offline.h b/FHEOffline/Player-Offline.h deleted file mode 100644 index 10ae1c823..000000000 --- a/FHEOffline/Player-Offline.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Player-Offline.h - * - */ - -#ifndef FHEOFFLINE_PLAYER_OFFLINE_H_ -#define FHEOFFLINE_PLAYER_OFFLINE_H_ - -class thread_info -{ - public: - - int thread_num; - int covert; - Names* Nms; - FHE_PK* pk_p; - FHE_PK* pk_2; - FHE_SK* sk_p; - FHE_SK* sk_2; - Ciphertext *calphap; - Ciphertext *calpha2; - gfp *alphapi; - gf2n_short *alpha2i; - - FFT_Data *FTD; - P2Data *P2D; - - int nm2,nmp,nb2,nbp,ni2,nip,ns2,nsp,nvp; - bool skip_2() { return nm2 + ni2 + nb2 + ns2 == 0; } -}; - -#endif /* FHEOFFLINE_PLAYER_OFFLINE_H_ */ diff --git a/FHEOffline/Producer.cpp b/FHEOffline/Producer.cpp index d121372fb..5714b7224 100644 --- a/FHEOffline/Producer.cpp +++ b/FHEOffline/Producer.cpp @@ -589,7 +589,7 @@ void InputProducer::run(const Player& P, const FHE_PK& pk, P.receive_player(j, cleartexts); C.resize(personal_EC.machine->sec, pk.get_params()); Verifier(personal_EC.proof, FieldD).NIZKPoK(C, ciphertexts, - cleartexts, pk, false); + cleartexts, pk); } inputs[j].clear(); diff --git a/FHEOffline/Proof.cpp b/FHEOffline/Proof.cpp index e5bc641db..c65927207 100644 --- a/FHEOffline/Proof.cpp +++ b/FHEOffline/Proof.cpp @@ -88,6 +88,7 @@ class AbsoluteBoundChecker bool Proof::check_bounds(T& z, X& t, int i) const { + (void)i; unsigned int j,k; // Check Bound 1 and Bound 2 @@ -99,9 +100,11 @@ bool Proof::check_bounds(T& z, X& t, int i) const auto& te = z[j]; if (plain_checker.outside(te, dist)) { +#ifdef VERBOSE cout << "Fail on Check 1 " << i << " " << j << endl; cout << te << " " << plain_check << endl; cout << tau << " " << sec << " " << n_proofs << endl; +#endif return false; } } @@ -113,9 +116,11 @@ bool Proof::check_bounds(T& z, X& t, int i) const auto& te = coeffs.at(j); if (rand_checker.outside(te, dist)) { +#ifdef VERBOSE cout << "Fail on Check 2 " << k << " : " << i << " " << j << endl; cout << te << " " << rand_check << endl; cout << rho << " " << sec << " " << n_proofs << endl; +#endif return false; } } diff --git a/FHEOffline/Prover.cpp b/FHEOffline/Prover.cpp index 96204976f..d92f30806 100644 --- a/FHEOffline/Prover.cpp +++ b/FHEOffline/Prover.cpp @@ -6,6 +6,7 @@ #include "Tools/random.h" #include "Math/Z2k.hpp" #include "Math/modp.hpp" +#include "FHE/AddableVector.hpp" template @@ -28,7 +29,7 @@ Prover::Prover(Proof& proof, const FD& FieldD) : template void Prover::Stage_1(const Proof& P, octetStream& ciphertexts, const AddableVector& c, - const FHE_PK& pk, bool binary) + const FHE_PK& pk) { size_t allocate = 3 * c.size() * c[0].report_size(USED); ciphertexts.resize_precise(allocate); @@ -51,7 +52,7 @@ void Prover::Stage_1(const Proof& P, octetStream& ciphertexts, // AE.randomize(Diag,binary); // rd=RandPoly(phim,bd<<1); // y[i]=AE.plaintext()+pr*rd; - y[i].randomize(G, P.B_plain_length, P.get_diagonal(), binary); + y[i].randomize(G, P.B_plain_length, P.get_diagonal()); if (P.get_diagonal()) assert(y[i].is_diagonal()); s[i].resize(3, P.phim); @@ -114,8 +115,7 @@ size_t Prover::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl const FHE_PK& pk, const AddableVector& c, const vector& x, - const Proof::Randomness& r, - bool binary) + const Proof::Randomness& r) { // AElement AE; // for (i=0; i::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl int cnt=0; while (!ok) { cnt++; - Stage_1(P,ciphertexts,c,pk,binary); + Stage_1(P,ciphertexts,c,pk); P.set_challenge(ciphertexts); // Check check whether we are OK, or whether we should abort ok = Stage_2(P,cleartexts,x,r,pk); } +#ifdef VERBOSE if (cnt > 1) cout << "\t\tNumber iterations of prover = " << cnt << endl; +#endif return report_size(CAPACITY) + volatile_memory; } diff --git a/FHEOffline/Prover.h b/FHEOffline/Prover.h index 91bf05cf1..d0dd2f887 100644 --- a/FHEOffline/Prover.h +++ b/FHEOffline/Prover.h @@ -24,8 +24,7 @@ class Prover Prover(Proof& proof, const FD& FieldD); void Stage_1(const Proof& P, octetStream& ciphertexts, const AddableVector& c, - const FHE_PK& pk, - bool binary = false); + const FHE_PK& pk); bool Stage_2(Proof& P, octetStream& cleartexts, const vector& x, @@ -40,8 +39,7 @@ class Prover const FHE_PK& pk, const AddableVector& c, const vector& x, - const Proof::Randomness& r, - bool binary=false); + const Proof::Randomness& r); size_t report_size(ReportType type); void report_size(ReportType type, MemoryUsage& res); diff --git a/FHEOffline/SimpleEncCommit.cpp b/FHEOffline/SimpleEncCommit.cpp index 4685e6169..912920679 100644 --- a/FHEOffline/SimpleEncCommit.cpp +++ b/FHEOffline/SimpleEncCommit.cpp @@ -11,6 +11,7 @@ #include "Protocols/MAC_Check.h" #include "Protocols/MAC_Check.hpp" +#include "Math/modp.hpp" template SimpleEncCommitBase::SimpleEncCommitBase(const MachineBase& machine) : @@ -63,7 +64,10 @@ SimpleEncCommitFactory::SimpleEncCommitFactory(const FHE_PK& pk, template SimpleEncCommitFactory::~SimpleEncCommitFactory() { - cout << "EncCommit called " << n_calls << " times" << endl; +#ifdef VERBOSE_HE + if (n_calls > 0) + cout << "EncCommit called " << n_calls << " times" << endl; +#endif } template @@ -131,7 +135,7 @@ size_t NonInteractiveProofSimpleEncCommit::generate_proof(AddableVector > prover(proof, FTD); #endif size_t prover_memory = prover.NIZKPoK(proof, ciphertexts, cleartexts, - pk, c, m, r, false); + pk, c, m, r); timers["Proving"].stop(); if (proof.top_gear) @@ -192,7 +196,7 @@ size_t NonInteractiveProofSimpleEncCommit::create_more(octetStream& cipherte #endif timers["Verifying"].start(); verifier.NIZKPoK(others_ciphertexts, ciphertexts, - cleartexts, get_pk_for_verification(i), false); + cleartexts, get_pk_for_verification(i)); timers["Verifying"].stop(); add_ciphertexts(others_ciphertexts, i); this->memory_usage.update("verifier", verifier.report_size(CAPACITY)); @@ -251,7 +255,7 @@ void SummingEncCommit::create_more() #endif this->generate_ciphertexts(this->c, this->m, r, pk, timers, proof); this->timers["Stage 1 of proof"].start(); - prover.Stage_1(proof, ciphertexts, this->c, this->pk, false); + prover.Stage_1(proof, ciphertexts, this->c, this->pk); this->timers["Stage 1 of proof"].stop(); this->c.unpack(ciphertexts, this->pk); @@ -291,8 +295,10 @@ void SummingEncCommit::create_more() for (int i = 1; i < P.num_players(); i++) { +#ifdef VERBOSE_HE cout << "Sending cleartexts with " << 1e-9 * cleartexts.get_length() << " GB in round " << i << endl; +#endif TimeScope(this->timers["Exchanging cleartexts"]); P.pass_around(cleartexts); preimages.add(cleartexts); @@ -312,7 +318,7 @@ void SummingEncCommit::create_more() Verifier verifier(proof); #endif verifier.Stage_2(this->c, ciphertexts, cleartexts, - this->pk, false); + this->pk); this->timers["Verifying"].stop(); this->cnt = proof.U - 1; diff --git a/FHEOffline/Verifier.cpp b/FHEOffline/Verifier.cpp index 9c26e94c0..f40df8925 100644 --- a/FHEOffline/Verifier.cpp +++ b/FHEOffline/Verifier.cpp @@ -25,7 +25,10 @@ bool Check_Decoding(const Plaintext& AE,bool Diag) // return false; // } if (Diag && !AE.is_diagonal()) - { cout << "Fail Check 5 " << endl; + { +#ifdef VERBOSE + cout << "Fail Check 5 " << endl; +#endif return false; } return true; @@ -62,7 +65,7 @@ template void Verifier::Stage_2( AddableVector& c,octetStream& ciphertexts, octetStream& cleartexts, - const FHE_PK& pk,bool binary) + const FHE_PK& pk) { unsigned int i, V; @@ -90,18 +93,19 @@ void Verifier::Stage_2( rc.assign(t[0], t[1], t[2]); pk.encrypt(d2,z,rc); if (!(d1 == d2)) - { cout << "Fail Check 6 " << i << endl; + { +#ifdef VERBOSE + cout << "Fail Check 6 " << i << endl; +#endif throw runtime_error("ciphertexts don't match"); } if (!Check_Decoding(z,P.get_diagonal(),FieldD)) - { cout << "\tCheck : " << i << endl; + { +#ifdef VERBOSE + cout << "\tCheck : " << i << endl; +#endif throw runtime_error("cleartext isn't diagonal"); } - if (binary && !z.is_binary()) - { - cout << "Not binary " << i << endl; - throw runtime_error("cleartext isn't binary"); - } } } @@ -112,17 +116,15 @@ void Verifier::Stage_2( template void Verifier::NIZKPoK(AddableVector& c, octetStream& ciphertexts, octetStream& cleartexts, - const FHE_PK& pk, - bool binary) + const FHE_PK& pk) { P.set_challenge(ciphertexts); - Stage_2(c,ciphertexts,cleartexts,pk,binary); + Stage_2(c,ciphertexts,cleartexts,pk); if (P.top_gear) { assert(not P.get_diagonal()); - assert(not binary); c += c; } } diff --git a/FHEOffline/Verifier.h b/FHEOffline/Verifier.h index dd9614488..68a6605eb 100644 --- a/FHEOffline/Verifier.h +++ b/FHEOffline/Verifier.h @@ -21,14 +21,14 @@ class Verifier void Stage_2( AddableVector& c, octetStream& ciphertexts, - octetStream& cleartexts,const FHE_PK& pk,bool binary=false); + octetStream& cleartexts,const FHE_PK& pk); /* This is the non-interactive version using the ROM - Creates space for all output values - Diag flag mirrors that in Prover */ void NIZKPoK(AddableVector& c,octetStream& ciphertexts,octetStream& cleartexts, - const FHE_PK& pk,bool binary=false); + const FHE_PK& pk); size_t report_size(ReportType type) { return z.report_size(type) + t.report_size(type); } }; diff --git a/GC/CcdPrep.h b/GC/CcdPrep.h index e38ad5f08..f3da07ca1 100644 --- a/GC/CcdPrep.h +++ b/GC/CcdPrep.h @@ -19,13 +19,13 @@ template class CcdPrep : public BufferPrep { typename T::part_type::LivePrep part_prep; - typename T::part_type::MAC_Check part_MC; SubProcessor* part_proc; ShareThread& thread; public: CcdPrep(DataPositions& usage, ShareThread& thread) : - BufferPrep(usage), part_prep(usage), part_proc(0), thread(thread) + BufferPrep(usage), part_prep(usage, thread), part_proc(0), + thread(thread) { } @@ -34,17 +34,9 @@ class CcdPrep : public BufferPrep { } - ~CcdPrep() - { - if (part_proc) - delete part_proc; - } + ~CcdPrep(); - void set_protocol(typename T::Protocol& protocol) - { - part_proc = new SubProcessor(part_MC, - part_prep, protocol.get_part().P); - } + void set_protocol(typename T::Protocol& protocol); Preprocessing& get_part() { @@ -53,7 +45,16 @@ class CcdPrep : public BufferPrep void buffer_triples() { - throw not_implemented(); + assert(part_proc); + this->triples.push_back({}); + for (auto& x : this->triples.back()) + x.resize_regs(T::default_length); + for (int i = 0; i < T::default_length; i++) + { + auto triple = part_prep.get_triple(1); + for (int j = 0; j < 3; j++) + this->triples.back()[j].get_bit(j) = triple[j]; + } } void buffer_bits() @@ -72,6 +73,25 @@ class CcdPrep : public BufferPrep { throw not_implemented(); } + + void buffer_inputs(int player) + { + this->inputs[player].push_back({}); + this->inputs[player].back().share.resize_regs(T::default_length); + for (int i = 0; i < T::default_length; i++) + { + typename T::part_type::open_type tmp; + part_prep.get_input(this->inputs[player].back().share.get_reg(i), + tmp, player); + this->inputs[player].back().value ^= + (typename T::clear(tmp.get_bit(0)) << i); + } + } + + size_t data_sent() + { + return part_prep.data_sent(); + } }; } /* namespace GC */ diff --git a/GC/CcdPrep.hpp b/GC/CcdPrep.hpp new file mode 100644 index 000000000..62f0f097e --- /dev/null +++ b/GC/CcdPrep.hpp @@ -0,0 +1,33 @@ +/* + * CcdPrep.hpp + * + */ + +#ifndef GC_CCDPREP_HPP_ +#define GC_CCDPREP_HPP_ + +#include "CcdPrep.h" + +#include "Processor/Processor.hpp" + +namespace GC +{ + +template +CcdPrep::~CcdPrep() +{ + if (part_proc) + delete part_proc; +} + +template +void CcdPrep::set_protocol(typename T::Protocol& protocol) +{ + assert(thread.MC); + part_proc = new SubProcessor( + thread.MC->get_part_MC(), part_prep, protocol.get_part().P); +} + +} + +#endif /* GC_CCDPREP_HPP_ */ diff --git a/GC/CcdShare.h b/GC/CcdShare.h index fcd16bff0..cd6c3618f 100644 --- a/GC/CcdShare.h +++ b/GC/CcdShare.h @@ -34,11 +34,6 @@ class CcdShare : public ShamirShare, public ShareSecret> static const int default_length = 1; - static DataFieldType field_type() - { - return DATA_GF2; - } - static string name() { return "CCD"; diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 9cdf97e8c..a23c303b4 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -70,8 +70,6 @@ class FakeSecret : public ShareInterface, public BitVec static const true_type invertible; static const true_type characteristic_two; - static DataFieldType field_type() { return DATA_GF2; } - static MC* new_mc(mac_key_type key) { return new MC(key); } static void store_clear_in_dynamic(Memory& mem, diff --git a/GC/MaliciousCcdShare.h b/GC/MaliciousCcdShare.h index 1c5841824..6f9410961 100644 --- a/GC/MaliciousCcdShare.h +++ b/GC/MaliciousCcdShare.h @@ -39,11 +39,6 @@ class MaliciousCcdShare: public MaliciousShamirShare, public ShareSecret< static const int default_length = 1; - static DataFieldType field_type() - { - return DATA_GF2; - } - static string name() { return "Malicious CCD"; diff --git a/GC/NoShare.h b/GC/NoShare.h index fed195e8b..9cea3fa0b 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -54,6 +54,11 @@ class NoValue : public ValueInterface return "no"; } + static DataFieldType field_type() + { + throw not_implemented(); + } + static void fail() { throw runtime_error("VM does not support binary circuits"); @@ -101,16 +106,10 @@ class NoShare : public ShareInterface typedef NoValue clear; typedef NoValue mac_key_type; - typedef NoShare bit_type; - typedef NoShare part_type; typedef NoShare small_type; typedef BlackHole out_type; - static const int default_length = 1; - - static const bool needs_ot = false; - static const bool expensive_triples = false; static const bool is_real = false; static MC* new_mc(mac_key_type) @@ -118,21 +117,6 @@ class NoShare : public ShareInterface return new MC; } - template - static void generate_mac_key(mac_key_type, T) - { - } - - static DataFieldType field_type() - { - throw not_implemented(); - } - - static string type_short() - { - return ""; - } - static string type_string() { return "no"; @@ -155,7 +139,6 @@ class NoShare : public ShareInterface static void ands(Processor&, const vector&) { fail(); } static void andrs(Processor&, const vector&) { fail(); } - static void input(Processor&, InputArgs&) { fail(); } static void trans(Processor&, Integer, const vector&) { fail(); } static NoShare constant(const GC::Clear&, int, mac_key_type) { fail(); return {}; } @@ -166,11 +149,8 @@ class NoShare : public ShareInterface void load_clear(Integer, Integer) { fail(); } void random_bit() { fail(); } - void and_(int, NoShare&, NoShare&, bool) { fail(); } - void xor_(int, NoShare&, NoShare&) { fail(); } void bitdec(vector&, const vector&) const { fail(); } void bitcom(vector&, const vector&) const { fail(); } - void reveal(Integer, Integer) { fail(); } void assign(const char*) { fail(); } @@ -183,13 +163,11 @@ class NoShare : public ShareInterface NoShare operator-(const NoShare&) const { fail(); return {}; } NoShare operator*(const NoValue&) const { fail(); return {}; } - NoShare operator+(int) const { fail(); return {}; } NoShare operator&(int) const { fail(); return {}; } NoShare operator>>(int) const { fail(); return {}; } NoShare& operator+=(const NoShare&) { fail(); return *this; } - NoShare lsb() const { fail(); return {}; } NoShare get_bit(int) const { fail(); return {}; } void invert(int, NoShare) { fail(); } diff --git a/GC/PersonalPrep.hpp b/GC/PersonalPrep.hpp index dfeb77bc4..df1725854 100644 --- a/GC/PersonalPrep.hpp +++ b/GC/PersonalPrep.hpp @@ -88,7 +88,7 @@ void PersonalPrep::buffer_personal_triples(vector>& triples, input.reset_all(P); for (size_t i = begin; i < end; i++) { - typename T::clear x[2]; + typename T::open_type x[2]; for (int j = 0; j < 2; j++) this->get_input(triples[i][j], x[j], input_player); if (P.my_num() == input_player) diff --git a/GC/Processor.h b/GC/Processor.h index 25fcca90b..d0a84a6a2 100644 --- a/GC/Processor.h +++ b/GC/Processor.h @@ -84,6 +84,7 @@ class Processor : public ::ProcessorBase, public GC::RuntimeBranching void xors(const vector& args); void xors(const vector& args, size_t start, size_t end); + void xorc(const ::BaseInstruction& instruction); void nots(const ::BaseInstruction& instruction); void andm(const ::BaseInstruction& instruction); void and_(const vector& args, bool repeat); diff --git a/GC/Processor.hpp b/GC/Processor.hpp index 8cf85e0cd..d2151fe2b 100644 --- a/GC/Processor.hpp +++ b/GC/Processor.hpp @@ -18,6 +18,7 @@ using namespace std; #include "GC/Machine.hpp" #include "Processor/ProcessorBase.hpp" +#include "Processor/IntInput.hpp" #include "Math/bigint.hpp" namespace GC @@ -82,8 +83,12 @@ U GC::Processor::get_long_input(const int* params, { if (not T::actual_inputs) return {}; - U res = input_proc.get_input>(interactive, - ¶ms[1]).items[0]; + U res; + if (params[1] == 0) + res = input_proc.get_input>(interactive, 0).items[0]; + else + res = input_proc.get_input>(interactive, + ¶ms[1]).items[0]; int n_bits = *params; check_input(res, n_bits); return res; @@ -229,6 +234,18 @@ void Processor::xors(const vector& args, size_t start, size_t end) } } +template +void Processor::xorc(const ::BaseInstruction& instruction) +{ + int total = instruction.get_n(); + for (int i = 0; i < DIV_CEIL(total, T::default_length); i++) + { + int n = min(T::default_length, total - i * T::default_length); + C[instruction.get_r(0) + i] = BitVec(C[instruction.get_r(1) + i]).mask(n) + ^ BitVec(C[instruction.get_r(2) + i]).mask(n); + } +} + template void Processor::nots(const ::BaseInstruction& instruction) { diff --git a/GC/Rep4Secret.h b/GC/Rep4Secret.h index f17ae1e37..a708cc315 100644 --- a/GC/Rep4Secret.h +++ b/GC/Rep4Secret.h @@ -7,7 +7,6 @@ #define GC_REP4SECRET_H_ #include "ShareSecret.h" -#include "Processor/NoLivePrep.h" #include "Protocols/Rep4MC.h" #include "Protocols/Rep4Share.h" diff --git a/GC/SemiHonestRepPrep.cpp b/GC/SemiHonestRepPrep.cpp deleted file mode 100644 index efb21d55a..000000000 --- a/GC/SemiHonestRepPrep.cpp +++ /dev/null @@ -1,11 +0,0 @@ -/* - * ReplicatedPrep.cpp - * - */ - -#include - -namespace GC -{ - -} /* namespace GC */ diff --git a/GC/ShareParty.hpp b/GC/ShareParty.hpp index c38bac03c..2250ff461 100644 --- a/GC/ShareParty.hpp +++ b/GC/ShareParty.hpp @@ -119,7 +119,9 @@ ShareParty::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt, try { - read_mac_key(get_prep_sub_dir(PREP_DIR, network_opts.nplayers), this->N, + read_mac_key( + get_prep_sub_dir(PREP_DIR, network_opts.nplayers), + this->N, this->mac_key); } catch (exception& e) diff --git a/GC/ShareSecret.h b/GC/ShareSecret.h index 3a76dd233..a2bb5958c 100644 --- a/GC/ShareSecret.h +++ b/GC/ShareSecret.h @@ -38,6 +38,8 @@ template class ShareSecret { public: + typedef U whole_type; + typedef Memory DynamicMemory; typedef SwitchableOutput out_type; diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 3066e2e02..c8b5a327e 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -21,6 +21,7 @@ #include "ShareParty.h" #include "ShareThread.hpp" #include "Thread.hpp" +#include "VectorProtocol.hpp" namespace GC { diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index 91087647c..0280fb50a 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -29,7 +29,8 @@ ShareThread::ShareThread(const Names& N, OnlineOptions& opts, DataPositions& *static_cast*>(new typename T::LivePrep( usage, *this)) : *static_cast*>(new BitPrepFiles(N, - get_prep_sub_dir(PREP_DIR, N.num_players()), usage))) + get_prep_sub_dir(PREP_DIR, N.num_players()), + usage, BaseMachine::thread_num))) { } diff --git a/GC/TinierPrep.h b/GC/TinierPrep.h deleted file mode 100644 index be35f5a96..000000000 --- a/GC/TinierPrep.h +++ /dev/null @@ -1,37 +0,0 @@ -/* - * TinierPrep.h - * - */ - -#ifndef GC_TINIERPREP_H_ -#define GC_TINIERPREP_H_ - -#include "TinyPrep.h" - -namespace GC -{ - -template -class TinierPrep : public TinyPrep -{ -public: - TinierPrep(DataPositions& usage, ShareThread& thread, - bool amplify = true) : - TinyPrep(usage, thread, amplify) - { - } - - TinierPrep(SubProcessor*, DataPositions& usage) : - TinierPrep(usage, ShareThread::s()) - { - } - - void buffer_inputs(int player) - { - this->buffer_inputs_(player, this->triple_generator); - } -}; - -} - -#endif /* GC_TINIERPREP_H_ */ diff --git a/GC/TinierSecret.h b/GC/TinierSecret.h index ffb29913d..6c2a1b6fa 100644 --- a/GC/TinierSecret.h +++ b/GC/TinierSecret.h @@ -15,6 +15,9 @@ namespace GC { template class TinierPrep; +template class VectorProtocol; +template class CcdPrep; +template class VectorInput; template class TinierSecret : public VectorSecret> @@ -25,9 +28,9 @@ class TinierSecret : public VectorSecret> public: typedef TinyMC MC; typedef MC MAC_Check; - typedef Beaver Protocol; - typedef ::Input Input; - typedef TinierPrep LivePrep; + typedef VectorProtocol Protocol; + typedef VectorInput Input; + typedef CcdPrep LivePrep; typedef Memory DynamicMemory; typedef NPartyTripleGenerator TripleGenerator; diff --git a/GC/TinierShare.h b/GC/TinierShare.h index ea23a5078..4a57bc46c 100644 --- a/GC/TinierShare.h +++ b/GC/TinierShare.h @@ -9,12 +9,12 @@ #include "Processor/DummyProtocol.h" #include "Protocols/Share.h" #include "Math/Bit.h" -#include "TinierSharePrep.h" namespace GC { template class TinierSecret; +template class TinierSharePrep; template class TinierShare: public Share_, SemiShare>, @@ -55,6 +55,11 @@ class TinierShare: public Share_, SemiShare>, return "Tinier"; } + static string type_short() + { + return "TT"; + } + static ShareThread>& get_party() { return ShareThread>::s(); @@ -103,9 +108,7 @@ class TinierShare: public Share_, SemiShare>, void random() { - TinierSecret tmp; - get_party().DataF.get_one(DATA_BIT, tmp); - *this = tmp.get_reg(0); + *this = get_party().DataF.get_part().get_bit(); } This lsb() const diff --git a/GC/TinierSharePrep.h b/GC/TinierSharePrep.h index 7f292ac73..cad2e969b 100644 --- a/GC/TinierSharePrep.h +++ b/GC/TinierSharePrep.h @@ -21,18 +21,26 @@ template class TinierSharePrep : public PersonalPrep { typename T::TripleGenerator* triple_generator; + typename T::whole_type::TripleGenerator* real_triple_generator; MascotParams params; - TinierPrep> whole_prep; + typedef typename T::whole_type secret_type; + ShareThread& thread; void buffer_triples(); void buffer_squares() { throw not_implemented(); } - void buffer_bits() { throw not_implemented(); } + void buffer_bits(); void buffer_inverses() { throw not_implemented(); } void buffer_inputs(int player); + void buffer_secret_triples(); + + void init_real(Player& P); + public: + TinierSharePrep(DataPositions& usage, ShareThread& thread, + int input_player = PersonalPrep::SECURE); TinierSharePrep(DataPositions& usage, int input_player = PersonalPrep::SECURE); TinierSharePrep(SubProcessor*, DataPositions& usage); diff --git a/GC/TinierSharePrep.hpp b/GC/TinierSharePrep.hpp index 02c7dd7a2..964e3614a 100644 --- a/GC/TinierSharePrep.hpp +++ b/GC/TinierSharePrep.hpp @@ -15,10 +15,16 @@ namespace GC template TinierSharePrep::TinierSharePrep(DataPositions& usage, int input_player) : + TinierSharePrep(usage, ShareThread::s(), input_player) +{ +} + +template +TinierSharePrep::TinierSharePrep(DataPositions& usage, + ShareThread& thread, int input_player) : PersonalPrep(usage, input_player), triple_generator(0), - whole_prep(usage, - ShareThread>::s(), - input_player == PersonalPrep::SECURE) + real_triple_generator(0), + thread(thread) { } @@ -33,6 +39,8 @@ TinierSharePrep::~TinierSharePrep() { if (triple_generator) delete triple_generator; + if (real_triple_generator) + delete real_triple_generator; } template @@ -44,15 +52,14 @@ void TinierSharePrep::set_protocol(typename T::Protocol& protocol) params.generateMACs = true; params.amplify = false; params.check = false; - auto& thread = ShareThread>::s(); + auto& thread = ShareThread::s(); triple_generator = new typename T::TripleGenerator( BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1, - OnlineOptions::singleton.batch_size - * TinierSecret::default_length, 1, + OnlineOptions::singleton.batch_size, 1, params, thread.MC->get_alphai(), &protocol.P); triple_generator->multi_threaded = false; this->inputs.resize(thread.P->num_players()); - whole_prep.init(*thread.P); + init_real(protocol.P); } template @@ -63,12 +70,8 @@ void TinierSharePrep::buffer_triples() this->buffer_personal_triples(); return; } - - array, 3> whole; - whole_prep.get(DATA_TRIPLE, whole.data()); - for (size_t i = 0; i < whole[0].get_regs().size(); i++) - this->triples.push_back( - {{ whole[0].get_reg(i), whole[1].get_reg(i), whole[2].get_reg(i) }}); + else + buffer_secret_triples(); } template @@ -81,12 +84,21 @@ void TinierSharePrep::buffer_inputs(int player) inputs.at(player).push_back(x); } +template +void GC::TinierSharePrep::buffer_bits() +{ + this->bits.push_back( + BufferPrep::get_random_from_inputs(thread.P->num_players())); +} + template size_t TinierSharePrep::data_sent() { - size_t res = whole_prep.data_sent(); + size_t res = 0; if (triple_generator) res += triple_generator->data_sent(); + if (real_triple_generator) + res += real_triple_generator->data_sent(); return res; } diff --git a/GC/TinyMC.cpp b/GC/TinyMC.cpp deleted file mode 100644 index ff4320070..000000000 --- a/GC/TinyMC.cpp +++ /dev/null @@ -1,11 +0,0 @@ -/* - * TinyMC.cpp - * - */ - -#include "TinyMC.h" - -namespace GC -{ - -} /* namespace GC */ diff --git a/GC/TinyMC.h b/GC/TinyMC.h index 2e6344897..e0a0b948b 100644 --- a/GC/TinyMC.h +++ b/GC/TinyMC.h @@ -14,7 +14,7 @@ namespace GC template class TinyMC : public MAC_Check_Base { - typename T::check_type::MAC_Check part_MC; + typename T::part_type::MAC_Check part_MC; PointerVector sizes; public: diff --git a/GC/TinyPrep.h b/GC/TinyPrep.h deleted file mode 100644 index 22d2b0bc3..000000000 --- a/GC/TinyPrep.h +++ /dev/null @@ -1,71 +0,0 @@ -/* - * TinyPrep.h - * - */ - -#ifndef GC_TINYPREP_H_ -#define GC_TINYPREP_H_ - -#include "Thread.h" -#include "OT/MascotParams.h" -#include "Protocols/Beaver.h" -#include "Protocols/ReplicatedPrep.h" - -namespace GC -{ - -template -class TinyPrep : public BufferPrep -{ -protected: - ShareThread& thread; - - typename T::TripleGenerator* triple_generator; - MascotParams params; - - vector> triple_buffer; - - const bool amplify; - -public: - TinyPrep(DataPositions& usage, ShareThread& thread, bool amplify = true); - ~TinyPrep(); - - void set_protocol(Beaver& protocol); - void init(Player& P); - - void buffer_triples(); - void buffer_bits(); - - void buffer_squares() { throw not_implemented(); } - void buffer_inverses() { throw not_implemented(); } - - void buffer_inputs_(int player, typename T::InputGenerator* input_generator); - - array get_triple_no_count(int n_bits); - - size_t data_sent(); -}; - -template -class TinyOnlyPrep : public TinyPrep -{ - typename T::part_type::TripleGenerator* input_generator; - -public: - TinyOnlyPrep(DataPositions& usage, ShareThread& thread); - ~TinyOnlyPrep(); - - void set_protocol(Beaver& protocol); - - void buffer_inputs(int player) - { - this->buffer_inputs_(player, input_generator); - } - - size_t data_sent(); -}; - -} /* namespace GC */ - -#endif /* GC_TINYPREP_H_ */ diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index b4c7c7b47..eae76ab59 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -3,7 +3,7 @@ * */ -#include "TinyPrep.h" +#include "TinierSharePrep.h" #include "Protocols/MascotPrep.hpp" @@ -11,78 +11,26 @@ namespace GC { template -TinyPrep::TinyPrep(DataPositions& usage, ShareThread& thread, - bool amplify) : - BufferPrep(usage), thread(thread), triple_generator(0), - amplify(amplify) +void TinierSharePrep::init_real(Player& P) { - -} - -template -TinyOnlyPrep::TinyOnlyPrep(DataPositions& usage, ShareThread& thread) : - TinyPrep(usage, thread), input_generator(0) -{ -} - -template -TinyPrep::~TinyPrep() -{ - if (triple_generator) - delete triple_generator; -} - -template -TinyOnlyPrep::~TinyOnlyPrep() -{ - if (input_generator) - delete input_generator; -} - -template -void TinyPrep::set_protocol(Beaver& protocol) -{ - init(protocol.P); -} - -template -void TinyPrep::init(Player& P) -{ - params.generateMACs = true; - params.amplify = false; - params.check = false; - auto& thread = ShareThread::s(); - triple_generator = new typename T::TripleGenerator( + assert(real_triple_generator == 0); + real_triple_generator = new typename T::whole_type::TripleGenerator( BaseMachine::s().fresh_ot_setup(), P.N, -1, OnlineOptions::singleton.batch_size, 1, params, thread.MC->get_alphai(), &P); - triple_generator->multi_threaded = false; -} - -template -void TinyOnlyPrep::set_protocol(Beaver& protocol) -{ - TinyPrep::set_protocol(protocol); - input_generator = new typename T::part_type::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1, - OnlineOptions::singleton.batch_size, 1, this->params, - this->thread.MC->get_alphai(), &protocol.P); - input_generator->multi_threaded = false; + real_triple_generator->multi_threaded = false; } template -void TinyPrep::buffer_triples() +void TinierSharePrep::buffer_secret_triples() { - auto& triple_generator = this->triple_generator; + auto& triple_generator = real_triple_generator; assert(triple_generator != 0); params.generateBits = false; - vector> triples; - TripleShuffleSacrifice sacrifice; + vector> triples; + TripleShuffleSacrifice sacrifice; size_t required; - if (amplify) - required = sacrifice.minimum_n_inputs_with_combining(); - else - required = sacrifice.minimum_n_inputs(); + required = sacrifice.minimum_n_inputs_with_combining(); while (triples.size() < required) { triple_generator->generatePlainTriples(); @@ -92,9 +40,11 @@ void TinyPrep::buffer_triples() triple_generator->valueBits[2].set_portion(i, triple_generator->plainTriples[i][2]); triple_generator->run_multipliers({}); + assert(triple_generator->plainTriples.size() != 0); for (size_t i = 0; i < triple_generator->plainTriples.size(); i++) { - for (int j = 0; j < T::default_length; j++) + int dl = secret_type::default_length; + for (int j = 0; j < dl; j++) { triples.push_back({}); for (int k = 0; k < 3; k++) @@ -103,10 +53,10 @@ void TinyPrep::buffer_triples() share.set_share( triple_generator->plainTriples.at(i).at(k).get_bit( j)); - typename T::part_type::mac_type mac; + typename T::mac_type mac; mac = thread.MC->get_alphai() * share.get_share(); for (auto& multiplier : triple_generator->ot_multipliers) - mac += multiplier->macs.at(k).at(i * T::default_length + j); + mac += multiplier->macs.at(k).at(i * dl + j); share.set_mac(mac); } } @@ -114,104 +64,10 @@ void TinyPrep::buffer_triples() } sacrifice.triple_sacrifice(triples, triples, *thread.P, thread.MC->get_part_MC()); - if (amplify) - sacrifice.triple_combine(triples, triples, *thread.P, - thread.MC->get_part_MC()); - for (size_t i = 0; i < triples.size() / T::default_length; i++) - { - this->triples.push_back({}); - auto& triple = this->triples.back(); - for (auto& x : triple) - x.resize_regs(T::default_length); - for (int j = 0; j < T::default_length; j++) - { - auto& source_triple = triples[j + i * T::default_length]; - for (int k = 0; k < 3; k++) - triple[k].get_reg(j) = source_triple[k]; - } - } -} - -template -void TinyPrep::buffer_bits() -{ - auto tmp = BufferPrep::get_random_from_inputs(thread.P->num_players()); - for (auto& bit : tmp.get_regs()) - { - this->bits.push_back({}); - this->bits.back().resize_regs(1); - this->bits.back().get_reg(0) = bit; - } -} - -template -void TinyPrep::buffer_inputs_(int player, typename T::InputGenerator* input_generator) -{ - auto& inputs = this->inputs; - inputs.resize(this->thread.P->num_players()); - assert(input_generator); - input_generator->generateInputs(player); - assert(input_generator->inputs.size() >= T::default_length); - for (size_t i = 0; i < input_generator->inputs.size() / T::default_length; i++) - { - inputs[player].push_back({}); - inputs[player].back().share.resize_regs(T::default_length); - for (int j = 0; j < T::default_length; j++) - { - auto& source_input = input_generator->inputs[j - + i * T::default_length]; - inputs[player].back().share.get_reg(j) = source_input.share; - inputs[player].back().value ^= typename T::open_type( - source_input.value.get_bit(0)) << j; - } - } -} - -template -array TinyPrep::get_triple_no_count(int n_bits) -{ - assert(n_bits > 0); - while ((unsigned)n_bits > triple_buffer.size()) - { - array tmp; - this->get(DATA_TRIPLE, tmp.data()); - for (size_t i = 0; i < tmp[0].get_regs().size(); i++) - { - triple_buffer.push_back( - { {tmp[0].get_reg(i), tmp[1].get_reg(i), tmp[2].get_reg(i)} }); - } - } - - array res; - for (int j = 0; j < 3; j++) - res[j].resize_regs(n_bits); - - for (int i = 0; i < n_bits; i++) - { - for (int j = 0; j < 3; j++) - res[j].get_reg(i) = triple_buffer.back()[j]; - triple_buffer.pop_back(); - } - - return res; -} - -template -size_t TinyPrep::data_sent() -{ - size_t res = 0; - if (triple_generator) - res += triple_generator->data_sent(); - return res; -} - -template -size_t TinyOnlyPrep::data_sent() -{ - auto res = TinyPrep::data_sent(); - if (input_generator) - res += input_generator->data_sent(); - return res; + sacrifice.triple_combine(triples, triples, *thread.P, + thread.MC->get_part_MC()); + for (auto& triple : triples) + this->triples.push_back(triple); } } /* namespace GC */ diff --git a/GC/TinySecret.cpp b/GC/TinySecret.cpp deleted file mode 100644 index a8f78241d..000000000 --- a/GC/TinySecret.cpp +++ /dev/null @@ -1,11 +0,0 @@ -/* - * TinySecret.cpp - * - */ - -#include "TinySecret.h" - -namespace GC -{ - -} /* namespace GC */ diff --git a/GC/TinySecret.h b/GC/TinySecret.h index 95a00490b..9f48474da 100644 --- a/GC/TinySecret.h +++ b/GC/TinySecret.h @@ -21,6 +21,9 @@ namespace GC template class TinyOnlyPrep; template class TinyMC; +template class VectorProtocol; +template class VectorInput; +template class CcdPrep; template class VectorSecret : public Secret @@ -50,11 +53,6 @@ class VectorSecret : public Secret static const int default_length = 64; - static DataFieldType field_type() - { - return BitVec::field_type(); - } - static int size() { return part_type::size() * default_length; @@ -166,9 +164,9 @@ class VectorSecret : public Secret } template - void other_input(U& inputter, int from, int) + void other_input(U& inputter, int from, int n_bits) { - inputter.add_other(from); + inputter.add_other(from, n_bits); } template @@ -187,9 +185,9 @@ class TinySecret : public VectorSecret> public: typedef TinyMC MC; typedef MC MAC_Check; - typedef Beaver Protocol; - typedef ::Input Input; - typedef TinyOnlyPrep LivePrep; + typedef VectorProtocol Protocol; + typedef VectorInput Input; + typedef CcdPrep LivePrep; typedef Memory DynamicMemory; typedef OTTripleGenerator TripleGenerator; diff --git a/GC/TinyShare.cpp b/GC/TinyShare.cpp deleted file mode 100644 index cdbd03b62..000000000 --- a/GC/TinyShare.cpp +++ /dev/null @@ -1,11 +0,0 @@ -/* - * TinyShare.cpp - * - */ - -#include "TinyShare.h" - -namespace GC -{ - -} /* namespace GC */ diff --git a/GC/TinyShare.h b/GC/TinyShare.h index 980b28b8a..4fbb6092e 100644 --- a/GC/TinyShare.h +++ b/GC/TinyShare.h @@ -10,13 +10,14 @@ #include "ShareParty.h" #include "Secret.h" #include "Protocols/Spdz2kShare.h" -#include "Processor/NoLivePrep.h" + namespace GC { template class TinySecret; template class ShareThread; +template class TinierSharePrep; template class TinyShare : public Spdz2kShare<1, S>, public ShareSecret> @@ -28,12 +29,18 @@ class TinyShare : public Spdz2kShare<1, S>, public ShareSecret> typedef void DynamicMemory; - typedef NoLivePrep LivePrep; + typedef Beaver Protocol; + typedef MAC_Check_Z2k_ MAC_Check; + typedef MAC_Check Direct_MC; + typedef ::Input Input; + typedef TinierSharePrep LivePrep; typedef SwitchableOutput out_type; typedef This small_type; + typedef NoShare bit_type; + static string name() { return "tiny share"; diff --git a/GC/VectorInput.h b/GC/VectorInput.h index fe25beaa1..c17cd93d4 100644 --- a/GC/VectorInput.h +++ b/GC/VectorInput.h @@ -18,14 +18,14 @@ class VectorInput : public InputBase deque input_lengths; public: - VectorInput(typename T::MAC_Check&, Preprocessing&, Player& P) : - part_input(0, P) + VectorInput(typename T::MAC_Check& MC, Preprocessing& prep, Player& P) : + part_input(MC.get_part_MC(), prep.get_part(), P) { part_input.reset_all(P); } VectorInput(SubProcessor& proc, typename T::MAC_Check&) : - VectorInput(proc.MC, proc.DataF, proc.P) + part_input(proc.MC, proc.DataF, proc.P) { } @@ -41,8 +41,10 @@ class VectorInput : public InputBase input_lengths.push_back(n_bits); } - void add_other(int) + void add_other(int player, int n_bits) { + for (int i = 0; i < n_bits; i++) + part_input.add_other(player); } void send_mine() diff --git a/GC/VectorProtocol.h b/GC/VectorProtocol.h index 4237184e9..1292c2212 100644 --- a/GC/VectorProtocol.h +++ b/GC/VectorProtocol.h @@ -17,6 +17,8 @@ class VectorProtocol : public ProtocolBase typename T::part_type::Protocol part_protocol; public: + Player& P; + VectorProtocol(Player& P); void init_mul(SubProcessor* proc); diff --git a/GC/VectorProtocol.hpp b/GC/VectorProtocol.hpp index d95c3a24e..072cb71fc 100644 --- a/GC/VectorProtocol.hpp +++ b/GC/VectorProtocol.hpp @@ -3,6 +3,9 @@ * */ +#ifndef GC_VECTORPROTOCOL_HPP_ +#define GC_VECTORPROTOCOL_HPP_ + #include "VectorProtocol.h" namespace GC @@ -10,7 +13,7 @@ namespace GC template VectorProtocol::VectorProtocol(Player& P) : - part_protocol(P) + part_protocol(P), P(P) { } @@ -54,3 +57,5 @@ T VectorProtocol::finalize_mul(int n) } } /* namespace GC */ + +#endif diff --git a/GC/instructions.h b/GC/instructions.h index 467c1d7c1..31dc0592b 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -40,7 +40,7 @@ #define BIT_INSTRUCTIONS \ X(XORS, T::xors(PROC, EXTRA)) \ - X(XORCB, C0.xor_(PC1, PC2)) \ + X(XORCB, processor.xorc(instruction)) \ X(XORCBI, C0.xor_(PC1, IMM)) \ X(NOTS, processor.nots(INST)) \ X(ANDRS, T::andrs(PROC, EXTRA)) \ diff --git a/Machines/SPDZ.hpp b/Machines/SPDZ.hpp index 07764968d..02ad9b983 100644 --- a/Machines/SPDZ.hpp +++ b/Machines/SPDZ.hpp @@ -20,13 +20,14 @@ #include "GC/TinierSecret.h" #include "GC/TinyMC.h" -#include "GC/TinierPrep.h" +#include "GC/VectorInput.h" #include "GC/ShareParty.hpp" #include "GC/Secret.hpp" #include "GC/TinyPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/TinierSharePrep.hpp" +#include "GC/CcdPrep.hpp" #include "Math/gfp.hpp" diff --git a/Machines/SPDZ2k.hpp b/Machines/SPDZ2k.hpp index 48cb37847..672a29b4e 100644 --- a/Machines/SPDZ2k.hpp +++ b/Machines/SPDZ2k.hpp @@ -8,9 +8,8 @@ #include "GC/TinySecret.h" #include "GC/TinyMC.h" -#include "GC/TinyPrep.h" -#include "GC/TinierPrep.h" #include "GC/TinierSecret.h" +#include "GC/VectorInput.h" #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" @@ -29,3 +28,4 @@ #include "GC/Secret.hpp" #include "GC/TinyPrep.hpp" #include "GC/TinierSharePrep.hpp" +#include "GC/CcdPrep.hpp" diff --git a/Machines/ShamirMachine.hpp b/Machines/ShamirMachine.hpp index f6b790199..080332aea 100644 --- a/Machines/ShamirMachine.hpp +++ b/Machines/ShamirMachine.hpp @@ -30,6 +30,7 @@ #include "GC/ShareSecret.hpp" #include "GC/VectorProtocol.hpp" #include "GC/Secret.hpp" +#include "GC/CcdPrep.hpp" #include "Math/gfp.hpp" ShamirOptions ShamirOptions::singleton; diff --git a/Machines/ccd-party.cpp b/Machines/ccd-party.cpp index 765699c0a..9945f40b9 100644 --- a/Machines/ccd-party.cpp +++ b/Machines/ccd-party.cpp @@ -12,6 +12,7 @@ #include "GC/ShareSecret.hpp" #include "GC/ThreadMaster.hpp" #include "GC/Secret.hpp" +#include "GC/CcdPrep.hpp" #include "Machines/ShamirMachine.hpp" int main(int argc, const char** argv) diff --git a/Machines/cowgear-party.cpp b/Machines/cowgear-party.cpp index 3f09697dd..6e7a333d2 100644 --- a/Machines/cowgear-party.cpp +++ b/Machines/cowgear-party.cpp @@ -12,7 +12,6 @@ #include "FHE/NTL-Subs.h" #include "GC/TinierSecret.h" -#include "GC/TinierPrep.h" #include "GC/TinyMC.h" #include "SPDZ.hpp" diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp index 32a9b6f18..d70e0fe2b 100644 --- a/Machines/emulate.cpp +++ b/Machines/emulate.cpp @@ -18,12 +18,14 @@ int main(int argc, const char** argv) { - OnlineOptions online_opts; - Names N(0, randombytes_random() % (65536 - 1024) + 1024, vector({"localhost"})); + OnlineOptions& online_opts = OnlineOptions::singleton; + Names N; ez::ezOptionParser opt; RingOptions ring_opts(opt, argc, argv); + online_opts = {opt, argc, argv}; opt.parse(argc, argv); opt.syntax = string(argv[0]) + " "; + string progname; if (opt.firstArgs.size() > 1) progname = *opt.firstArgs.at(1); @@ -50,36 +52,14 @@ int main(int argc, const char** argv) int R = ring_opts.ring_size_from_opts_or_schedule(progname); switch (R) { - case 64: - Machine>, FakeShare>(0, N, progname, - online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, - online_opts.live_prep, online_opts).run(); - break; - case 128: - Machine>, FakeShare>(0, N, progname, - online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, - online_opts.live_prep, online_opts).run(); - break; - case 256: - Machine>, FakeShare>(0, N, progname, - online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, - online_opts.live_prep, online_opts).run(); - break; - case 192: - Machine>, FakeShare>(0, N, progname, - online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, - online_opts.live_prep, online_opts).run(); - break; - case 384: - Machine>, FakeShare>(0, N, progname, - online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, - online_opts.live_prep, online_opts).run(); - break; - case 512: - Machine>, FakeShare>(0, N, progname, - online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, true, - online_opts.live_prep, online_opts).run(); +#define X(L) \ + case L: \ + Machine>, FakeShare>(0, N, progname, \ + online_opts.memtype, gf2n::default_degree(), 0, 0, 0, 0, false, \ + online_opts.live_prep, online_opts).run(); \ break; + X(64) X(128) X(256) X(192) X(384) X(512) +#undef X default: cerr << "Not compiled for " << R << "-bit rings" << endl; } diff --git a/Machines/malicious-ccd-party.cpp b/Machines/malicious-ccd-party.cpp index a43542060..4ce84aea3 100644 --- a/Machines/malicious-ccd-party.cpp +++ b/Machines/malicious-ccd-party.cpp @@ -12,6 +12,7 @@ #include "GC/ShareSecret.hpp" #include "GC/ThreadMaster.hpp" #include "GC/Secret.hpp" +#include "GC/CcdPrep.hpp" #include "Machines/ShamirMachine.hpp" #include "Machines/MalRep.hpp" diff --git a/Machines/no-party.cpp b/Machines/no-party.cpp new file mode 100644 index 000000000..2120322f3 --- /dev/null +++ b/Machines/no-party.cpp @@ -0,0 +1,23 @@ +/* + * no-party.cpp + * + */ + +#include "Protocols/NoShare.h" + +#include "Processor/OnlineMachine.hpp" +#include "Processor/Machine.hpp" +#include "Protocols/Replicated.hpp" +#include "Math/gfp.hpp" +#include "Math/Z2k.hpp" + +int main(int argc, const char** argv) +{ + ez::ezOptionParser opt; + OnlineOptions::singleton = {opt, argc, argv}; + OnlineMachine machine(argc, argv, opt, OnlineOptions::singleton); + OnlineOptions::singleton.finalize(opt, argc, argv); + machine.start_networking(); + // use primes of length 65 to 128 for arithmetic computation + machine.run>, NoShare>(); +} diff --git a/Machines/sy-rep-field-party.cpp b/Machines/sy-rep-field-party.cpp index d047a04e8..1da856768 100644 --- a/Machines/sy-rep-field-party.cpp +++ b/Machines/sy-rep-field-party.cpp @@ -12,7 +12,6 @@ #include "Math/gfp.h" #include "Math/gf2n.h" #include "Tools/ezOptionParser.h" -#include "Processor/NoLivePrep.h" #include "GC/MaliciousCcdSecret.h" #include "Processor/FieldMachine.hpp" diff --git a/Machines/tinier-party.cpp b/Machines/tinier-party.cpp index 1a9f5fe54..82122bd11 100644 --- a/Machines/tinier-party.cpp +++ b/Machines/tinier-party.cpp @@ -4,9 +4,9 @@ */ #include "GC/TinierSecret.h" -#include "GC/TinierPrep.h" #include "GC/ShareParty.h" #include "GC/TinyMC.h" +#include "GC/VectorInput.h" #include "GC/ShareParty.hpp" #include "GC/ShareSecret.hpp" @@ -17,6 +17,8 @@ #include "GC/ThreadMaster.hpp" #include "GC/Secret.hpp" #include "GC/TinyPrep.hpp" +#include "GC/TinierSharePrep.hpp" +#include "GC/CcdPrep.hpp" #include "Processor/Instruction.hpp" #include "Protocols/MAC_Check.hpp" diff --git a/Machines/tiny-party.cpp b/Machines/tiny-party.cpp index e24df7a16..7e72361eb 100644 --- a/Machines/tiny-party.cpp +++ b/Machines/tiny-party.cpp @@ -7,6 +7,7 @@ #include "GC/TinierSecret.h" #include "GC/ShareParty.h" #include "GC/TinyMC.h" +#include "GC/VectorInput.h" #include "GC/ShareParty.hpp" #include "GC/ShareSecret.hpp" @@ -17,6 +18,8 @@ #include "GC/ThreadMaster.hpp" #include "GC/Secret.hpp" #include "GC/TinyPrep.hpp" +#include "GC/CcdPrep.hpp" +#include "GC/TinierSharePrep.hpp" #include "Processor/Instruction.hpp" #include "Protocols/MAC_Check.hpp" diff --git a/Makefile b/Makefile index 6406b28e0..c44f62916 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ GC_SEMI = GC/SemiSecret.o GC/SemiPrep.o GC/square64.o OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp)) OT_EXE = ot.x ot-offline.x -COMMON = $(MATH) $(TOOLS) $(NETWORK) GC/square64.o Processor/OnlineOptions.o Processor/BaseMachine.o +COMMON = $(MATH) $(TOOLS) $(NETWORK) GC/square64.o Processor/OnlineOptions.o Processor/BaseMachine.o Processor/DataPositions.o Processor/ThreadQueues.o Processor/ThreadQueue.o COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT) YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) BMR/Key.o BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp)) @@ -221,6 +221,7 @@ ps-rep-ring-party.x: Protocols/MalRepRingOptions.o malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o sy-rep-ring-party.x: Protocols/MalRepRingOptions.o rep4-ring-party.x: GC/Rep4Secret.o +no-party.x: Protocols/ShareInterface.o semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o GC/SemiSecret.o mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) @@ -238,6 +239,7 @@ static/mal-shamir-bmr-party.x: $(BMR) static/semi-bmr-party.x: $(BMR) static/real-bmr-party.x: $(BMR) static/bmr-program-party.x: $(BMR) +static/no-party.x: Protocols/ShareInterface.o ifeq ($(AVX_OT), 1) $(LIBSIMPLEOT): SimpleOT/Makefile diff --git a/Math/BitVec.cpp b/Math/BitVec.cpp deleted file mode 100644 index 89886bcb3..000000000 --- a/Math/BitVec.cpp +++ /dev/null @@ -1,7 +0,0 @@ -/* - * BitVec.cpp - * - */ - -#include "BitVec.h" - diff --git a/Math/FixedVec.h b/Math/FixedVec.h index e579b3f69..e2dfb38af 100644 --- a/Math/FixedVec.h +++ b/Math/FixedVec.h @@ -47,10 +47,6 @@ class FixedVec { return string(1, T::type_char()); } - static DataFieldType field_type() - { - return T::field_type(); - } template static FixedVec Mul(const FixedVec& a, const V& b) diff --git a/Math/Integer.h b/Math/Integer.h index 4f7506909..006d20e18 100644 --- a/Math/Integer.h +++ b/Math/Integer.h @@ -85,6 +85,7 @@ class IntBase : public ValueInterface T& operator&=(const IntBase& other) { return a &= other.a; } friend ostream& operator<<(ostream& s, const IntBase& x) { x.output(s, true); return s; } + friend istream& operator>>(istream& s, IntBase& x) { x.input(s, true); return s; } void randomize(PRNG& G); void almost_randomize(PRNG& G) { randomize(G); } diff --git a/Math/Z2k.h b/Math/Z2k.h index c22a2ad95..f70ef0ad7 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -198,6 +198,11 @@ class SignedZ2 : public Z2 extend(other); } + void to(bigint& res) const + { + res = *this; + } + bool negative() const { return this->a[this->N_WORDS - 1] & 1ll << ((K - 1) % (8 * sizeof(mp_limb_t))); @@ -401,12 +406,6 @@ void Z2::unpack(octetStream& o, int n) o.consume((octet*)a, N_BYTES); } -template -void to_gfp(Z2& res, const bigint& a) -{ - res = a; -} - template SignedZ2 abs(const SignedZ2& x) { @@ -435,10 +434,4 @@ ostream& operator<<(ostream& o, const SignedZ2& x) return o; } -template -void to_bigint(bigint& res, const SignedZ2& a) -{ - res = a; -} - #endif /* MATH_Z2K_H_ */ diff --git a/Math/bigint.h b/Math/bigint.h index 081a49c4b..e62b39f3a 100644 --- a/Math/bigint.h +++ b/Math/bigint.h @@ -196,6 +196,18 @@ bigint& bigint::operator=(const gfp_& x) return *this; } +template +void to_bigint(bigint& res, const T& other) +{ + other.to(res); +} + +template +void to_gfp(T& res, const bigint& a) +{ + res = a; +} + string to_string(const bigint& x); /********************************** diff --git a/Math/field_types.h b/Math/field_types.h index e02cd1b8f..9f54d3afa 100644 --- a/Math/field_types.h +++ b/Math/field_types.h @@ -15,8 +15,6 @@ enum Dtype DATA_SQUARE, DATA_BIT, DATA_INVERSE, - DATA_BITTRIPLE, - DATA_BITGF2NTRIPLE, DATA_DABIT, N_DTYPE }; diff --git a/Math/gf2n.h b/Math/gf2n.h index 4c51abc30..8eba20aeb 100644 --- a/Math/gf2n.h +++ b/Math/gf2n.h @@ -144,8 +144,6 @@ class gf2n_short : public ValueInterface { a=x.a^y.a; } // = x * y void mul(const gf2n_short& x,const gf2n_short& y); - // x * y when one of x,y is a bit - void mul_by_bit(const gf2n_short& x, const gf2n_short& y) { a = x.a * y.a; } gf2n_short lazy_add(const gf2n_short& x) const { return *this + x; } gf2n_short lazy_mul(const gf2n_short& x) const { return *this * x; } diff --git a/Math/gf2nlong.h b/Math/gf2nlong.h index 77467f69e..a4d734324 100644 --- a/Math/gf2nlong.h +++ b/Math/gf2nlong.h @@ -180,8 +180,6 @@ class gf2n_long : public ValueInterface { a=x.a^y.a; } // = x * y gf2n_long& mul(const gf2n_long& x,const gf2n_long& y); - // x * y when one of x,y is a bit - void mul_by_bit(const gf2n_long& x, const gf2n_long& y) { a = x.a.a * y.a.a; } gf2n_long lazy_add(const gf2n_long& x) const { return *this + x; } gf2n_long lazy_mul(const gf2n_long& x) const { return *this * x; } diff --git a/Math/gfp.h b/Math/gfp.h index 255815fa3..5b493ce74 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -68,7 +68,7 @@ class gfp_ : public ValueInterface // must be negative static const int N_BITS = -1; - static const int MAX_EDABITS = 40 > MAX_N_BITS - 40 ? 40 : MAX_N_BITS - 40; + static const int MAX_EDABITS = MAX_N_BITS; template static void init(bool mont = true) @@ -229,6 +229,11 @@ class gfp_ : public ValueInterface void convert_destroy(bigint& x) { a.convert_destroy(x, ZpD); } + void to(bigint& res) const + { + res = *this; + } + // Convert representation to and from a bigint number friend void to_bigint(bigint& ans,const gfp_& x,bool reduce=true) { x.a.template to_bigint(ans, x.ZpD, reduce); } diff --git a/Math/gfp.hpp b/Math/gfp.hpp index 59d40ab6e..acc587007 100644 --- a/Math/gfp.hpp +++ b/Math/gfp.hpp @@ -171,16 +171,9 @@ void gfp_::reqbl(int n) } template -bool gfp_::allows(Dtype type) +bool gfp_::allows(Dtype) { - switch(type) - { - case DATA_BITGF2NTRIPLE: - case DATA_BITTRIPLE: - return false; - default: - return true; - } + return true; } template diff --git a/Math/gfpvar.cpp b/Math/gfpvar.cpp index 7b7360fdb..8b065195f 100644 --- a/Math/gfpvar.cpp +++ b/Math/gfpvar.cpp @@ -11,6 +11,7 @@ const true_type gfpvar::invertible; const true_type gfpvar::prime_field; +const false_type gfpvar::characteristic_two; Zp_Data gfpvar::ZpD; diff --git a/Math/modp.h b/Math/modp.h index 6cc08fa57..9bf942052 100644 --- a/Math/modp.h +++ b/Math/modp.h @@ -94,6 +94,10 @@ class modp_ template friend void to_modp(modp_& ans,int x,const Zp_Data& ZpD); template friend void to_modp(modp_& ans,const mpz_class& x,const Zp_Data& ZpD); + modp_ add(const modp_& other, const Zp_Data& ZpD) const; + modp_ sub(const modp_& other, const Zp_Data& ZpD) const; + modp_ mul(const modp_& other, const Zp_Data& ZpD) const; + friend void Add(modp_& ans,const modp_& x,const modp_& y,const Zp_Data& ZpD) { ZpD.Add(ans.x, x.x, y.x); } template friend void Sub(modp_& ans,const modp_& x,const modp_& y,const Zp_Data& ZpD); diff --git a/Math/modp.hpp b/Math/modp.hpp index 0e9d98b3e..e4570f940 100644 --- a/Math/modp.hpp +++ b/Math/modp.hpp @@ -8,6 +8,30 @@ * The following functions remain the same in Real and Montgomery rep * ***********************************************************************/ +template +modp_ modp_::add(const modp_& other, const Zp_Data& ZpD) const +{ + modp_ res; + Add(res, *this, other, ZpD); + return res; +} + +template +modp_ modp_::sub(const modp_& other, const Zp_Data& ZpD) const +{ + modp_ res; + Sub(res, *this, other, ZpD); + return res; +} + +template +modp_ modp_::mul(const modp_& other, const Zp_Data& ZpD) const +{ + modp_ res; + Mul(res, *this, other, ZpD); + return res; +} + template void modp_::randomize(PRNG& G, const Zp_Data& ZpD) { diff --git a/Math/mpn_fixed.h b/Math/mpn_fixed.h index cd891869c..b55a6b7ec 100644 --- a/Math/mpn_fixed.h +++ b/Math/mpn_fixed.h @@ -135,7 +135,7 @@ mp_limb_t mpn_add_fixed_n_with_carry(mp_limb_t* res, const mp_limb_t* x, const m inline mp_limb_t mpn_sub_n_borrow(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y, int n) { -#if !defined(__clang__) || (__GNUC__ < 7) || !defined(__x86_64__) +#if (!defined(__clang__) && (__GNUC__ < 7)) || !defined(__x86_64__) // GCC 6 can't handle the code below return mpn_sub_n(res, x, y, n); #else diff --git a/Networking/Player.cpp b/Networking/Player.cpp index e92d5219e..37807273e 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -216,17 +216,23 @@ template MultiPlayer::MultiPlayer(const Names& Nms, int id) : Player(Nms), send_to_self_socket(0) { - setup_sockets(Nms.names, Nms.ports, id, *Nms.server); + if (Nms.num_players() > 1) + setup_sockets(Nms.names, Nms.ports, id, *Nms.server); + else + sockets.resize(Nms.num_players()); } template<> MultiPlayer::~MultiPlayer() { - /* Close down the sockets */ - for (auto socket : sockets) - close_client_socket(socket); - close_client_socket(send_to_self_socket); + if (num_players() > 1) + { + /* Close down the sockets */ + for (auto socket : sockets) + close_client_socket(socket); + close_client_socket(send_to_self_socket); + } } template @@ -290,7 +296,6 @@ void MultiPlayer::setup_sockets(const vector& names,const vector& os) const; + // receive from a specific player void receive_player(int i,octetStream& o) const; virtual void receive_player_no_stats(int i,octetStream& o) const = 0; virtual void receive_player(int i,FlexBuffer& buffer) const; // Communication relative to my number + // send to all other players by offset void send_relative(const vector& o) const; + // send to other player specified by offset void send_relative(int offset, const octetStream& o) const; + // receive from all other players by offset void receive_relative(vector& o) const; + // receive from other palyer specified by offset void receive_relative(int offset, octetStream& o) const; // exchange data with minimal memory usage + // exchange information with one other party void exchange(int other, const octetStream& to_send, octetStream& ot_receive) const; virtual void exchange_no_stats(int other, const octetStream& to_send, octetStream& ot_receive) const = 0; void exchange(int other, octetStream& o) const; + // exchange with one other partiy specified by offset void exchange_relative(int offset, octetStream& o) const; + // send information to party while receiving from another by offset void pass_around(octetStream& o, int offset = 1) const { pass_around(o, o, offset); } void pass_around(octetStream& to_send, octetStream& to_receive, int offset) const; virtual void pass_around_no_stats(const octetStream& to_send, @@ -198,6 +212,7 @@ class Player : public PlayerBase * - Assumes o[player_no] contains the thing broadcast by me */ virtual void unchecked_broadcast(vector& o) const; + // broadcast with eventual verification virtual void Broadcast_Receive(vector& o) const; virtual void Broadcast_Receive_no_stats(vector& o) const = 0; @@ -209,8 +224,10 @@ class Player : public PlayerBase // send something different to all void send_receive_all(const vector& to_send, vector& to_receive) const; + // specified senders only send something different to all void send_receive_all(const vector& senders, const vector& to_send, vector& to_receive) const; + // send something different only one specified channels void send_receive_all(const vector>& channels, const vector& to_send, vector& to_receive) const; @@ -218,6 +235,7 @@ class Player : public PlayerBase const vector& to_send, vector& to_receive) const = 0; + // specified senders broadcast information virtual void partial_broadcast(const vector& senders, vector& os) const; virtual void partial_broadcast(const vector&, const vector&, @@ -238,8 +256,6 @@ class MultiPlayer : public Player void setup_sockets(const vector& names,const vector& ports,int id_base,ServerSocket& server); - map socket_players; - T socket_to_send(int player) const { return player == player_no ? send_to_self_socket : sockets[player]; } friend class CryptoPlayer; diff --git a/Processor/DataPositions.cpp b/Processor/DataPositions.cpp index fc8ebf185..9eaa81e56 100644 --- a/Processor/DataPositions.cpp +++ b/Processor/DataPositions.cpp @@ -10,7 +10,7 @@ const char* DataPositions::field_names[] = { "int", "gf2n", "bit" }; -const int DataPositions::tuple_size[N_DTYPE] = { 3, 2, 1, 2, 3, 3 }; +const int DataPositions::tuple_size[N_DTYPE] = { 3, 2, 1, 2 }; DataPositions::DataPositions(int num_players) { diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 404527ed2..36257af74 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -94,8 +94,10 @@ class Preprocessing : public PrepBase bool do_count; - void count(Dtype dtype, int n = 1) { usage.files[T::field_type()][dtype] += do_count * n; } - void count_input(int player) { usage.inputs[player][T::field_type()] += do_count; } + void count(Dtype dtype, int n = 1) + { usage.files[T::clear::field_type()][dtype] += do_count * n; } + void count_input(int player) + { usage.inputs[player][T::clear::field_type()] += do_count; } template void get_edabits(bool strict, size_t size, T* a, @@ -203,6 +205,7 @@ class Sub_Data_Files : public Preprocessing Sub_Data_Files(int my_num, int num_players, const string& prep_data_dir, DataPositions& usage, int thread_num = -1); + Sub_Data_Files(const Names& N, DataPositions& usage, int thread_num = -1); Sub_Data_Files(const Names& N, const string& prep_data_dir, DataPositions& usage, int thread_num = -1) : Sub_Data_Files(N.my_num(), N.num_players(), prep_data_dir, usage, thread_num) @@ -269,6 +272,7 @@ class Data_Files Data_Files(Machine& machine, SubProcessor* procp = 0, SubProcessor* proc2 = 0); + Data_Files(const Names& N); ~Data_Files(); DataPositions tellg(); @@ -331,7 +335,7 @@ template inline void Preprocessing::get_three(Dtype dtype, T& a, T& b, T& c) { // count bit triples in get_triple() - if (T::field_type() != DATA_GF2) + if (T::clear::field_type() != DATA_GF2) count(dtype); get_three_no_count(dtype, a, b, c); } @@ -361,14 +365,14 @@ template inline void Preprocessing::get(vector& S, DataTag tag, const vector& regs, int vector_size) { - usage.count(T::field_type(), tag, vector_size); + usage.count(T::clear::field_type(), tag, vector_size); get_no_count(S, tag, regs, vector_size); } template array Preprocessing::get_triple(int n_bits) { - if (T::field_type() == DATA_GF2) + if (T::clear::field_type() == DATA_GF2) count(DATA_TRIPLE, n_bits); return get_triple_no_count(n_bits); } @@ -376,7 +380,7 @@ array Preprocessing::get_triple(int n_bits) template array Preprocessing::get_triple_no_count(int n_bits) { - assert(T::field_type() != DATA_GF2 or T::default_length == 1 or + assert(T::clear::field_type() != DATA_GF2 or T::default_length == 1 or T::default_length == n_bits or not do_count); array res; get(DATA_TRIPLE, res.data()); diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index dca9a9079..702937951 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -28,6 +28,15 @@ Preprocessing* Preprocessing::get_new( machine.template prep_dir_prefix(), usage); } +template +Sub_Data_Files::Sub_Data_Files(const Names& N, DataPositions& usage, + int thread_num) : + Sub_Data_Files(N, + OnlineOptions::singleton.prep_dir_prefix(N.num_players()), usage, + thread_num) +{ +} + template int Sub_Data_Files::tuple_length(int dtype) @@ -117,6 +126,15 @@ Data_Files::Data_Files(Machine& machine, SubProcessor< { } +template +Data_Files::Data_Files(const Names& N) : + usage(N.num_players()), + DataFp(*new Sub_Data_Files(N, usage)), + DataF2(*new Sub_Data_Files(N, usage)) +{ +} + + template Data_Files::~Data_Files() { @@ -148,7 +166,7 @@ Sub_Data_Files::~Sub_Data_Files() template void Sub_Data_Files::seekg(DataPositions& pos) { - DataFieldType field_type = T::field_type(); + DataFieldType field_type = T::clear::field_type(); for (int dtype = 0; dtype < N_DTYPE; dtype++) if (T::clear::allows(Dtype(dtype))) buffers[dtype].seekg(pos.files[field_type][dtype]); diff --git a/Processor/DummyProtocol.h b/Processor/DummyProtocol.h index a1d9aacda..95bcd029a 100644 --- a/Processor/DummyProtocol.h +++ b/Processor/DummyProtocol.h @@ -89,7 +89,6 @@ class DummyProtocol : public ProtocolBase void init_mul(SubProcessor* = 0) { - throw not_implemented(); } typename T::clear prepare_mul(const T&, const T&, int = 0) { @@ -223,9 +222,6 @@ class NotImplementedInput static void input_mixed(SubProcessor, vector, int, int) { } - static void raw_input(SubProcessor, vector, int) - { - } void reset_all(Player& P) { (void) P; diff --git a/Processor/ExternalClients.cpp b/Processor/ExternalClients.cpp index d85807b1e..5ea480474 100644 --- a/Processor/ExternalClients.cpp +++ b/Processor/ExternalClients.cpp @@ -23,6 +23,8 @@ ExternalClients::~ExternalClients() { delete it->second; } + if (ctx) + delete ctx; } void ExternalClients::start_listening(int portnum_base) diff --git a/Processor/Input.h b/Processor/Input.h index 0558e7969..1e5aa1035 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -49,7 +49,7 @@ class InputBase void reset_all(Player& P); virtual void add_mine(const typename T::open_type& input, int n_bits = -1) = 0; - virtual void add_other(int player) = 0; + virtual void add_other(int player, int n_bits = -1) = 0; void add_from_all(const clear& input); virtual void send_mine() = 0; @@ -85,7 +85,7 @@ class Input : public InputBase void reset(int player); void add_mine(const open_type& input, int n_bits = -1); - void add_other(int player); + void add_other(int player, int n_bits = -1); void send_mine(); diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 044872e72..9272535bc 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -103,7 +103,7 @@ void Input::add_mine(const open_type& input, int n_bits) } template -void Input::add_other(int player) +void Input::add_other(int player, int) { open_type t; shares.at(player).push_back({}); diff --git a/Processor/Instruction.h b/Processor/Instruction.h index ce0d951c8..cfcac3b38 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -23,6 +23,7 @@ class ArithmeticProcessor; */ enum { + CISC = 0, // Load/store LDI = 0x1, LDSI = 0x2, @@ -101,6 +102,7 @@ enum MATMULS = 0xAA, MATMULSM = 0xAB, CONV2DS = 0xAC, + CHECK = 0xAF, // Data access TRIPLE = 0x50, BIT = 0x51, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index 711fd3dd5..8949f28e7 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -14,16 +14,8 @@ #include "GC/Instruction.h" #include "GC/instructions.h" -//#include "Processor/Processor.hpp" #include "Processor/Binary_File_IO.hpp" #include "Processor/PrivateOutput.hpp" -//#include "Processor/Input.hpp" -//#include "Processor/Beaver.hpp" -//#include "Protocols/Shamir.hpp" -//#include "Protocols/ShamirInput.hpp" -//#include "Protocols/Replicated.hpp" -//#include "Protocols/MaliciousRepMC.hpp" -//#include "Protocols/ShamirMC.hpp" #include "Math/bigint.hpp" #include @@ -73,7 +65,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case TRIPLE: case ANDC: case XORC: - case XORCB: case ORC: case SHLC: case SHRC: @@ -88,13 +79,9 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case GMULM: case GDIVC: case GTRIPLE: - case GBITTRIPLE: - case GBITGF2NTRIPLE: case GANDC: case GXORC: case GORC: - case GMULBITC: - case GMULBITM: case LTC: case GTC: case EQC: @@ -169,12 +156,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case CLOSECLIENTCONNECTION: r[0]=get_int(s); break; - // instructions with 3 registers + 1 integer operand - r[0]=get_int(s); - r[1]=get_int(s); - r[2]=get_int(s); - n = get_int(s); - break; // instructions with 2 registers + 1 integer operand case ADDCI: case ADDCBI: @@ -275,6 +256,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case CRASH: case STARTGRIND: case STOPGRIND: + case CHECK: break; // instructions with 5 register operands case PRINTFLOATPLAIN: @@ -320,7 +302,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) break; case CONV2DS: get_ints(r, s, 3); - get_vector(11, start, s); + get_vector(12, start, s); break; // read from file, input is opcode num_args, @@ -374,6 +356,11 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case GPROTECTMEMC: case PROTECTMEMINT: throw runtime_error("memory protection not supported any more"); + case GBITTRIPLE: + case GBITGF2NTRIPLE: + case GMULBITC: + case GMULBITM: + throw runtime_error("GF(2^n) bit operations not supported any more"); case GBITDEC: case GBITCOM: num_var_args = get_int(s) - 2; @@ -390,6 +377,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) break; case PREP: case GPREP: + case CISC: // subtract extra argument num_var_args = get_int(s) - 1; s.read((char*)r, sizeof(r)); @@ -417,6 +405,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) break; case XORM: case ANDM: + case XORCB: n = get_int(s); get_ints(r, s, 3); break; @@ -555,6 +544,7 @@ int BaseInstruction::get_reg_type() const case GUSE_PREP: case USE_EDABIT: case RUN_TAPE: + case CISC: // those use r[] not for registers return NONE; case LDI: @@ -667,7 +657,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const case MATMULSM: return r[0] + start[0] * start[2]; case CONV2DS: - return r[0] + start[0] * start[1]; + return r[0] + start[0] * start[1] * start[11]; case OPEN: skip = 2; break; @@ -901,18 +891,6 @@ inline void Instruction::execute(Processor& Proc) const to_gfp(Proc.temp.ansp, Proc.temp.aa2 = mpz_fdiv_ui(Proc.temp.aa.get_mpz_t(), n)); Proc.write_Cp(r[0],Proc.temp.ansp); break; - case GMULBITC: - Proc.get_C2_ref(r[0]).mul_by_bit(Proc.read_C2(r[1]),Proc.read_C2(r[2])); - break; - case GMULBITM: - Proc.get_S2_ref(r[0]).mul_by_bit(Proc.read_S2(r[1]),Proc.read_C2(r[2])); - break; - case GBITTRIPLE: - Proc2.DataF.get_three(DATA_BITTRIPLE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]),Proc.get_S2_ref(r[2])); - break; - case GBITGF2NTRIPLE: - Proc2.DataF.get_three(DATA_BITGF2NTRIPLE, Proc.get_S2_ref(r[0]),Proc.get_S2_ref(r[1]),Proc.get_S2_ref(r[2])); - break; case SQUARE: Procp.DataF.get_two(DATA_SQUARE, Proc.get_Sp_ref(r[0]),Proc.get_Sp_ref(r[1])); break; @@ -1016,6 +994,16 @@ inline void Instruction::execute(Processor& Proc) const case TRUNC_PR: Proc.Procp.protocol.trunc_pr(start, size, Proc.Procp); return; + case CHECK: + { + CheckJob job; + if (BaseMachine::thread_num == 0) + BaseMachine::s().queues.distribute(job, 0); + Proc.check(); + if (BaseMachine::thread_num == 0) + BaseMachine::s().queues.wrap_up(job); + return; + } case JMP: Proc.PC += (signed int) n; break; @@ -1130,7 +1118,7 @@ inline void Instruction::execute(Processor& Proc) const { octetStream os; os.store(int(sint::open_type::type_char())); - sint::open_type::specification(os); + sint::specification(os); os.Send(Proc.external_clients.get_socket(client_handle)); } Proc.write_Ci(r[0], client_handle); @@ -1194,6 +1182,9 @@ inline void Instruction::execute(Processor& Proc) const case GPREP: Proc2.DataF.get(Proc.Proc2.get_S(), r, start, size); return; + case CISC: + Procp.protocol.cisc(Procp, *this); + return; default: printf("Case of opcode=0x%x not implemented yet\n",opcode); throw invalid_opcode(opcode); diff --git a/Processor/IntInput.hpp b/Processor/IntInput.hpp index 97bc7c0b5..5f0a745dd 100644 --- a/Processor/IntInput.hpp +++ b/Processor/IntInput.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROCESSOR_INTINPUT_HPP_ +#define PROCESSOR_INTINPUT_HPP_ + #include "IntInput.h" template @@ -13,3 +16,5 @@ void IntInput::read(std::istream& in, const int*) { in >> items[0]; } + +#endif diff --git a/Processor/Machine.h b/Processor/Machine.h index 4a68f798c..2cd0c664b 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -95,9 +95,6 @@ class Machine : public BaseMachine template string prep_dir_prefix(); - // Only for Player-Demo.cpp - Machine(Names& N = *(new Names())): N(N) {} - void reqbl(int n); typename sint::bit_type::mac_key_type get_bit_mac_key() { return alphabi; } diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 17da8a026..d7a1fb10d 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -62,7 +62,7 @@ Machine::Machine(int my_number, Names& playerNames, sint::read_or_generate_mac_key(prep_dir_prefix(), *P, alphapi); sgf2n::read_or_generate_mac_key(prep_dir_prefix(), *P, alpha2i); sint::bit_type::part_type::read_or_generate_mac_key( - prep_dir_prefix(), *P, alphabi); + prep_dir_prefix(), *P, alphabi); #ifdef DEBUG_MAC cerr << "MAC Key p = " << alphapi << endl; @@ -411,10 +411,7 @@ template template string Machine::prep_dir_prefix() { - int lgp = opts.lgp; - if (opts.prime) - lgp = numBits(opts.prime); - return get_prep_sub_dir(PREP_DIR, N.num_players(), lgp); + return opts.prep_dir_prefix(N.num_players()); } template diff --git a/Processor/NoLivePrep.h b/Processor/NoLivePrep.h deleted file mode 100644 index 514283535..000000000 --- a/Processor/NoLivePrep.h +++ /dev/null @@ -1,41 +0,0 @@ -/* - * NoLivePrep.h - * - */ - -#ifndef PROCESSOR_NOLIVEPREP_H_ -#define PROCESSOR_NOLIVEPREP_H_ - -#include "Tools/Exceptions.h" -#include "Data_Files.h" - -template class SubProcessor; -class DataPositions; - -template -class NoLivePrep : public Sub_Data_Files -{ -public: - static void basic_setup(Player&) - { - } - static void teardown() - { - } - - NoLivePrep(SubProcessor* proc, DataPositions& usage) : Sub_Data_Files(0, 0, "", usage, 0) - { - (void) proc; - throw not_implemented(); - } - template - NoLivePrep(DataPositions& usage, U& _) : NoLivePrep(0, usage) - { - (void) _; - } - NoLivePrep(DataPositions& usage) : NoLivePrep(0, usage) - { - } -}; - -#endif /* PROCESSOR_NOLIVEPREP_H_ */ diff --git a/Processor/NoProtocol.h b/Processor/NoProtocol.h deleted file mode 100644 index 036be1d70..000000000 --- a/Processor/NoProtocol.h +++ /dev/null @@ -1,42 +0,0 @@ -/* - * NoProtocol.h - * - */ - -#ifndef PROCESSOR_NOPROTOCOL_H_ -#define PROCESSOR_NOPROTOCOL_H_ - -#include "Protocols/Replicated.h" - -template -class NoProtocol : public ProtocolBase -{ -public: - NoProtocol(Player&) - { - - } - - void init_mul(SubProcessor*) - { - throw not_implemented(); - } - typename T::clear prepare_mul(const T&, const T&, int n = -1) - { - (void) n; - throw not_implemented(); - } - void exchange() - { - throw not_implemented(); - } - T finalize_mul(int n = -1) - { - (void) n; - throw not_implemented(); - } -}; - - - -#endif /* PROCESSOR_NOPROTOCOL_H_ */ diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index 0e33de641..5275e69a3 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -37,7 +37,8 @@ int OfflineMachine::run() T::clear::init_default(this->online_opts.prime_length()); U::clear::init_field(U::clear::default_degree()); T::bit_type::mac_key_type::init_field(); - auto binary_mac_key = read_generate_write_mac_key(P); + auto binary_mac_key = read_generate_write_mac_key< + typename T::bit_type::part_type>(P); GC::ShareThread thread(playerNames, OnlineOptions::singleton, P, binary_mac_key, usage); @@ -63,12 +64,13 @@ void OfflineMachine::generate() typename T::LivePrep preprocessing(0, generated); SubProcessor processor(output, preprocessing, P); - auto& domain_usage = usage.files[T::field_type()]; + auto& domain_usage = usage.files[T::clear::field_type()]; for (unsigned i = 0; i < domain_usage.size(); i++) { auto my_usage = domain_usage[i]; Dtype dtype = Dtype(i); - string filename = Sub_Data_Files::get_filename(playerNames, dtype); + string filename = Sub_Data_Files::get_filename(playerNames, dtype, + T::clear::field_type() == DATA_GF2 ? 0 : -1); if (my_usage > 0) { ofstream out(filename, iostream::out | iostream::binary); @@ -101,7 +103,7 @@ void OfflineMachine::generate() for (int i = 0; i < P.num_players(); i++) { - auto n_inputs = usage.inputs[i][T::field_type()]; + auto n_inputs = usage.inputs[i][T::clear::field_type()]; string filename = Sub_Data_Files::get_input_filename(playerNames, i); if (n_inputs > 0) { @@ -120,7 +122,7 @@ void OfflineMachine::generate() remove(filename.c_str()); } - if (T::field_type() == DATA_INT) + if (T::clear::field_type() == DATA_INT) { int max_n_bits = 0; for (auto& x : usage.edabits) @@ -132,7 +134,7 @@ void OfflineMachine::generate() int total = usage.edabits[{false, n_bits}] + usage.edabits[{true, n_bits}]; string filename = Sub_Data_Files::get_edabit_filename(playerNames, - n_bits, P.my_num()); + n_bits); if (total > 0) { ofstream out(filename, ios::binary); diff --git a/Processor/Online-Thread.h b/Processor/Online-Thread.h index dc05c31be..b0965ae0d 100644 --- a/Processor/Online-Thread.h +++ b/Processor/Online-Thread.h @@ -26,7 +26,7 @@ class thread_info static void* Main_Func(void *ptr); - static void purge_preprocessing(Machine& machine); + static void purge_preprocessing(const Names& N); template static void print_usage(ostream& o, const vector& regs, diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index 466d1b336..74fc0a2d3 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -8,6 +8,7 @@ #include "Networking/CryptoPlayer.h" #include "Protocols/ShuffleSacrifice.h" #include "Protocols/LimitedPrep.h" +#include "FHE/FFT.h" #include "Processor/Processor.hpp" #include "Processor/Instruction.hpp" @@ -95,7 +96,7 @@ void thread_info::Sub_Main_Func() } // Allocate memory for first program before starting the clock - processor = new Processor(tinfo->thread_num,P,*MC2,*MCp,machine,progs[0]); + processor = new Processor(tinfo->thread_num,P,*MC2,*MCp,machine,progs.at(thread_num > 0)); auto& Proc = *processor; bool flag=true; @@ -224,6 +225,19 @@ void thread_info::Sub_Main_Func() job.end); queues->finished(job); } + else if (job.type == CHECK_JOB) + { + Proc.check(); + queues->finished(job); + } + else if (job.type == FFT_JOB) + { + for (int i = job.begin; i < job.end; i++) + FFT_Iter2_body(*(vector*) job.output, + *(vector*) job.input, i, job.length, + *(Zp_Data*) job.supply); + queues->finished(job); + } else { // RUN PROGRAM #ifdef DEBUG_THREADS @@ -240,6 +254,9 @@ void thread_info::Sub_Main_Func() // Execute the program progs[program].execute(Proc); + // prevent mangled output + cout.flush(); + actual_usage.increase(Proc.DataF.get_usage()); if (progs[program].usage_unknown()) @@ -256,18 +273,8 @@ void thread_info::Sub_Main_Func() } } - // protocol check before last MAC check - Proc.Procp.protocol.check(); - Proc.Proc2.protocol.check(); - - // MACCheck - MC2->Check(P); - MCp->Check(P); - Proc.share_thread.MC->Check(P); - - //cout << num << " : Checking broadcast" << endl; - P.Check_Broadcast(); - //cout << num << " : Broadcast checked "<< endl; + // final check + Proc.check(); wait_timer.start(); queues->next(); @@ -320,31 +327,35 @@ void thread_info::Sub_Main_Func() template void* thread_info::Main_Func(void* ptr) { -#ifndef INSECURE - try -#endif - { - ((thread_info*)ptr)->Sub_Main_Func(); - } -#ifndef INSECURE - catch (...) - { + auto& ti = *(thread_info*)(ptr); +#ifdef INSECURE + ti.Sub_Main_Func(); +#else + if (ti.machine->opts.live_prep) + ti.Sub_Main_Func(); + else + try + { + ti.Sub_Main_Func(); + } + catch (...) + { thread_info* ti = (thread_info*)ptr; - ti->purge_preprocessing(*ti->machine); + ti->purge_preprocessing(ti->machine->get_N()); throw; - } + } #endif return 0; } template -void thread_info::purge_preprocessing(Machine& machine) +void thread_info::purge_preprocessing(const Names& N) { cerr << "Purging preprocessed data because something is wrong" << endl; try { - Data_Files df(machine); + Data_Files df(N); df.purge(); } catch(...) diff --git a/Processor/OnlineMachine.h b/Processor/OnlineMachine.h index 352884a14..e0a70a4b5 100644 --- a/Processor/OnlineMachine.h +++ b/Processor/OnlineMachine.h @@ -28,13 +28,13 @@ class OnlineMachine int nplayers; - void start_networking(); - public: template OnlineMachine(int argc, const char** argv, ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers = 0, V = {}); + void start_networking(); + template int run(); diff --git a/Processor/OnlineMachine.hpp b/Processor/OnlineMachine.hpp index 6b72b34f3..049388e8f 100644 --- a/Processor/OnlineMachine.hpp +++ b/Processor/OnlineMachine.hpp @@ -249,9 +249,8 @@ int OnlineMachine::run() #ifndef INSECURE catch(...) { - Machine machine(playerNames); - machine.live_prep = false; - thread_info::purge_preprocessing(machine); + if (not online_opts.live_prep) + thread_info::purge_preprocessing(playerNames); throw; } #endif diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index 372e9b4ca..8b34adf0f 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -8,6 +8,7 @@ #include "Tools/ezOptionParser.h" #include "Math/bigint.h" +#include "Math/Setup.h" class OnlineOptions { @@ -32,10 +33,21 @@ class OnlineOptions OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, int default_batch_size = 0, bool default_live_prep = true, bool variable_prime_length = false); + ~OnlineOptions() {} + void finalize(ez::ezOptionParser& opt, int argc, const char** argv); int prime_length(); int prime_limbs(); + + template + string prep_dir_prefix(int nplayers) + { + int lgp = this->lgp; + if (prime) + lgp = numBits(prime); + return get_prep_sub_dir(PREP_DIR, nplayers, lgp); + } }; #endif /* PROCESSOR_ONLINEOPTIONS_H_ */ diff --git a/Processor/Processor.h b/Processor/Processor.h index dc152486a..fa0a5fe18 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -213,6 +213,8 @@ class Processor : public ArithmeticProcessor void write_Sp(int i,const sint & x) { Procp.S[i]=x; } + void check(); + void dabit(const Instruction& instruction); void edabit(const Instruction& instruction, bool strict = false); diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index bb751975d..e88fe8795 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -146,6 +146,24 @@ void Processor::reset(const Program& program,int arg) Procb.reset(program); } +template +void Processor::check() +{ + // protocol check before last MAC check + Procp.protocol.check(); + Proc2.protocol.check(); + share_thread.protocol->check(); + + // MACCheck + MC2.Check(P); + MCp.Check(P); + share_thread.MC->Check(P); + + //cout << num << " : Checking broadcast" << endl; + P.Check_Broadcast(); + //cout << num << " : Broadcast checked "<< endl; +} + template void Processor::dabit(const Instruction& instruction) { @@ -555,49 +573,77 @@ void SubProcessor::conv2ds(const Instruction& instruction) int n_channels_in = args[8]; int padding_h = args[9]; int padding_w = args[10]; - int r0 = instruction.get_r(0); - int r1 = instruction.get_r(1); + int batch_size = args[11]; + size_t r0 = instruction.get_r(0); + size_t r1 = instruction.get_r(1); int r2 = instruction.get_r(2); - int lengths[output_h][output_w]; + int lengths[batch_size][output_h][output_w]; memset(lengths, 0, sizeof(lengths)); + int filter_stride_h = 1; + int filter_stride_w = 1; + if (stride_h < 0) + { + filter_stride_h = -stride_h; + stride_h = 1; + } + if (stride_w < 0) + { + filter_stride_w = -stride_w; + stride_w = 1; + } - for (int out_y = 0; out_y < output_h; out_y++) - for (int out_x = 0; out_x < output_w; out_x++) - { - int in_x_origin = (out_x * stride_w) - padding_w; - int in_y_origin = (out_y * stride_h) - padding_h; - - for (int filter_y = 0; filter_y < weights_h; filter_y++) + for (int i_batch = 0; i_batch < batch_size; i_batch ++) + { + size_t base = r1 + i_batch * inputs_w * inputs_h * n_channels_in; + assert(base + inputs_w * inputs_h * n_channels_in <= S.size()); + T* input_base = &S[base]; + for (int out_y = 0; out_y < output_h; out_y++) + for (int out_x = 0; out_x < output_w; out_x++) { - int in_y = in_y_origin + filter_y; - if ((0 <= in_y) and (in_y < inputs_h)) - for (int filter_x = 0; filter_x < weights_w; filter_x++) - { - int in_x = in_x_origin + filter_x; - if ((0 <= in_x) and (in_x < inputs_w)) + int in_x_origin = (out_x * stride_w) - padding_w; + int in_y_origin = (out_y * stride_h) - padding_h; + + for (int filter_y = 0; filter_y < weights_h; filter_y++) + { + int in_y = in_y_origin + filter_y * filter_stride_h; + if ((0 <= in_y) and (in_y < inputs_h)) + for (int filter_x = 0; filter_x < weights_w; filter_x++) { - for (int in_c = 0; in_c < n_channels_in; in_c++) - protocol.prepare_dotprod( - S[r1 + (in_y * inputs_w + in_x) * - n_channels_in + in_c], - S[r2 + (filter_y * weights_w + filter_x) * - n_channels_in + in_c]); - lengths[out_y][out_x] += n_channels_in; + int in_x = in_x_origin + filter_x * filter_stride_w; + if ((0 <= in_x) and (in_x < inputs_w)) + { + T* pixel_base = &input_base[(in_y * inputs_w + + in_x) * n_channels_in]; + T* weight_base = &S[r2 + + (filter_y * weights_w + filter_x) + * n_channels_in]; + for (int in_c = 0; in_c < n_channels_in; in_c++) + protocol.prepare_dotprod(pixel_base[in_c], + weight_base[in_c]); + lengths[i_batch][out_y][out_x] += n_channels_in; + } } - } - } + } - protocol.next_dotprod(); - } + protocol.next_dotprod(); + } + } protocol.exchange(); - for (int out_y = 0; out_y < output_h; out_y++) - for (int out_x = 0; out_x < output_w; out_x++) - { - S[r0 + out_y * output_w + out_x] = protocol.finalize_dotprod( - lengths[out_y][out_x]); - } + for (int i_batch = 0; i_batch < batch_size; i_batch ++) + { + size_t base = r0 + i_batch * output_h * output_w; + assert(base + output_h * output_w <= S.size()); + T* output_base = &S[base]; + for (int out_y = 0; out_y < output_h; out_y++) + for (int out_x = 0; out_x < output_w; out_x++) + { + output_base[out_y * output_w + out_x] = + protocol.finalize_dotprod( + lengths[i_batch][out_y][out_x]); + } + } } template diff --git a/Processor/Program.h b/Processor/Program.h index 727f90709..55c9b8c45 100644 --- a/Processor/Program.h +++ b/Processor/Program.h @@ -43,7 +43,7 @@ class Program bool usage_unknown() const { return unknown_usage; } - int num_reg(RegType reg_type) const + unsigned num_reg(RegType reg_type) const { return max_reg[reg_type]; } unsigned direct_mem(RegType reg_type) const diff --git a/Processor/ThreadJob.h b/Processor/ThreadJob.h index 65199a13e..fe4f0b6c6 100644 --- a/Processor/ThreadJob.h +++ b/Processor/ThreadJob.h @@ -7,6 +7,7 @@ #define PROCESSOR_THREADJOB_H_ #include "Data_Files.h" +#include "Math/modp.h" enum ThreadJobType { @@ -20,6 +21,8 @@ enum ThreadJobType EDABIT_SACRIFICE_JOB, PERSONAL_TRIPLE_JOB, TRIPLE_SACRIFICE_JOB, + CHECK_JOB, + FFT_JOB, NO_JOB }; @@ -159,4 +162,26 @@ class TripleSacrificeJob : public ThreadJob } }; +class CheckJob : public ThreadJob +{ +public: + CheckJob() + { + type = CHECK_JOB; + } +}; + +class FftJob : public ThreadJob +{ +public: + FftJob(vector& ioput, vector& alpha2, int m, const Zp_Data& PrD) + { + type = FFT_JOB; + output = &ioput; + input = &alpha2; + length = m; + supply = &PrD; + } +}; + #endif /* PROCESSOR_THREADJOB_H_ */ diff --git a/Programs/Source/mnist_B.mpc b/Programs/Source/mnist_B.mpc new file mode 100644 index 000000000..a62be674f --- /dev/null +++ b/Programs/Source/mnist_B.mpc @@ -0,0 +1,73 @@ +import ml +import math + +#ml.report_progress = True + +program.options_from_args() + +approx = 3 + +if 'profile' in program.args: + print('Compiling for profiling') + N = 1000 + n_test = 1000 +elif 'debug' in program.args: + N = 10 + n_test = 10 +elif 'debug20' in program.args: + N = 20 + n_test = 20 +elif 'debug100' in program.args: + N = 100 + n_test = 100 +elif 'gisette' in program.args: + print('Compiling for 4/9') + N = 11791 + n_test = 1991 +else: + N = 12665 + n_test = 2115 + +n_examples = N +n_features = 28 ** 2 + +try: + n_epochs = int(program.args[1]) +except: + n_epochs = 100 + +try: + batch_size = int(program.args[2]) +except: + batch_size = N + +assert batch_size <= N + +try: + ml.set_n_threads(int(program.args[3])) +except: + pass + +layers = [ + ml.FixConv2d([N, 28, 28, 1], (16, 5, 5, 1), (16,), [N, 24, 24, 16], (1, 1)), + ml.MaxPool([N, 24, 24, 16]), + ml.Relu([N, 12, 12, 16]), + ml.FixConv2d([N, 12, 12, 16], (16, 5, 5, 16), (16,), [N, 8, 8, 16], (1, 1)), + ml.MaxPool([N, 8, 8, 16]), + ml.Relu([N, 4, 4, 16]), + ml.Dense(N, 256, 100), + ml.Relu([N, 100]), + ml.Dense(N, 100, 1), + ml.Output(N) +] + +layers[-1].Y.input_from(0) +layers[0].X.input_from(0) + +Y = sint.Array(n_test) +X = sfix.Matrix(n_test, n_features) +Y.input_from(0) +X.input_from(0) + +sgd = ml.SGD(layers, 1, report_loss=True) +sgd.run_by_args(program, n_epochs, batch_size, X, Y) diff --git a/Programs/Source/mnist_D.mpc b/Programs/Source/mnist_D.mpc new file mode 100644 index 000000000..49f8f06fd --- /dev/null +++ b/Programs/Source/mnist_D.mpc @@ -0,0 +1,60 @@ +import ml +import math + +#ml.report_progress = True + +program.options_from_args() + +approx = 3 + +if 'profile' in program.args: + print('Compiling for profiling') + N = 1000 + n_test = 1000 +elif 'debug' in program.args: + N = 10 + n_test = 10 +elif 'gisette' in program.args: + print('Compiling for 4/9') + N = 11791 + n_test = 1991 +else: + N = 12665 + n_test = 2115 + +n_examples = N +n_features = 28 ** 2 + +try: + n_epochs = int(program.args[1]) +except: + n_epochs = 100 + +try: + batch_size = int(program.args[2]) +except: + batch_size = N + +assert batch_size <= N + +try: + ml.set_n_threads(int(program.args[3])) +except: + pass + +layers = [ + ml.FixConv2d([N, 28, 28, 1], (5, 5, 5, 1), (5,), [N, 14, 14, 5], (2, 2)), + ml.Relu([N, 14, 14, 5]), + ml.Dense(N, 980, 1), + ml.Output(N, approx=approx)] + +layers[-1].Y.input_from(0) +layers[0].X.input_from(0) + +Y = sint.Array(n_test) +X = sfix.Matrix(n_test, n_features) +Y.input_from(0) +X.input_from(0) + +sgd = ml.SGD(layers, 1, report_loss=True) +sgd.run_by_args(program, n_epochs, batch_size, X, Y) diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index b40219d14..4a8065df0 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -6,6 +6,7 @@ import util #ml.report_progress = True program.options_from_args() +sfix.set_precision_from_args(program, adapt_ring=True) if 'profile' in program.args: print('Compiling for profiling') @@ -14,6 +15,8 @@ if 'profile' in program.args: elif 'debug' in program.args: N = 100 n_test = 100 +elif 'debug5000' in program.args: + N = n_test = 5000 else: N = 60000 n_test = 10000 @@ -31,7 +34,8 @@ try: except: batch_size = N -assert batch_size <= N +N = min(N, 10000) +ml.Layer.back_batch_size = batch_size try: ml.set_n_threads(int(program.args[3])) @@ -40,6 +44,9 @@ except: n_inner = 128 +if 'fc512' in program.args: + n_inner = 512 + if 'norelu' in program.args: activation = 'id' else: @@ -48,35 +55,32 @@ else: if 'nearest' in program.args: sfix.round_nearest = True -if 'double' in program.args: - sfix.set_precision(32, 63) - cfix.set_precision(32, 63) -elif 'triple' in program.args: - sfix.set_precision(48, 91) - cfix.set_precision(48, 91) -elif 'quadruple' in program.args: - sfix.set_precision(64, 127) - cfix.set_precision(64, 127) -elif 'sextuple' in program.args: - sfix.set_precision(96, 191) - cfix.set_precision(96, 191) -elif 'octuple' in program.args: - sfix.set_precision(128, 255) - cfix.set_precision(128, 255) - if program.options.ring: assert sfix.f * 4 == int(program.options.ring) debug_ml = ('debug_ml' in program.args) * 2 ** (sfix.f / 2) if '1dense' in program.args: - layers = [ml.Dense(N, n_features, 10, debug=debug_ml)] + layers = [ml.Dense(n_examples, n_features, 10, debug=debug_ml)] else: - layers = [ml.Dense(N, n_features, n_inner, activation=activation, debug=debug_ml), + layers = [ml.Dense(n_examples, n_features, n_inner, activation=activation, debug=debug_ml), ml.Dense(N, n_inner, n_inner, activation=activation, debug=debug_ml), ml.Dense(N, n_inner, 10, debug=debug_ml)] -layers += [ml.MultiOutput.from_args(program, N, 10)] +if 'dropout' in program.args: + for i in range(len(layers) - 1, 0, -1): + layers.insert(i, ml.Dropout(N, n_inner)) + +if 'dropout-late' in program.args: + layers.insert(-1, ml.Dropout(N, n_inner)) + +if 'dropout-early' in program.args: + layers.insert(0, ml.Dropout(n_examples, n_features)) + +if 'dropout-early.25' in program.args: + layers.insert(0, ml.Dropout(n_examples, n_features, alpha=.25)) + +layers += [ml.MultiOutput.from_args(program, n_examples, 10)] layers[-1].cheaper_loss = 'mse' in program.args @@ -94,14 +98,8 @@ if not ('no_acc' in program.args and 'no_loss' in program.args): Y.input_from(0) X.input_from(0) -if 'always_acc' in program.args: - n_part_epochs = 1 -else: - n_part_epochs = 10 - -sgd = ml.SGD(layers, n_part_epochs, report_loss=True, debug=debug_ml) +sgd = ml.Optimizer.from_args(program, layers) #sgd.print_update_average = True -sgd.print_losses = 'print_losses' in program.args if 'faster' in program.args: sgd.gamma = MemValue(cfix(.1)) @@ -109,5 +107,5 @@ if 'faster' in program.args: if 'slower' in program.args: sgd.gamma = MemValue(cfix(.001)) -sgd.run_by_args(program, int(math.ceil(n_epochs / n_part_epochs)), batch_size, - X, Y) +sgd.run_by_args(program, n_epochs, batch_size, + X, Y, acc_batch_size=N) diff --git a/Programs/Source/mnist_full_B.mpc b/Programs/Source/mnist_full_B.mpc new file mode 100644 index 000000000..334072b43 --- /dev/null +++ b/Programs/Source/mnist_full_B.mpc @@ -0,0 +1,73 @@ +import ml +import math +import re +import util + +program.options_from_args() +sfix.set_precision_from_args(program, adapt_ring=True) + +if 'profile' in program.args: + print('Compiling for profiling') + N = 1000 + n_test = 100 +elif 'debug' in program.args: + N = 100 + n_test = 100 +elif 'debug1000' in program.args: + N = 1000 + n_test = 1000 +elif 'debug5000' in program.args: + N = 5000 + n_test = 5000 +else: + N = 60000 + n_test = 10000 + +n_examples = N +n_features = 28 ** 2 + +try: + n_epochs = int(program.args[1]) +except: + n_epochs = 100 + +try: + batch_size = int(program.args[2]) +except: + batch_size = N + +if 'savemem' in program.args: + N = batch_size +else: + N = min(N, 1000) + +try: + ml.set_n_threads(int(program.args[3])) +except: + pass + +layers = [ + ml.FixConv2d([n_examples, 28, 28, 1], (16, 5, 5, 1), (16,), [n_examples, 24, 24, 16], + (1, 1), 'VALID'), + ml.MaxPool([N, 24, 24, 16]), + ml.Relu([N, 12, 12, 16]), + ml.FixConv2d([N, 12, 12, 16], (16, 5, 5, 16), (16,), [N, 8, 8, 16], (1, 1), 'VALID'), + ml.MaxPool([N, 8, 8, 16]), + ml.Relu([N, 4, 4, 16]), + ml.Dense(N, 256, 100), + ml.Relu([N, 100]), + ml.Dense(N, 100, 10), + ml.MultiOutput(n_examples, 10) +] + +Y = sint.Matrix(n_test, 10) +X = sfix.Matrix(n_test, n_features) + +if not ('no_acc' in program.args and 'no_loss' in program.args): + layers[-1].Y.input_from(0) + layers[0].X.input_from(0) + Y.input_from(0) + X.input_from(0) + +optim = ml.Optimizer.from_args(program, layers) +optim.run_by_args(program, n_epochs, batch_size, X, Y, acc_batch_size=N) diff --git a/Programs/Source/mnist_full_C.mpc b/Programs/Source/mnist_full_C.mpc new file mode 100644 index 000000000..fb6ad55f6 --- /dev/null +++ b/Programs/Source/mnist_full_C.mpc @@ -0,0 +1,88 @@ +import ml +import math +import re +import util + +program.options_from_args() +sfix.set_precision_from_args(program, adapt_ring=True) + +if 'profile' in program.args: + print('Compiling for profiling') + N = 1000 + n_test = 100 +elif 'debug' in program.args: + N = 100 + n_test = 100 +elif 'debug1000' in program.args: + N = 1000 + n_test = 1000 +elif 'debug5000' in program.args: + N = 5000 + n_test = 5000 +else: + N = 60000 + n_test = 10000 + +n_examples = N +n_features = 28 ** 2 + +try: + n_epochs = int(program.args[1]) +except: + n_epochs = 100 + +try: + batch_size = int(program.args[2]) +except: + batch_size = N + +if 'savemem' in program.args: + N = batch_size +else: + N = min(N, max(1000, batch_size)) + +try: + ml.set_n_threads(int(program.args[3])) +except: + pass + +ml.Layer.back_batch_size = batch_size + +layers = [ + ml.FixConv2d([n_examples, 28, 28, 1], (20, 5, 5, 1), (20,), [n_examples, 24, 24, 20], (1, 1), 'VALID'), + ml.MaxPool([N, 24, 24, 20]), + ml.Relu([N, 12, 12, 20]), + ml.FixConv2d([N, 12, 12, 20], (50, 5, 5, 20), (50,), [N, 8, 8, 50], (1, 1), 'VALID'), + ml.MaxPool([N, 8, 8, 50]), + ml.Relu([N, 4, 4, 50]), + ml.Dense(N, 800, 500), + ml.Relu([N, 500]), + ml.Dense(N, 500, 10), + ml.MultiOutput(n_examples, 10) +] + +if 'dropout' in program.args or 'dropout2' in program.args: + layers.insert(8, ml.Dropout(N, 500)) +elif 'dropout.25' in program.args: + layers.insert(8, ml.Dropout(N, 500, alpha=0.25)) +elif 'dropout.125' in program.args: + layers.insert(8, ml.Dropout(N, 500, alpha=0.125)) + +if 'dropout2' in program.args: + layers.insert(6, ml.Dropout(N, 800, alpha=0.125)) +elif 'dropout1' in program.args: + layers.insert(6, ml.Dropout(N, 800, alpha=0.5)) + +print(layers) + +Y = sint.Matrix(n_test, 10) +X = sfix.Matrix(n_test, n_features) + +if not ('no_acc' in program.args and 'no_loss' in program.args): + layers[-1].Y.input_from(0) + layers[0].X.input_from(0) + Y.input_from(0) + X.input_from(0) + +optim = ml.Optimizer.from_args(program, layers) +optim.run_by_args(program, n_epochs, batch_size, X, Y, acc_batch_size=N) diff --git a/Programs/Source/mnist_full_D.mpc b/Programs/Source/mnist_full_D.mpc new file mode 100644 index 000000000..f250de3ae --- /dev/null +++ b/Programs/Source/mnist_full_D.mpc @@ -0,0 +1,105 @@ +import ml +import math +import re +import util + +program.options_from_args() +sfix.set_precision_from_args(program, True) + +if 'profile' in program.args: + print('Compiling for profiling') + N = 1000 + n_test = 100 +elif 'debug' in program.args: + N = 100 + n_test = 100 +elif 'debug1000' in program.args: + N = 1000 + n_test = 1000 +elif 'debug5000' in program.args: + N = 5000 + n_test = 5000 +else: + N = 60000 + n_test = 10000 + +n_examples = N +n_features = 28 ** 2 + +try: + n_epochs = int(program.args[1]) +except: + n_epochs = 100 + +try: + batch_size = int(program.args[2]) +except: + batch_size = N + +assert batch_size <= N +ml.Layer.back_batch_size = batch_size + +try: + ml.set_n_threads(int(program.args[3])) +except: + pass + +if program.options.ring: + assert sfix.f * 4 == int(program.options.ring) + +if 'stride1' in program.args: + stride = (1, 1) +else: + stride = (2, 2) + +if 'valid' in program.args: + padding = 'VALID' + inner_dim = (28 - 4) // stride[0] +else: + padding = 'SAME' + inner_dim = 28 // stride[0] + +layers = [ + ml.FixConv2d([N, 28, 28, 1], (5, 5, 5, 1), (5,), + [N, inner_dim, inner_dim, 5], stride, padding), + ml.Relu([N, inner_dim, inner_dim, 5]), +] + +if 'maxpool' in program.args: + layers += [ml.MaxPool((N, inner_dim, inner_dim, 5))] + inner_dim //= 2 + +n_inner = inner_dim ** 2 * 5 + +dropout = 'dropout' in program.args + +if '1dense' in program.args: + if dropout: + layers += [ml.Dropout(N, n_inner)] + layers += [ml.Dense(N, n_inner, 10),] +elif '2dense' in program.args: + if dropout: + layers += [ml.Dropout(N, n_inner)] + layers += [ + ml.Dense(N, n_inner, 100), + ml.Relu([N, 100]), + ml.Dense(N, 100, 10), + ] + if dropout or 'dropout1' in program.args: + layers.insert(-1, ml.Dropout(N, 100)) +else: + raise Exception('need to specify number of dense layers') + +layers += [ml.MultiOutput(N, 10)] + +Y = sint.Matrix(n_test, 10) +X = sfix.Matrix(n_test, n_features) + +if not ('no_acc' in program.args and 'no_loss' in program.args): + layers[-1].Y.input_from(0) + layers[0].X.input_from(0) + Y.input_from(0) + X.input_from(0) + +optim = ml.Optimizer.from_args(program, layers) +optim.run_by_args(program, n_epochs, batch_size, X, Y) diff --git a/Programs/Source/prep_aes.mpc b/Programs/Source/prep_aes.mpc index 6f69c149f..35103da10 100644 --- a/Programs/Source/prep_aes.mpc +++ b/Programs/Source/prep_aes.mpc @@ -80,7 +80,7 @@ def ApplyBDEmbedding(x): def PreprocInverseEmbedding(x): in_bytes = x.bit_decompose(step=5) - out_bytes = [cgf2(0) for _ in range(8)] + out_bytes = [cgf2n(0) for _ in range(8)] out_bytes[7] = in_bytes[7] out_bytes[6] = in_bytes[6] + out_bytes[7] @@ -337,7 +337,7 @@ class SpdzBox(object): [0,1,0,0,1,0,1,0] ] to_add = [1,0,1,0,0,0,0,0] - self.addition_inv = [cgf2(_) for _ in to_add] + self.addition_inv = [cgf2n(_) for _ in to_add] self.forward_matrix = [ [1,0,0,0,1,1,1,1], [1,1,0,0,0,1,1,1], @@ -351,7 +351,7 @@ class SpdzBox(object): forward_add = [1,1,0,0,0,1,1,0] self.forward_add = Array(len(forward_add), cgf2) for i,x in enumerate(forward_add): - self.forward_add[i] = cgf2(x) + self.forward_add[i] = cgf2n(x) def __init__(self): constants = [ @@ -369,7 +369,7 @@ class SpdzBox(object): linear_transform = list() for row in self.forward_matrix: - result = cgf2(0) + result = cgf2n(0) for idx in range(len(row)): result = result + unembedded_x[idx] * row[idx] linear_transform.append(result) @@ -395,7 +395,7 @@ class SpdzBox(object): linear_transform = list() for row in self.matrix_inv: - result = cgf2(0) + result = cgf2n(0) for idx in range(len(row)): result = result + what_inv_bd[idx] * row[idx] linear_transform.append(result) diff --git a/Programs/Source/test_gc.mpc b/Programs/Source/test_gc.mpc index 130bb9689..268971d6f 100644 --- a/Programs/Source/test_gc.mpc +++ b/Programs/Source/test_gc.mpc @@ -9,7 +9,7 @@ test(sbits(3) - sbits(5), 3 ^ 5) test(sbit(1) * sbits(3), 3) #test(cbits(1) * cbits(3), 3) test(sbit(1) * 3, 3) -test(~sbits(1, n=64), 2**64 - 2) +test(~sbits.new(1, n=64), 2**64 - 2) test(sbits(5) & sbits(3), 5 & 3) test(sbits(3).equal(sbits(3)), 1) @@ -23,7 +23,7 @@ test(sbit(1).if_else(1, 2), 1) test(sbit(0).if_else(2, 1), 1) test(sbit(1).if_else(2, 1), 2) -test(sbits.compose((sbits(2, n=2), sbits(1, n=2)), 2), 6) +test(sbits.compose((sbits.new(2, n=2), sbits.new(1, n=2)), 2), 6) x = MemValue(sbits(1234)) program.curr_tape.start_new_basicblock() @@ -41,15 +41,15 @@ cbits(456).store_in_mem(1234) program.curr_tape.start_new_basicblock() test(cbits.load_mem(1234), 456) -test(sbits(1 << 63, n=64), 1 << 63) +test(sbits.new(1 << 63, n=64), 1 << 63) bits = sbits(0x1234, n=40).bit_decompose(40) test(sbits.bit_compose(bits), 0x1234) -test(sbits(5, n=4) ^ sbits(3, n=3), 6) -test(sbits(5, n=3) ^ sbits(3, n=4), 6) -test(sbits(13, n=4) ^ sbits(3, n=3), 14) -test(sbits(5, n=3) ^ sbits(11, n=4), 14) +test(sbits.new(5, n=4) ^ sbits.new(3, n=3), 6) +test(sbits.new(5, n=3) ^ sbits.new(3, n=4), 6) +test(sbits.new(13, n=4) ^ sbits.new(3, n=3), 14) +test(sbits.new(5, n=3) ^ sbits.new(11, n=4), 14) b = sbits.get_random_bit() test(b * (1 - b), 0) @@ -66,32 +66,32 @@ test(x, 0xa) test(y, 0xc) aa = [1, 2**63, 2**64 - 1] -a = sbitvec(sbits(x, n=64) for x in aa).elements() +a = sbitvec(sbits.new(x, n=64) for x in aa).elements() test(a[0], aa[0]) test(a[1], aa[1]) test(a[2], aa[2]) -a = sbitvec(sbits(x, n=64) for x in [1, 2**63, 2**64 - 1]).popcnt().elements() +a = sbitvec(sbits.new(x, n=64) for x in [1, 2**63, 2**64 - 1]).popcnt().elements() test(a[0], 1) test(a[1], 1) test(a[2], 64) -a = sbits(-1, n=64) +a = sbits.new(-1, n=64) test(a & a, 2**64 - 1) sbits.n = 64 -a = sbitvec(64 * [sbits(2**64 - 1, n=64)]).popcnt().elements() +a = sbitvec(64 * [sbits.new(2**64 - 1, n=64)]).popcnt().elements() test(a[0], 64) test(a[63], 64) -a = sbitintvec(sbits(x, n=64) for x in [2**63 - 1, 1]) -b = sbitintvec(sbits(x, n=64) for x in [1, -1]) +a = sbitintvec(sbits.new(x, n=64) for x in [2**63 - 1, 1]) +b = sbitintvec(sbits.new(x, n=64) for x in [1, -1]) c = (a + b).elements() test(c[0], 2**63) test(c[1], 0) -a = sbitintvec(sbits(x, n=64) for x in [1, 1, 2**63 - 1, 2**63]) -b = sbitintvec(sbits(x, n=64) for x in [1, 2, 2**63, 2**63 - 1]) +a = sbitintvec(sbits.new(x, n=64) for x in [1, 1, 2**63 - 1, 2**63]) +b = sbitintvec(sbits.new(x, n=64) for x in [1, 2, 2**63, 2**63 - 1]) c = sbitvec([a.less_than(b)]).v test(c[0], 0) test(c[1], 1) diff --git a/Programs/Source/tutorial.mpc b/Programs/Source/tutorial.mpc index c18a63a18..9d99c37da 100644 --- a/Programs/Source/tutorial.mpc +++ b/Programs/Source/tutorial.mpc @@ -27,9 +27,9 @@ test(a - b, -1) # Division can mean different things in different domains # and there has be a specified bit length in some, # so we use int_div() for integer division. -# k-bit division requires 4k-bit computation. +# k-bit division requires (4k+1)-bit computation. -test(b.int_div(a, 16), 2) +test(b.int_div(a, 15), 2) # comparisons produce 1 for true and 0 for false @@ -65,7 +65,7 @@ test(a[99], 99 * 98) # set the precision after the dot and in total -sfix.set_precision(16, 32) +sfix.set_precision(16, 31) # and the output precision in decimal digits diff --git a/Programs/Source/vickrey.mpc b/Programs/Source/vickrey.mpc index 2e7e0c7ea..c63a314c2 100644 --- a/Programs/Source/vickrey.mpc +++ b/Programs/Source/vickrey.mpc @@ -31,7 +31,7 @@ def f(_): for i in range(n_inputs): # i * 10 because inputs are all zero by default - bids[i] = Bid(i, value_type.get_raw_input_from(i % n_parties) + i * 10) + bids[i] = Bid(i, value_type.get_input_from(i % n_parties) + i * 10) #bids = [Bid(i, value_type(i * 10)) for i in range(n_parties)] def bid_sort(a, b): diff --git a/Protocols/FakeInput.h b/Protocols/FakeInput.h index ea3d118ea..32bfc3598 100644 --- a/Protocols/FakeInput.h +++ b/Protocols/FakeInput.h @@ -33,7 +33,7 @@ class FakeInput : public InputBase results.push_back(x); } - void add_other(int) + void add_other(int, int = -1) { } diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index 83379c7c4..0c5fe9697 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -8,6 +8,7 @@ #include "Replicated.h" #include "Math/Z2k.h" +#include "Processor/Instruction.h" template class FakeProtocol : public ProtocolBase @@ -157,12 +158,53 @@ class FakeProtocol : public ProtocolBase res = ((source + r) >> n_shift) - (r >> n_shift); #else T r; - r.randomize_part(G, n_shift - 1); + r.randomize_part(G, n_shift); res = (source + r) >> n_shift; #endif #endif } } + + void cisc(SubProcessor& processor, const Instruction& instruction) + { + int r0 = instruction.get_r(0); + string tag((char*)&r0, 4); + auto& args = instruction.get_start(); + if (tag == string("LTZ\0", 4)) + { + for (size_t i = 0; i < args.size(); i += args[i]) + { + assert(i + args[i] <= args.size()); + assert(args[i] == 6); + for (int j = 0; j < args[i + 1]; j++) + { + auto& res = processor.get_S()[args[i + 2] + j]; + res = T(processor.get_S()[args[i + 3] + j]).get_bit( + args[i + 4] - 1); + } + } + } + else if (tag == "Trun") + { + for (size_t i = 0; i < args.size(); i += args[i]) + { + assert(i + args[i] <= args.size()); + assert(args[i] == 8); + int k = args[i + 4]; + int m = args[i + 5]; + int s = args[i + 7]; + assert((s == 0) or (s == 1)); + for (int j = 0; j < args[i + 1]; j++) + { + auto& res = processor.get_S()[args[i + 2] + j]; + res = ((T(processor.get_S()[args[i + 3] + j]) + + (T(s) << (k - 1))) >> m) - (T(s) << (k - m - 1)); + } + } + } + else + throw runtime_error("unknown CISC instruction: " + tag); + } }; #endif /* PROTOCOLS_FAKEPROTOCOL_H_ */ diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index 8b083a7f0..f36a7b754 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -8,7 +8,6 @@ #include "GC/FakeSecret.h" #include "ShareInterface.h" -#include "Processor/NoLivePrep.h" #include "FakeMC.h" #include "FakeProtocol.h" #include "FakePrep.h" diff --git a/Protocols/HighGearKeyGen.cpp b/Protocols/HighGearKeyGen.cpp index c3465634e..cb3af5e9a 100644 --- a/Protocols/HighGearKeyGen.cpp +++ b/Protocols/HighGearKeyGen.cpp @@ -27,7 +27,10 @@ void PartSetup::key_and_mac_generation(Player& P, } X(5, 3) X(4, 3) X(3, 2) if (not done) - throw runtime_error("not compiled for choice of parameters"); + throw runtime_error( + "not compiled for choice of parameters, add X(" + + to_string(n_limbs[0]) + ", " + to_string(n_limbs[1]) + + ") at " + __FILE__ + ":" + to_string(__LINE__ - 5)); batch_size = backup; } diff --git a/Protocols/LowGearKeyGen.h b/Protocols/LowGearKeyGen.h index 396b68e1c..930aa19a4 100644 --- a/Protocols/LowGearKeyGen.h +++ b/Protocols/LowGearKeyGen.h @@ -46,8 +46,6 @@ class KeyGenProtocol template void binomial(vector_type& shares, T& prep); template - void hamming(vector_type& shares, T& prep); - template void secret_key(vector_type& shares, T& prep); vector_type schur_product(const vector_type& x, const vector_type& y); void output_to(int player, vector& opened, diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index e820e89c4..c1e7e8253 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -101,58 +101,13 @@ void KeyGenProtocol::binomial(vector_type& shares, T& prep) shares.fft(fftd); } -template -template -void KeyGenProtocol::hamming(vector_type& shares, T& prep) -{ - shares.resize(params.phi_m()); - int h = params.get_h(); - assert(h > 0); - assert(shares.size() / h * h == shares.size()); - int n_bits = log(shares.size() / h) / log(2); -// assert(size_t(h << n_bits) == shares.size()); - - for (auto& share : shares) - share = prep.get_bit(); - - auto& protocol = proc->protocol; - for (int i = 0; i < n_bits - 1; i++) - { - protocol.init_mul(proc); - for (auto& share : shares) - protocol.prepare_mul(share, prep.get_bit()); - protocol.exchange(); - for (auto& share : shares) - share = protocol.finalize_mul(); - } - - protocol.init_mul(proc); - auto one = share_type::constant(1, P.my_num(), MC->get_alphai()); - for (auto& share : shares) - protocol.prepare_mul(share, prep.get_bit() * 2 - one); - protocol.exchange(); - for (auto& share : shares) - share = protocol.finalize_mul(); - - shares.fft(fftd); -} - template template void KeyGenProtocol::secret_key(vector_type& shares, T& prep) { - assert(params.get_h() != 0); cerr << "Generate secret key by "; - if (params.get_h() > 0) - { - cerr << "Hamming weight" << endl; - hamming(shares, prep); - } - else - { - cerr << "binomial" << endl; - binomial(shares, prep); - } + cerr << "binomial" << endl; + binomial(shares, prep); } template @@ -246,7 +201,7 @@ void LowGearKeyGen::run(PairwiseSetup& setup) others_ciphertexts.resize(EC.proof.U, machine.pk.get_params()); Verifier verifier(EC.proof, setup.FieldD); verifier.NIZKPoK(others_ciphertexts, ciphertexts, - cleartexts, machine.pk, false); + cleartexts, machine.pk); machine.enc_alphas.clear(); for (int i = 0; i < P.num_players(); i++) @@ -268,7 +223,7 @@ void LowGearKeyGen::run(PairwiseSetup& setup) #endif timers["Verifying"].start(); verifier.NIZKPoK(others_ciphertexts, ciphertexts, - cleartexts, machine.other_pks[player], false); + cleartexts, machine.other_pks[player]); timers["Verifying"].stop(); machine.enc_alphas.at(player) = others_ciphertexts.at(0); } diff --git a/Protocols/MAC_Check.h b/Protocols/MAC_Check.h index d17eeef99..cf36c98c3 100644 --- a/Protocols/MAC_Check.h +++ b/Protocols/MAC_Check.h @@ -132,6 +132,10 @@ class MAC_Check_Z2k : public Tree_MAC_Check virtual ~MAC_Check_Z2k() {}; }; +template +using MAC_Check_Z2k_ = MAC_Check_Z2k; + template void add_openings(vector& values, const Player& P, int sum_players, int last_sum_players, int send_player, TreeSum& MC); diff --git a/Protocols/MaliciousRepPrep.h b/Protocols/MaliciousRepPrep.h index b9b744a0a..3e2c8d1c6 100644 --- a/Protocols/MaliciousRepPrep.h +++ b/Protocols/MaliciousRepPrep.h @@ -52,6 +52,8 @@ class MaliciousRepPrep : public MaliciousBitOnlyRepPrep public: MaliciousRepPrep(SubProcessor* proc, DataPositions& usage); MaliciousRepPrep(DataPositions& usage, int = 0); + template + MaliciousRepPrep(DataPositions& usage, GC::ShareThread&, int = 0); }; template diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index e968ae0b0..43b2292fa 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -27,6 +27,14 @@ MaliciousRepPrep::MaliciousRepPrep(DataPositions& usage, int) : { } +template +template +MaliciousRepPrep::MaliciousRepPrep(DataPositions& usage, + GC::ShareThread&, int) : + MaliciousRepPrep(0, usage) +{ +} + template MaliciousRepPrepWithBits::MaliciousRepPrepWithBits(SubProcessor* proc, DataPositions& usage) : diff --git a/Protocols/MaliciousShamirPO.hpp b/Protocols/MaliciousShamirPO.hpp index 1e867d176..92b25cf79 100644 --- a/Protocols/MaliciousShamirPO.hpp +++ b/Protocols/MaliciousShamirPO.hpp @@ -39,9 +39,9 @@ typename T::clear MaliciousShamirPO::finalize(const T& secret) for (int i = 0; i < P.num_players(); i++) { if (i == P.my_num()) - shares[i] = secret; + shares[0] = secret; else - shares[i].unpack(to_receive[i]); + shares[P.get_offset(i)].unpack(to_receive[i]); } return MC.reconstruct(shares); diff --git a/Protocols/NoLivePrep.h b/Protocols/NoLivePrep.h new file mode 100644 index 000000000..c53ec7e84 --- /dev/null +++ b/Protocols/NoLivePrep.h @@ -0,0 +1,53 @@ +/* + * NoLivePrep.h + * + */ + +#ifndef PROCESSOR_NOLIVEPREP_H_ +#define PROCESSOR_NOLIVEPREP_H_ + +#include "Tools/Exceptions.h" +#include "Protocols/ReplicatedPrep.h" + +template class SubProcessor; +class DataPositions; + +// preprocessing facility +template +class NoLivePrep : public BufferPrep +{ +public: + // global setup for encryption keys if needed + static void basic_setup(Player&) + { + } + + // destruct global setup + static void teardown() + { + } + + NoLivePrep(SubProcessor*, DataPositions& usage) : + BufferPrep(usage) + { + } + + // access to protocol instance if needed + void set_protocol(typename T::Protocol&) + { + } + + // buffer batch of multiplication triples in this->triples + void buffer_triples() + { + throw runtime_error("no triples"); + } + + // buffer batch of random bit shares in this->bits + void buffer_bits() + { + throw runtime_error("no bits"); + } +}; + +#endif /* PROCESSOR_NOLIVEPREP_H_ */ diff --git a/Protocols/NoProtocol.h b/Protocols/NoProtocol.h new file mode 100644 index 000000000..b99ce4e3e --- /dev/null +++ b/Protocols/NoProtocol.h @@ -0,0 +1,116 @@ +/* + * NoProtocol.h + * + */ + +#ifndef PROCESSOR_NOPROTOCOL_H_ +#define PROCESSOR_NOPROTOCOL_H_ + +#include "Protocols/Replicated.h" +#include "Protocols/MAC_Check_Base.h" + +// opening facility +template +class NoOutput : public MAC_Check_Base +{ +public: + NoOutput(const typename T::mac_key_type& mac_key, int = 0, int = 0) : + MAC_Check_Base(mac_key) + { + } + + // open shares in this->shares and put clear values in this->values + void exchange(const Player&) + { + throw runtime_error("no opening"); + } +}; + +// multiplication protocol +template +class NoProtocol : public ProtocolBase +{ +public: + Player& P; + + static int get_n_relevant_players() + { + throw runtime_error("no number of relevant players"); + return -1; + } + + NoProtocol(Player& P) : + P(P) + { + } + + // prepare next round of multiplications + void init_mul(SubProcessor*) + { + } + + // schedule multiplication + typename T::clear prepare_mul(const T&, const T&, int = -1) + { + throw runtime_error("no multiplication preparation"); + } + + // execute protocol + void exchange() + { + throw runtime_error("no multiplication exchange"); + } + + // return next product + T finalize_mul(int = -1) + { + throw runtime_error("no multiplication finalization"); + } +}; + +// private input facility +template +class NoInput : public InputBase +{ +public: + NoInput(SubProcessor&, typename T::MAC_Check&) + { + } + + // prepare next round of inputs from specific party + void reset(int) + { + } + + // schedule private input from me + void add_mine(const typename T::open_type&, int = -1) + { + throw runtime_error("no input from me"); + } + + // schedule private from someone else + void add_other(int, int = -1) + { + throw runtime_error("no input from others"); + } + + // send my inputs + void send_mine() + { + throw runtime_error("no sending of my inputs"); + } + + // return share corresponding to my next input + T finalize_mine() + { + throw runtime_error("no finalizing for my inputs"); + } + + // construct share corresponding to someone else's input + void finalize_other(int, T&, octetStream&, int = -1) + { + throw runtime_error("no finalizing of someone else's input"); + } +}; + +#endif /* PROCESSOR_NOPROTOCOL_H_ */ diff --git a/Protocols/NoShare.h b/Protocols/NoShare.h new file mode 100644 index 000000000..e44006400 --- /dev/null +++ b/Protocols/NoShare.h @@ -0,0 +1,187 @@ +/* + * NoShare.h + * + */ + +#ifndef PROTOCOLS_NOSHARE_H_ +#define PROTOCOLS_NOSHARE_H_ + +#include "ShareInterface.h" +#include "Math/bigint.h" +#include "Math/gfp.h" +#include "GC/NoShare.h" + +#include "NoLivePrep.h" +#include "NoProtocol.h" + +template +class NoShare : public ShareInterface +{ + typedef NoShare This; + +public: + // type for clear values in relevant domain + typedef T clear; + typedef clear open_type; + + // needs to be defined even if protocol doesn't use MACs + typedef clear mac_key_type; + + // disable binary computation + typedef GC::NoShare bit_type; + + // opening facility + typedef NoOutput MAC_Check; + typedef MAC_Check Direct_MC; + + // multiplication protocol + typedef NoProtocol Protocol; + + // preprocessing facility + typedef NoLivePrep LivePrep; + + // private input facility + typedef NoInput Input; + + // default private output facility (using input tuples) + typedef ::PrivateOutput PrivateOutput; + + // description used for debugging output + static string type_string() + { + return "no share"; + } + + // used for preprocessing storage location + static string type_short() + { + return "no"; + } + + // size in bytes + // must match assign/pack/unpack and machine-readable input/output + static int size() + { + throw runtime_error("no size"); + return -1; + } + + // maximum number of corrupted parties + // only used in virtual machine instruction + static int threshold(int) + { + throw runtime_error("no threshold"); + return -1; + } + + // serialize computation domain for client communication + static void specification(octetStream& os) + { + T::specification(os); + } + + // constant secret sharing + static This constant(const clear&, int, const mac_key_type&) + { + throw runtime_error("no constant sharing"); + return {}; + } + + // share addition + This operator+(const This&) + { + throw runtime_error("no share addition"); + return {}; + } + + // share subtraction + This operator-(const This&) + { + throw runtime_error("no share subtraction"); + return {}; + } + + This& operator+=(const This& other) + { + *this = *this + other; + return *this; + } + + This& operator-=(const This& other) + { + *this = *this - other; + return *this; + } + + // private-public multiplication + This operator*(const clear&) const + { + throw runtime_error("no private-public multiplication"); + return {}; + } + + // private-public division + This operator/(const clear&) const + { + throw runtime_error("no private-public division"); + return {}; + } + + // multiplication by power of two + This operator<<(int) const + { + throw runtime_error("no right shift"); + return {}; + } + + // assignment from byte string + // must match unpack + void assign(const char*) + { + throw runtime_error("no assignment"); + } + + // serialization + // must use the number of bytes given by size() + void pack(octetStream&, bool = false) const + { + throw runtime_error("no packing"); + } + + // serialization + // must use the number of bytes given by size() + void unpack(octetStream& os, bool = false) + { + assign((char*)os.consume(size())); + } + + // serialization + // must use the number of bytes given by size() for human=false + void input(istream& is, bool human) + { + if (human) + throw runtime_error("no human-readable input"); + else + { + char buf[size()]; + is.read(buf, size()); + assign(buf); + } + } + + // serialization + // must use the number of bytes given by size() for human=false + void output(ostream& os, bool human) const + { + if (human) + throw runtime_error("no human-readable output"); + else + { + octetStream buf; + pack(buf); + os.write((char*)buf.get_data(), size()); + } + } +}; + +#endif /* PROTOCOLS_NOSHARE_H_ */ diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index dc367fd7c..d115b4c5c 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -37,6 +37,11 @@ class RepShare : public FixedVec, public ShareInterface return 1; } + static void specification(octetStream& os) + { + T::specification(os); + } + RepShare() { } @@ -46,12 +51,10 @@ class RepShare : public FixedVec, public ShareInterface { } - void mul_by_bit(const This& x, const T& y) + void pack(octetStream& os, T) const { - (void) x, (void) y; - throw runtime_error("multiplication by bit not implemented"); + pack(os, false); } - void pack(octetStream& os, bool full = true) const { if (full) diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index 7cdf66235..d9225d5c7 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -47,7 +47,6 @@ class Rep3Share2 : public Rep3Share> { auto& P = protocol.P; int my_num = P.my_num(); - assert(n_bits <= 64); int unit = GC::Clear::N_BITS; for (int k = 0; k < DIV_CEIL(n_inputs, unit); k++) { @@ -57,29 +56,37 @@ class Rep3Share2 : public Rep3Share> switch (regs.size() / n_bits) { case 3: - for (int i = 0; i < n_bits; i++) - dest.at(regs.at(3 * i + my_num) + k) = {}; - - for (int i = 0; i < 2; i++) + for (int l = 0; l < n_bits; l += unit) { - square64 square; - - for (int j = 0; j < m; j++) - square.rows[j] = Integer(source[j + start][i]).get(); + int base = l; + int n_left = min(n_bits - base, unit); + for (int i = base; i < base + n_left; i++) + dest.at(regs.at(3 * i + my_num) + k) = {}; - square.transpose(m, n_bits); - - for (int j = 0; j < n_bits; j++) + for (int i = 0; i < 2; i++) { - auto &dest_reg = dest.at( - regs.at(3 * j + ((my_num + 2 - i) % 3)) + k); - dest_reg[1 - i] = 0; - dest_reg[i] = square.rows[j]; + square64 square; + + for (int j = 0; j < m; j++) + square.rows[j] = source[j + start][i].get_limb( + l / unit); + + square.transpose(m, n_left); + + for (int j = 0; j < n_left; j++) + { + auto& dest_reg = dest.at( + regs.at(3 * (base + j) + ((my_num + 2 - i) % 3)) + + k); + dest_reg[1 - i] = 0; + dest_reg[i] = square.rows[j]; + } } } break; case 2: { + assert(n_bits <= 64); ReplicatedInput input(P); input.reset_all(P); if (P.my_num() == 0) diff --git a/Protocols/Rep4Input.h b/Protocols/Rep4Input.h index d8c66eff7..f1bc29af9 100644 --- a/Protocols/Rep4Input.h +++ b/Protocols/Rep4Input.h @@ -29,7 +29,7 @@ class Rep4Input : public InputBase void reset(int player); void add_mine(const typename T::open_type& input, int n_bits = -1); - void add_other(int player); + void add_other(int player, int n_bits = -1); void send_mine(); void exchange(); diff --git a/Protocols/Rep4Input.hpp b/Protocols/Rep4Input.hpp index 5375923a4..5600b45c7 100644 --- a/Protocols/Rep4Input.hpp +++ b/Protocols/Rep4Input.hpp @@ -44,7 +44,7 @@ void Rep4Input::add_mine(const typename T::open_type& input, int) } template -void Rep4Input::add_other(int player) +void Rep4Input::add_other(int player, int) { auto& prot = protocol; T res; diff --git a/Protocols/Rep4Share.h b/Protocols/Rep4Share.h index 98aa54908..5e197804d 100644 --- a/Protocols/Rep4Share.h +++ b/Protocols/Rep4Share.h @@ -7,7 +7,6 @@ #define PROTOCOLS_REP4SHARE_H_ #include "Rep3Share.h" -#include "Processor/NoLivePrep.h" template class Rep4MC; template class Rep4; diff --git a/Protocols/Rep4Share2k.h b/Protocols/Rep4Share2k.h index a394b1ae7..d902cc760 100644 --- a/Protocols/Rep4Share2k.h +++ b/Protocols/Rep4Share2k.h @@ -7,7 +7,6 @@ #define PROTOCOLS_REP4SHARE2K_H_ #include "Rep4Share.h" -#include "Processor/NoLivePrep.h" #include "Processor/DummyProtocol.h" #include "GC/square64.h" diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 83729d793..4f746c485 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -50,6 +50,7 @@ class ProtocolBase vector random; int trunc_pr_counter; + int rounds, trunc_rounds; public: typedef T share_type; @@ -90,6 +91,9 @@ class ProtocolBase virtual void stop_exchange() {} virtual void check() {} + + virtual void cisc(SubProcessor&, const Instruction&) + { throw runtime_error("CISC instructions not implemented"); } }; template diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 3f604afdb..c078d08bd 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -21,7 +21,7 @@ template ProtocolBase::ProtocolBase() : - trunc_pr_counter(0), counter(0) + trunc_pr_counter(0), rounds(0), trunc_rounds(0), counter(0) { } @@ -67,10 +67,10 @@ template ProtocolBase::~ProtocolBase() { #ifdef VERBOSE_COUNT - if (counter) - cerr << "Number of " << T::type_string() << " multiplications: " << counter << endl; - if (trunc_pr_counter) - cerr << "Number of probabilistic truncations: " << trunc_pr_counter << endl; + if (counter or rounds) + cerr << "Number of " << T::type_string() << " multiplications: " << counter << " in " << rounds << " rounds" << endl; + if (trunc_pr_counter or trunc_rounds) + cerr << "Number of probabilistic truncations: " << trunc_pr_counter << " in " << trunc_rounds << " rounds" << endl; #endif } @@ -189,12 +189,14 @@ void Replicated::exchange() { if (os[0].get_length() > 0) P.pass_around(os[0], os[1], 1); + this->rounds++; } template void Replicated::start_exchange() { P.send_relative(1, os[0]); + this->rounds++; } template @@ -381,6 +383,7 @@ template void Replicated::trunc_pr(const vector& regs, int size, U& proc) { + this->trunc_rounds++; ::trunc_pr(regs, size, proc); } diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index d1e588334..29dbbf142 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -25,7 +25,7 @@ class PrepLessInput : public InputBase virtual void reset(int player) = 0; virtual void add_mine(const typename T::open_type& input, int n_bits = -1) = 0; - virtual void add_other(int player) = 0; + virtual void add_other(int player, int n_bits = - 1) = 0; virtual void send_mine() = 0; virtual void finalize_other(int player, T& target, octetStream& o, int n_bits = -1) = 0; @@ -73,7 +73,7 @@ class ReplicatedInput : public PrepLessInput void reset(int player); void add_mine(const typename T::open_type& input, int n_bits = -1); - void add_other(int player); + void add_other(int player, int n_bits = -1); void send_mine(); void exchange(); void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); diff --git a/Protocols/ReplicatedInput.hpp b/Protocols/ReplicatedInput.hpp index 1b527a732..741d2c490 100644 --- a/Protocols/ReplicatedInput.hpp +++ b/Protocols/ReplicatedInput.hpp @@ -40,7 +40,7 @@ inline void ReplicatedInput::add_mine(const typename T::open_type& input, int } template -void ReplicatedInput::add_other(int player) +void ReplicatedInput::add_other(int player, int) { expect[player] = true; } diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index 357677182..2576158d4 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -20,6 +20,11 @@ template void bits_from_random(vector& bits, typename T::Protocol& protocol); +namespace GC +{ +template class ShareThread; +} + template class BufferPrep : public Preprocessing { @@ -322,6 +327,12 @@ class ReplicatedPrep : public virtual ReplicatedRingPrep, { } + template + ReplicatedPrep(DataPositions& usage, GC::ShareThread&, int = 0) : + ReplicatedPrep(0, usage) + { + } + void buffer_squares() { ReplicatedRingPrep::buffer_squares(); } void buffer_bits(); }; diff --git a/Protocols/Semi2k.h b/Protocols/Semi2k.h index 927bb6c17..646c955e3 100644 --- a/Protocols/Semi2k.h +++ b/Protocols/Semi2k.h @@ -7,6 +7,7 @@ #define PROTOCOLS_SEMI2K_H_ #include "SPDZ.h" +#include "Processor/TruncPrTuple.h" template class Semi2k : public SPDZ @@ -23,6 +24,37 @@ class Semi2k : public SPDZ { res.randomize_part(G, n_bits); } + + void trunc_pr(const vector& regs, int size, + SubProcessor& proc) + { + if (this->P.num_players() > 2) + throw runtime_error("probabilistic truncation " + "only implemented for two players"); + + assert(regs.size() % 4 == 0); + this->trunc_pr_counter += size * regs.size() / 4; + typedef typename T::open_type open_type; + + vector> infos; + for (size_t i = 0; i < regs.size(); i += 4) + infos.push_back({regs, i}); + + for (auto& info : infos) + { + if (not info.big_gap()) + throw runtime_error("bit length too large"); + if (this->P.my_num()) + for (int i = 0; i < size; i++) + proc.get_S_ref(info.dest_base + i) = -open_type( + -open_type(proc.get_S()[info.source_base + i]) + >> info.m); + else + for (int i = 0; i < size; i++) + proc.get_S_ref(info.dest_base + i) = + proc.get_S()[info.source_base + i] >> info.m; + } + } }; #endif /* PROTOCOLS_SEMI2K_H_ */ diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index 595957c18..4587b7f31 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -8,7 +8,6 @@ #include "Protocols/Beaver.h" #include "Processor/DummyProtocol.h" -#include "Processor/NoLivePrep.h" #include "ShareInterface.h" #include diff --git a/Protocols/ShamirInput.h b/Protocols/ShamirInput.h index 5a4a1ca58..4567ae11e 100644 --- a/Protocols/ShamirInput.h +++ b/Protocols/ShamirInput.h @@ -29,7 +29,7 @@ class IndividualInput : public PrepLessInput } void reset(int player); - void add_other(int player); + void add_other(int player, int n_bits = -1); void send_mine(); void exchange(); void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); @@ -61,6 +61,11 @@ class ShamirInput : public IndividualInput { } + ShamirInput(ShamirMC&, Preprocessing&, Player& P) : + IndividualInput(0, P) + { + } + void add_mine(const typename T::open_type& input, int n_bits = -1); }; diff --git a/Protocols/ShamirInput.hpp b/Protocols/ShamirInput.hpp index 1443a7d6a..b4421b190 100644 --- a/Protocols/ShamirInput.hpp +++ b/Protocols/ShamirInput.hpp @@ -73,7 +73,7 @@ void ShamirInput::add_mine(const typename T::open_type& input, int n_bits) } template -void IndividualInput::add_other(int player) +void IndividualInput::add_other(int player, int) { (void) player; } diff --git a/Protocols/Share.h b/Protocols/Share.h index 97fddcc93..743a2c614 100644 --- a/Protocols/Share.h +++ b/Protocols/Share.h @@ -62,9 +62,6 @@ class Share_ : public ShareInterface static char type_char() { return T::type_char(); } - static DataFieldType field_type() - { return T::field_type(); } - static int threshold(int nplayers) { return T::threshold(nplayers); } @@ -72,6 +69,9 @@ class Share_ : public ShareInterface static void read_or_generate_mac_key(string directory, const Player& P, U& key); + static void specification(octetStream& os) + { T::specification(os); } + static Share_ constant(const clear& aa, int my_num, const typename V::Scalar& alphai) { return Share_(aa, my_num, alphai); } @@ -100,7 +100,6 @@ class Share_ : public ShareInterface /* Arithmetic Routines */ void mul(const Share_& S,const clear& aa); - void mul_by_bit(const Share_& S,const clear& aa); void add(const Share_& S1,const Share_& S2); void sub(const Share_& S1,const Share_& S2); @@ -210,10 +209,6 @@ class ArithmeticOnlyMascotShare : public Share super(share, mac) {} }; -// specialized mul by bit for gf2n -template <> -void Share_, SemiShare>::mul_by_bit(const Share_, SemiShare>& S,const gf2n& aa); - template Share_ operator*(const typename T::clear& y, const Share_& x) { Share_ res; res.mul(x, y); return res; } diff --git a/Protocols/Share.hpp b/Protocols/Share.hpp index a30d15d2e..62f90e91d 100644 --- a/Protocols/Share.hpp +++ b/Protocols/Share.hpp @@ -23,27 +23,10 @@ void Share_::read_or_generate_mac_key(string directory, const Player& P, } } -template -inline -void Share_::mul_by_bit(const Share_& S,const clear& aa) -{ - a.mul(S.a,aa); - mac.mul(S.mac,aa); -} - -template<> -inline -void Share_, SemiShare>::mul_by_bit( - const Share_, SemiShare>& S, const gf2n& aa) -{ - a.mul_by_bit(S.a,aa); - mac.mul_by_bit(S.mac,aa); -} - template inline void Share_::pack(octetStream& os, bool full) const { - a.pack(os); + a.pack(os, full); if (full) mac.pack(os); } @@ -51,7 +34,7 @@ inline void Share_::pack(octetStream& os, bool full) const template inline void Share_::unpack(octetStream& os, bool full) { - a.unpack(os); + a.unpack(os, full); if (full) mac.unpack(os); } diff --git a/Protocols/ShareInterface.cpp b/Protocols/ShareInterface.cpp new file mode 100644 index 000000000..c2222e879 --- /dev/null +++ b/Protocols/ShareInterface.cpp @@ -0,0 +1,8 @@ +/* + * ShareInterface.cpp + * + */ + +#include "ShareInterface.h" + +const int ShareInterface::default_length; diff --git a/Protocols/ShareInterface.h b/Protocols/ShareInterface.h index 08d5f2067..d4f3adaac 100644 --- a/Protocols/ShareInterface.h +++ b/Protocols/ShareInterface.h @@ -7,9 +7,15 @@ #define PROTOCOLS_SHAREINTERFACE_H_ #include +#include +#include using namespace std; +#include "Tools/Exceptions.h" + class Player; +class Instruction; +class ValueInterface; namespace GC { diff --git a/Protocols/SohoPrep.hpp b/Protocols/SohoPrep.hpp index b1223df17..331c26258 100644 --- a/Protocols/SohoPrep.hpp +++ b/Protocols/SohoPrep.hpp @@ -7,6 +7,7 @@ #include "FHEOffline/DataSetup.h" #include "ReplicatedPrep.hpp" +#include "FHEOffline/DataSetup.hpp" template PartSetup::FD>* SohoPrep::setup = 0; @@ -21,7 +22,7 @@ void SohoPrep::basic_setup(Player& P) setup = new PartSetup; MachineBase machine; setup->secure_init(P, machine, T::clear::length(), 0); - setup->key_and_mac_generation(P, machine, 1, true_type()); + read_or_generate_secrets(*setup, P, machine, 1, true_type()); T::clear::template init(); } diff --git a/Protocols/SpdzWiseInput.h b/Protocols/SpdzWiseInput.h index a0f77b77c..458fe02a1 100644 --- a/Protocols/SpdzWiseInput.h +++ b/Protocols/SpdzWiseInput.h @@ -32,15 +32,12 @@ class SpdzWiseInput : public InputBase void reset(int player); void add_mine(const typename T::open_type& input, int n_bits = -1); - void add_other(int player); + void add_other(int player, int n_bits = -1); void send_mine(); void exchange(); T finalize(int player, int n_bits = -1); T finalize_mine(); void finalize_other(int player, T& target, octetStream& o, int n_bits = -1); - - void start(int, int) { throw not_implemented(); } - void stop(int, const vector&) { throw not_implemented(); } }; #endif /* PROTOCOLS_SPDZWISEINPUT_H_ */ diff --git a/Protocols/SpdzWiseInput.hpp b/Protocols/SpdzWiseInput.hpp index 29d2708f7..ef7f549bf 100644 --- a/Protocols/SpdzWiseInput.hpp +++ b/Protocols/SpdzWiseInput.hpp @@ -47,9 +47,9 @@ void SpdzWiseInput::add_mine(const typename T::open_type& input, int n_bits) } template -void SpdzWiseInput::add_other(int player) +void SpdzWiseInput::add_other(int player, int n_bits) { - part_input.add_other(player); + part_input.add_other(player, n_bits); counters[player]++; } diff --git a/Protocols/SpdzWiseShare.h b/Protocols/SpdzWiseShare.h index 6386aa5be..101965cad 100644 --- a/Protocols/SpdzWiseShare.h +++ b/Protocols/SpdzWiseShare.h @@ -9,7 +9,6 @@ #include "Share.h" #include "SpdzWise.h" #include "Processor/DummyProtocol.h" -#include "Processor/NoProtocol.h" template class NoLivePrep; template class NotImplementedInput; @@ -57,6 +56,16 @@ class SpdzWiseShare : public Share_ static void read_or_generate_mac_key(string directory, Player& P, T& mac_key); + static open_type get_rec_factor(int i, int n) + { + return T::get_rec_factor(i, n); + } + + static void specification(octetStream& os) + { + T::specification(os); + } + SpdzWiseShare() { } @@ -70,6 +79,8 @@ class SpdzWiseShare : public Share_ super(share, mac) { } + + void pack(octetStream& os, open_type factor) const; }; template class MaliciousRep3Share; diff --git a/Protocols/SpdzWiseShare.hpp b/Protocols/SpdzWiseShare.hpp index 60dccba50..038556936 100644 --- a/Protocols/SpdzWiseShare.hpp +++ b/Protocols/SpdzWiseShare.hpp @@ -40,4 +40,10 @@ void SpdzWiseShare::read_or_generate_mac_key(string directory, Player& P, T& } } +template +void SpdzWiseShare::pack(octetStream& os, open_type factor) const +{ + this->get_share().pack(os, factor); +} + #endif /* PROTOCOLS_SPDZWISESHARE_HPP_ */ diff --git a/README.md b/README.md index d123a0538..abd2590e0 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,12 @@ sharing (with an honest majority). us, but you can also write an email to mp-spdz@googlegroups.com ([archive](https://groups.google.com/forum/#!forum/mp-spdz)). +#### Frequently Asked Questions + +[The documentation](https://mp-spdz.readthedocs.io/en/latest) contains +sections on a number of frequently asked topics as well as information +on how to solve common issues. + #### TL;DR (Binary Distribution on Linux or Source Distribution on macOS) This requires either a Linux distribution originally released 2014 or @@ -121,6 +127,14 @@ there are a few things to consider: communication, and it is the only one offering constant-communication dot products. +- Fixed-point multiplication: Three- and four-party replicated secret + sharing modulo a power of two allow a special probabilistic + truncation protocol (see [Dalskov et + al.](https://eprint.iacr.org/2019/131) and [Dalskov et + al.](https://eprint.iacr.org/2020/1330)). You can activate it by + adding `program.use_trunc_pr = True` at the beginning of your + high-level program. + - Minor variants: Some command-line options change aspects of the protocols such as: @@ -610,7 +624,7 @@ The following table shows all programs for honest-majority computation: | `malicious-rep-field-party.x` | Replicated | Mod prime | Y | 3 | `mal-rep-field.sh` | | `shamir-party.x` | Shamir | Mod prime | N | 3 or more | `shamir.sh` | | `malicious-shamir-party.x` | Shamir | Mod prime | Y | 3 or more | `mal-shamir.sh` | -| `sy-shamir-party.x` | SPDZ-wise Shamir | Mod prime | Y | 3 or more | `mal-shamir.sh` | +| `sy-shamir-party.x` | SPDZ-wise Shamir | Mod prime | Y | 3 or more | `sy-shamir.sh` | | `ccd-party.x` | CCD/Shamir | Binary | N | 3 or more | `ccd.sh` | | `malicious-cdd-party.x` | CCD/Shamir | Binary | Y | 3 or more | `mal-ccd.sh` | @@ -790,7 +804,7 @@ This sets up parameters for the online phase for 2 parties with a 128-bit prime Parameters can be customised by running -`Scripts/setup-online.sh ` +`Scripts/setup-online.sh []` #### To compile a program diff --git a/Scripts/emulate.sh b/Scripts/emulate.sh index fe840c7d2..9585c85ca 100755 --- a/Scripts/emulate.sh +++ b/Scripts/emulate.sh @@ -1,7 +1,7 @@ #!/bin/bash -test -e logs || mkdir logs +. $(dirname $0)/run-common.sh prog=${1%.sch} prog=${prog##*/} shift -$prefix ./emulate.x $prog $* 2>&1 | tee -a logs/emulate-$prog +$prefix ./emulate.x $prog $* 2>&1 | tee logs/emulate-$prog diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index c7a1e10b9..95e11499b 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -27,20 +27,24 @@ run_player() { port=$((RANDOM%10000+10000)) bin=$1 shift + prog=$1 + prog=${prog##*/} + prog=${prog%.sch} + shift if ! test -e $SPDZROOT/logs; then mkdir $SPDZROOT/logs fi if [[ $bin = Player-Online.x || $bin =~ 'party.x' ]]; then - params="$* -pn $port -h localhost" + params="$prog $* -pn $port -h localhost" if [[ ! ($bin =~ 'rep' || $bin =~ 'brain') ]]; then params="$params -N $players" fi else - params="$port localhost $*" + params="$port localhost $prog $*" fi rem=$(($players - 2)) - if test "$1"; then - log_prefix=$1- + if test "$prog"; then + log_prefix=$prog- fi for i in $(seq 0 $rem); do >&2 echo Running $prefix $SPDZROOT/$bin $i $params diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index 3c5413dd4..33dd3d6be 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -58,7 +58,9 @@ void BufferBase::prune() purge(); else if (file and file->tellg() != 0) { +#ifdef VERBOSE cerr << "Pruning " << filename << endl; +#endif string tmp_name = filename + ".new"; ofstream tmp(tmp_name.c_str()); tmp << file->rdbuf(); @@ -75,7 +77,9 @@ void BufferBase::purge() { if (file) { +#ifdef VERBOSE cerr << "Removing " << filename << endl; +#endif unlink(filename.c_str()); file->close(); file = 0; diff --git a/Tools/names.cpp b/Tools/names.cpp index a428f9bb5..062beb022 100644 --- a/Tools/names.cpp +++ b/Tools/names.cpp @@ -1,5 +1,5 @@ #include "Processor/Data_Files.h" const char* DataPositions::dtype_names[N_DTYPE + 1] = -{ "Triples", "Squares", "Bits", "Inverses", "BitTriples", "BitGF2NTriples", +{ "Triples", "Squares", "Bits", "Inverses", "daBits", "None" }; diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index 1f43ecf8d..d01814257 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -75,49 +75,6 @@ class FakeParams } }; -void make_bit_triples(const gf2n& key,int N,int ntrip,Dtype dtype,bool zero) -{ - PRNG G; - G.ReSeed(); - - ofstream* outf=new ofstream[N]; - gf2n a,b,c, one; - one.assign_one(); - vector > Sa(N),Sb(N),Sc(N); - /* Generate Triples */ - for (int i=0; i>(prep_data_prefix, N) - << DataPositions::dtype_names[dtype] << "-2-P" << i; - cout << "Opening " << filename.str() << endl; - outf[i].open(filename.str().c_str(),ios::out | ios::binary); - if (outf[i].fail()) { throw file_error(filename.str().c_str()); } - } - for (int i=0; i(key2,nplayers,ninv,zero,prep_data_prefix); if (T::clear::invertible) make_inverse(keyp,nplayers,ninv,zero,prep_data_prefix); - make_bit_triples(key2,nplayers,nbittrip,DATA_BITTRIPLE,zero); - make_bit_triples(key2,nplayers,nbitgf2ntrip,DATA_BITGF2NTRIPLE,zero); if (opt.isSet("-s")) { @@ -829,8 +784,8 @@ int FakeParams::generate() make_minimal>(keyt, nplayers, default_num / 64, zero); gf2n_short keytt; - generate_mac_keys>(keytt, nplayers, prep_data_prefix); - make_minimal>(keytt, nplayers, default_num / 64, zero); + generate_mac_keys>(keytt, nplayers, prep_data_prefix); + make_minimal>(keytt, nplayers, default_num / 64, zero); make_dabits(keyp, nplayers, default_num, zero, keytt); make_edabits(keyp, nplayers, default_num, zero, false_type(), keytt); diff --git a/azure-pipelines.yml b/azure-pipelines.yml index e617937a1..77df49704 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -5,6 +5,7 @@ trigger: - master +- cnn pool: vmImage: 'ubuntu-18.04' diff --git a/compile.py b/compile.py index c209a5e40..aeac32e43 100755 --- a/compile.py +++ b/compile.py @@ -78,6 +78,8 @@ def main(): "(number of parties as argument)") parser.add_option("-C", "--CISC", action="store_true", dest="cisc", help="faster CISC compilation mode") + parser.add_option("-K", "--keep-cisc", action="store_true", dest="keep_cisc", + help="don't translate CISC instructions") parser.add_option("-v", "--verbose", action="store_true", dest="verbose", help="more verbose output") options,args = parser.parse_args() diff --git a/doc/Compiler.rst b/doc/Compiler.rst index ac57a07f9..f6cdc7a88 100644 --- a/doc/Compiler.rst +++ b/doc/Compiler.rst @@ -20,7 +20,7 @@ Compiler.types module ClientMessageType, __weakref__, __repr__, reg_type, int_type, clear_type, float_type, basic_type, default_type, unreduced_type, bit_type, dynamic_array, - squant, mov, load_mem, store_in_mem, + squant, mov, write_share_to_socket, Compiler.GC.types module @@ -55,12 +55,14 @@ Compiler.mpc\_math module .. autofunction:: asin .. autofunction:: cos .. autofunction:: exp2_fx +.. autofunction:: InvertSqrt .. autofunction:: log2_fx .. autofunction:: log_fx .. autofunction:: pow_fx .. autofunction:: sin .. autofunction:: sqrt .. autofunction:: tan +.. autofunction:: tanh Compiler.ml module ------------------------- diff --git a/doc/add-protocol.rst b/doc/add-protocol.rst new file mode 100644 index 000000000..e9cad665b --- /dev/null +++ b/doc/add-protocol.rst @@ -0,0 +1,82 @@ +Adding a Protocol +----------------- + +In order to illustrate how to create a virtual machine for a new +protocol, we have created one with blanks to be filled in. It is +defined in the following files: + +``Machines/no-party.cpp`` + Contains the main function. + +``Protocols/NoShare.h`` + Contains the :c:type:`NoShare` class, which is supposed to hold one + share. :c:type:`NoShare` takes the cleartext type as a template + parameter. + +``Protocols/NoProtocol.h`` + Contains a number of classes representing instances of protocols: + + :c:type:`NoInput` + Private input. + :c:type:`NoProtocol` + Multiplication protocol. + :c:type:`NoOutput` + Public output. + +``Protocols/NoLivePrep.h`` + Contains the :c:type:`NoLivePrep` class, representing a + preprocessing instance. + +The number of blanks can be overwhelming. We therefore recommend the +following approach to get started. If the desired protocol resembles +one that is already implemented, you can check its code for +inspiration. The main function of ``-party.x`` can be found +in ``Machines/-party.cpp``, which in turns contains the name +of the share class. For example ``replicated-ring-party.x`` is +implemented in ``Machines/replicated-ring-party.cpp``, which refers to +:c:func:`Rep3Share2` in ``Protocols/Rep3Share2.h``. There you will +find that it uses :c:func:`Replicated` for multiplication, which is +found in ``Protocols/Replicated.h``. + +1. Fill in the :c:func:`constant` static member function of + :c:type:`NoShare` as well as the :c:func:`exchange` member function + of c:type:`NoOutput`. Check out + :c:func:`DirectSemiMC::exchange_` in ``Protocols/SemiMC.hpp`` + for a simple example. It opens an additive secret sharing by + sending all shares to all other parties and then summing up the + received. Constant sharing and public output allows to execute the + following program:: + + print_ln('%s', sint(123).reveal()) + + This allows to check the correct execution of further + functionality. + +2. Fill in the operator functions in :c:type:`NoShare` and check + them:: + + print_ln('%s', (sint(2) + sint(3)).reveal()) + print_ln('%s', (sint(2) - sint(3)).reveal()) + print_ln('%s', (sint(2) * cint(3)).reveal()) + + Many protocols use these basic operations, which makes it + beneficial to check the correctness + +3. Fill in :c:type:`NoProtocol`. Alternatively, if the desired + protocol is based on Beaver multiplication, you can specify the + following in :c:type:`NoShare`:: + + typedef Beaver Protocol; + + Then add the desired triple generation to + :c:func:`NoLivePrep::buffer_triples()`. In + any case you should then be able to execute:: + + print_ln('%s', (sint(2) * sint(3)).reveal()) + +4. In order to execute many kinds of non-linear computation, random + bits are needed. After filling in + :c:func:`NoLivePrep::buffer_bits()`, you should be able to + execute:: + + print_ln('%s', (sint(2) < sint(3)).reveal() diff --git a/doc/conf.py b/doc/conf.py index bb803a8ac..909f6dab4 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -21,7 +21,7 @@ # -- Project information ----------------------------------------------------- project = u'MP-SPDZ' -copyright = u'2020, Data61' +copyright = u'2021, CSIRO\'s Data61' author = u'Marcel Keller' # The short X.Y version diff --git a/doc/gen-instructions.py b/doc/gen-instructions.py index 1478a12f5..db0c5a2d8 100755 --- a/doc/gen-instructions.py +++ b/doc/gen-instructions.py @@ -31,3 +31,5 @@ d = d.replace('\n', '') d = d.strip() out.writerow([':py:class:`%s <%s>`' % (name, n), hex(opcode), d]) + +del out diff --git a/doc/index.rst b/doc/index.rst index d56258bf2..e54c1c3b2 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -9,6 +9,11 @@ implemented protocols etc. see https://github.com/data61/MP-SPDZ. Compilation process ------------------- +The easiest way of using MP-SPDZ is using the ``compile.py`` as +described below. If you would like to run compilation directly from +Python, see ``Scripts/direct_compilation_example.py``. It contains all +the necessary setup steps. + After putting your code in ``Program/Source/.mpc``, run the compiler from the root directory as follows @@ -151,6 +156,11 @@ Reference instructions low-level networking + io + non-linear + preprocessing + add-protocol + troubleshooting Indices and tables diff --git a/doc/io.rst b/doc/io.rst new file mode 100644 index 000000000..5d952c282 --- /dev/null +++ b/doc/io.rst @@ -0,0 +1,82 @@ +Input/Output +------------ + +This section gives an overview over the input/output facilities. + + +Private Inputs from Computing Parties +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +All secret types have an input function +(e.g. :py:func:`Compiler.types.sint.get_input_from` or +:py:func:`Compiler.types.sfix.get_input_from`). Inputs are read as +whitespace-separated text in order (independent of the data type) from +``Player-Data/Input-P-``, where ``thread`` is ``0`` for +the main thread. You can change the prefix (``Player-Data/Input``) +using the ``-IF`` option on the virtual machine binary. You can also +use ``-I`` to read inputs from the command line. + + +Public Inputs +~~~~~~~~~~~~~ + +All types can be assigned a hard-coded value at compile time, e.g. +``sint(1)``. This is impractical for larger amounts of +data. :py:func:`~Compiler.library.foreach_enumerate` provides a +facility for this case. It uses +:py:class:`~Compiler.library.public_input` internally, which reads +from ``Programs/Public-Input/``. + + +Public Outputs +~~~~~~~~~~~~~~ + +By default, :py:func:`~Compiler.library.print_ln` and related +functions only output to the terminal on party 0. This allows to run +several parties in one terminal without spoiling the output. You can +use interactive mode with option ``-I`` in order to output on all +parties. Note that this also to reading inputs from the command line +unless you specify ``-IF`` as well. You can also specify a file prefix +with ``-OF``, so that outputs are written to +``-P-``. + + +Private Outputs to Computing Parties +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Some types provide a function to reveal a value only to a specific +party (e.g., :py:func:`Compiler.types.sint.reveal_to`). It can be used +conjunction with :py:func:`~Compiler.library.print_ln_to` in order to +output it. + + +Clients (Non-computing Parties) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +:py:func:`Compiler.types.sint.receive_from_client` and +:py:func:`Compiler.types.sint.write_shares_to_socket` allow +communicating securely with the clients. See `this example +`_ +covering both client code and server-side high-level code. + + +Secret Shares +~~~~~~~~~~~~~ + +:py:func:`Compiler.types.sint.read_from_file` and +:py:func:`Compiler.types.sint.write_to_file` allow reading and writing +secret shares to and from files. + +Another possibility for persistence between program runs is to use the +fact that the memory is stored in +``Player-Data/Memory--P`` at the end of a run. The +best way to use this is via the memory access functions like +:py:func:`~Compiler.types.sint.store_in_mem` and +:py:func:`~Compiler.types.sint.load_mem`. Make sure to only use +addresses below ``USER_MEM`` specified in ``Compiler/config.py`` to +avoid conflicts with the automatic allocation used for arrays +etc. Note also that all types based on +:py:class:`~Compiler.types.sint` (e.g., +:py:class:`~Compiler.types.sfix`) share the same memory, and that the +address is only a base address. This means that vectors will be +written to the memory starting at the given address. diff --git a/doc/networking.rst b/doc/networking.rst index 10140b7f6..6062d3602 100644 --- a/doc/networking.rst +++ b/doc/networking.rst @@ -28,3 +28,33 @@ individually setting ports: The hosts can be both hostnames and IP addresses. If not given, the ports default to base plus party number. + + +Internal Infrastructure +~~~~~~~~~~~~~~~~~~~~~~~ + +The internal networking infrastructure of MP-SPDZ reflects the needs +of the various multi-party computation. For example, some protocols +require a simultaneous broadcast from all parties whereas other +protocols require that every party sends different information to +different parties (include none at all). The infrastructure makes sure +to send and receive in parallel whenever possible. + +All communication is handled through two subclasses of ``Player`` +defined in ``Networking/Player.h``. ``PlainPlayer`` communicates in +cleartext while ``CryptoPlayer`` uses TLS encryption. The former uses +the same BSD socket for sending and receiving but the latter uses two +different connections for sending and receiving. This is because TLS +communication is never truly one-way due key renewals etc., so the +only way for simultaneous sending and receiving we found was to use +two connections in two threads. + +If you wish to use a different networking facility, we recommend to +subclass ``Player`` and fill in the virtual-only functions required by +the compiler (e.g., ``send_to_no_stats`` for sending to one other +party). Note that not all protocols require all functions, so you only +need to properly implement those you need. You can then replace uses +of ``PlainPlayer`` or ``CryptoPlayer`` by your own class. Furthermore, +you might need to extend the ``Names`` class to suit your purpose. By +default, ``Names`` manages one TCP port that a party is listening on +for connections. If this suits you, you don't need to change anything diff --git a/doc/non-linear.rst b/doc/non-linear.rst new file mode 100644 index 000000000..e85f2e24d --- /dev/null +++ b/doc/non-linear.rst @@ -0,0 +1,72 @@ +Non-linear Computation +---------------------- + +While the computation of addition and multiplication varies from +protocol, non-linear computation such as comparison in arithmetic +domains (modulus other than two) only comes in three flavors +throughout MP-SPDZ: + +Unknown prime modulus + This approach goes back to `Catrina and Saxena + `_. It crucially relies on + the use of secret random bits in the arithmetic domain. Enough + such bits allow to mask a secret value so that it is secure to + reveal the masked value. This can then be split in bits as it is + public. The public bits and the secret mask bits are then used to + compute a number of non-linear functions. The same idea has been + used to implement `fixed-point + `_ and + `floating-point `_ computation. + We call this method "unknown prime modulus" because it only + mandates a minimum modulus size for a given cleartext range, which + is roughly the cleartext bit length plus a statistical security + parameter. It has the downside that there is implicit enforcement + of the cleartext range. + +Known prime modulus + `Damgård et al. `_ have + proposed non-linear computation that involves an exact prime + modulus. We have implemented the refined bit decomposition by + `Nishide and Ohta + `_, which enables + further non-linear computation. Our assumption with this method is + that the cleartext space is slightly smaller the full range modulo + the prime. This allows for comparison by taking a difference and + extracting the most significant bit, which is different than the + above works that implement comparison between two positive numbers + modulo the prime. We also used an idea by `Makri et + al. `_, namely that a random + :math:`k`-bit number is indistinguishable from a random number + modulo :math:`p` if the latter is close enough to :math:`2^k`. + +Power-of-two modulus + In the context of non-linear computation, there are two important + differences to prime modulus setting: + + 1. Multiplication with a power of two effectively erases some of + the most significant bits. + + 2. There is no right shift using multiplication. Modulo a prime, + multiplying with a power of the inverse of two allows to + right-shift numbers with enough zeros as least significant + bits. + + Taking this differences into account, `Dalskov et + al. `_ have adapted the + mask-and-reveal approach above to the setting of computation + modulo a power of two. + + +Mixed-Circuit Computation +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Another approach to non-linear computation is switching to binary +computation for parts of the computation. MP-SPDZ implements protocols +proposed for particular security models by a number of works: `Demmler et +al. `_, `Mohassel and Rindal +`_, and `Dalskov et +al. `_. MP-SPDZ also implements +more general methods such as `daBits +`_ and `edaBits +`_ + diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst new file mode 100644 index 000000000..3dadcfae6 --- /dev/null +++ b/doc/preprocessing.rst @@ -0,0 +1,29 @@ +Preprocessing +------------- + +Many protocols in MP-SPDZ use preprocessing, that is, producing secret +shares that are independent of the actual data but help with the +computation. Due to the independence, this can be done in batches to +save communication rounds and even communication when using +homomorphic encryption that works with large vectors such as LWE-based +encryption. + +Generally, preprocessing is done on demand and per computation +threads. On demand means that batches of preprocessing data are +computed whenever there is none in storage, and a computation thread +is a thread created by control flow instructions such as +:py:func:`~Compiler.library.for_range_multithread`. + +The exceptions to the general rule are edaBit generation with +malicious security and AND triples with malicious security and honest +majority, both when use bucket size three. Bucket size three implies +batches of over a million to achieve 40-bit statistical security, and +in honest-majority binary computation the item size is 64, which makes +the actual batch size 64 million triples. In multithreaded programs, +the preprocessing is run centrally using the threads as helpers. + +The batching means that the cost in terms of time and communication +jump whenever another batch is generated. Note that, while some +protocols are flexible with the batch size and can thus be controlled +using ``-b``, others mandate a batch size, which can be as large as a +million. diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst new file mode 100644 index 000000000..10b5fb4d7 --- /dev/null +++ b/doc/troubleshooting.rst @@ -0,0 +1,100 @@ +Troubleshooting +--------------- + +This section shows how to solve some common issues. + + +Crash without error message or ``bad_alloc`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Some protocols require several gigabytes of memory, and the virtual +machine will crash if there is not enough RAM. You can reduce the +memory usage for some malicious protocols with ``-B 4``. The memory +usage for malicious protocols based on homomorphic encryption can also +be reduced by using ``-T``. Finally, every computation thread requires +separate resources, so consider reducing the number of threads with +:py:func:`~Compiler.library.for_range_multithreads` and similar. + + +List indices must be integers or slices +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You cannot access Python lists with runtime variables because the +lists only exists at compile time. Consider using +:py:class:`~Compiler.types.Array`. + + +``compile.py`` takes too long +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you Python loops (``for``), the are unrolled at compile-time, +resulting in potentially too much virtual machine code. Consider using +:py:func:`~Compiler.library.for_range` or similar. + + +Order of memory instructions not preserved +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +By default, the compiler runs optimizations that in some corner case +can introduce errors with memory accesses such as accessing an +:py:func:`~Compiler.types.Array`. If you encounter such errors, you +can fix this either with ``-M`` when compiling or placing +`break_point()` around memory accesses. + + +Odd timings +~~~~~~~~~~~ + +Many protocols use preprocessing, which means they execute expensive +computation to generates batches of information that can be used for +computation until the information is used up. An effect of this is +that computation can seem oddly slow or fast. For example, one +multiplication has a similar cost then some thousand multiplications +when using homomorphic encryption because one batch contains +information for more than than 10,000 multiplications. Only when a +second batch is necessary the cost shoots up. + + +Handshake failures +~~~~~~~~~~~~~~~~~~ + +If you run on different hosts, the certificates +(``Player-Data/*.pem``) must be the same on all of them. Also make +sure to run ``c_rehash Player-Data`` on all hosts. Finally, the +certificate generated by ``Scripts/setup-ssl.sh`` expire after a +month, so you might to regenerate them. + + +Not compiled for choice of parameters +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +HighGear and LowGear only support a limited choice of parameters +because they to be chosen when compiling the binaries. You can follow +the instructions in error message and recompile the binaries in order +fix this. + + +Illegal instruction +~~~~~~~~~~~~~~~~~~~ + +By default, the binaries are optimized for the machine they are +compiled on. If you try to run them an another one, make sure set +``ARCH`` in ``CONFIG`` accordingly. Furthermore, if you run on an x86 +processor without AVX (produced before 2011), you need to set +``AVX_OT = 0`` to run dishonest-majority protocols. + + +Computation used more preprocessing than expected +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This indicates an error in the internal accounting of +preprocessing. Please file a bug report. + + +``mac_fail`` +~~~~~~~~~~~~ + +This is a catch-all failure in protocols with malicious protocols that +can be caused by something being wrong at any level. Please file a bug +report with the specifics of your case. +