From e07d9bf2a3231fe95557106371ce25f3da32f5d6 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 11 Jan 2022 16:04:59 +1100 Subject: [PATCH] Maintenance. --- .gitmodules | 2 +- BMR/Party.cpp | 2 +- BMR/RealGarbleWire.hpp | 2 +- BMR/RealProgramParty.hpp | 8 +- BMR/Register.h | 3 + BMR/TrustedParty.cpp | 6 + BMR/TrustedParty.h | 3 +- CHANGELOG.md | 13 ++ Compiler/GC/types.py | 10 +- Compiler/comparison.py | 11 +- Compiler/compilerLib.py | 2 +- Compiler/exceptions.py | 5 +- Compiler/floatingpoint.py | 14 +- Compiler/instructions.py | 17 +- Compiler/instructions_base.py | 83 ++++++++- Compiler/library.py | 23 ++- Compiler/ml.py | 12 +- Compiler/non_linear.py | 43 +++-- Compiler/program.py | 28 ++- Compiler/types.py | 50 +++-- ECDSA/hm-ecdsa-party.hpp | 4 +- ECDSA/mascot-ecdsa-party.cpp | 2 + ECDSA/ot-ecdsa-party.hpp | 4 +- ECDSA/preprocessing.hpp | 10 +- ECDSA/sign.hpp | 20 +- ExternalIO/Client.h | 25 +++ ExternalIO/README.md | 59 ++---- FHE/FHE_Params.h | 3 - FHE/NTL-Subs.cpp | 7 +- FHE/NTL-Subs.h | 5 +- FHE/NoiseBounds.cpp | 1 - FHE/Ring_Element.cpp | 22 ++- FHE/Ring_Element.h | 1 + FHEOffline/PairwiseGenerator.cpp | 2 +- FHEOffline/SimpleGenerator.h | 2 +- GC/BitAdder.hpp | 2 +- GC/CcdPrep.h | 5 - GC/CcdPrep.hpp | 8 + GC/CcdShare.h | 1 + GC/FakeSecret.h | 3 + GC/Instruction.cpp | 2 +- GC/NoShare.h | 6 +- GC/PostSacriBin.cpp | 16 +- GC/PostSacriBin.h | 5 +- GC/RepPrep.hpp | 5 + GC/Secret.h | 3 + GC/SemiPrep.cpp | 13 +- GC/SemiPrep.h | 4 +- GC/ShareSecret.h | 2 + GC/ShareSecret.hpp | 10 +- GC/ShareThread.h | 3 - GC/ShareThread.hpp | 3 +- GC/Thread.h | 2 - GC/Thread.hpp | 7 - GC/ThreadMaster.hpp | 4 +- GC/TinierSharePrep.h | 2 - GC/TinierSharePrep.hpp | 18 +- GC/TinyPrep.hpp | 2 +- GC/VectorInput.h | 6 + GC/VectorProtocol.h | 7 +- GC/VectorProtocol.hpp | 18 +- GC/instructions.h | 2 +- License.txt | 2 +- Machines/Atlas.hpp | 16 ++ Machines/Rep.hpp | 1 + Machines/Rep4.hpp | 17 ++ Machines/SPDZ.hpp | 12 +- Machines/SPDZ2k.hpp | 11 +- Machines/Semi.hpp | 1 + Machines/Semi2k.hpp | 15 ++ Machines/ShamirMachine.hpp | 1 + Machines/Tinier.cpp | 23 +++ Machines/atlas-party.cpp | 7 +- Machines/emulate.cpp | 8 +- Machines/hemi-party.cpp | 1 + Machines/no-party.cpp | 1 + Machines/soho-party.cpp | 1 + Makefile | 28 +-- Math/BitVec.h | 20 +- Math/Setup.hpp | 5 +- Math/ValueInterface.h | 1 + Math/Z2k.h | 3 + Math/Zp_Data.cpp | 36 ++++ Math/Zp_Data.h | 33 +--- Math/gfp.h | 2 +- Networking/CryptoPlayer.cpp | 5 + Networking/Player.cpp | 49 +++-- Networking/Player.h | 17 +- Networking/Receiver.cpp | 8 + Networking/Sender.cpp | 12 +- Networking/Server.cpp | 46 +++-- Networking/Server.h | 7 +- Networking/ssl_sockets.h | 13 ++ OT/BaseOT.cpp | 18 ++ OT/NPartyTripleGenerator.h | 13 +- Processor/BaseMachine.cpp | 34 +++- Processor/BaseMachine.h | 15 +- Processor/Binary_File_IO.hpp | 2 +- Processor/Data_Files.h | 37 ++-- Processor/Data_Files.hpp | 18 +- Processor/DummyProtocol.h | 4 +- Processor/FieldMachine.h | 5 +- Processor/FieldMachine.hpp | 3 +- Processor/HonestMajorityMachine.cpp | 2 +- Processor/Input.h | 5 +- Processor/Input.hpp | 10 +- Processor/Instruction.hpp | 2 + Processor/Machine.h | 1 - Processor/Machine.hpp | 16 +- Processor/Memory.h | 3 + Processor/NoFilePrep.h | 22 +++ Processor/OfflineMachine.hpp | 6 +- Processor/Online-Thread.hpp | 15 +- Processor/OnlineOptions.cpp | 14 ++ Processor/OnlineOptions.h | 6 + Processor/OnlineOptions.hpp | 30 +++ Processor/PrepBase.cpp | 20 +- Processor/PrepBase.h | 6 +- Processor/Processor.h | 4 - Processor/Processor.hpp | 50 +---- Processor/RingMachine.h | 2 +- Processor/RingMachine.hpp | 3 +- Processor/ThreadQueue.cpp | 21 +++ Processor/ThreadQueue.h | 5 + Processor/TruncPrTuple.h | 37 +++- Programs/Source/keras_mnist_lenet_predict.mpc | 44 +++++ Protocols/Atlas.h | 11 +- Protocols/Atlas.hpp | 9 +- Protocols/Beaver.h | 11 +- Protocols/Beaver.hpp | 30 ++- Protocols/BrainShare.h | 2 + Protocols/FakeProtocol.h | 35 ++-- Protocols/FakeShare.h | 3 + Protocols/Hemi.h | 6 +- Protocols/Hemi.hpp | 22 ++- Protocols/HighGearKeyGen.cpp | 2 +- Protocols/LowGearKeyGen.cpp | 2 +- Protocols/LowGearKeyGen.hpp | 2 +- Protocols/MAC_Check.hpp | 2 + Protocols/MalRepRingPrep.hpp | 41 ----- Protocols/MaliciousRep3Share.h | 1 + Protocols/MaliciousRepPO.h | 8 +- Protocols/MaliciousRepPO.hpp | 18 +- Protocols/MaliciousRepPrep.hpp | 5 +- Protocols/MamaPrep.hpp | 1 + Protocols/MascotPrep.h | 2 - Protocols/MascotPrep.hpp | 12 +- Protocols/NoProtocol.h | 4 +- Protocols/PostSacriRepRingShare.h | 2 + Protocols/PostSacrifice.h | 4 +- Protocols/PostSacrifice.hpp | 7 +- Protocols/ProtocolSet.h | 107 +++++++++++ Protocols/ProtocolSetup.h | 95 ++++++++++ Protocols/Rep3Share.h | 27 +++ Protocols/Rep3Share2k.h | 12 -- Protocols/Rep4.h | 12 +- Protocols/Rep4.hpp | 34 ++-- Protocols/Rep4Prep.hpp | 2 +- Protocols/Replicated.h | 20 +- Protocols/Replicated.hpp | 174 +++++++----------- Protocols/ReplicatedInput.h | 3 +- Protocols/ReplicatedInput.hpp | 2 +- Protocols/ReplicatedPO.h | 24 +++ Protocols/ReplicatedPO.hpp | 21 +++ Protocols/ReplicatedPrep.h | 25 ++- Protocols/ReplicatedPrep.hpp | 150 +++++++++------ Protocols/{Semi2k.h => Semi.h} | 23 ++- Protocols/Semi2kShare.h | 6 +- Protocols/SemiShare.h | 8 +- Protocols/Shamir.h | 16 +- Protocols/Shamir.hpp | 14 +- Protocols/ShuffleSacrifice.hpp | 16 +- Protocols/Spdz2kPrep.h | 3 - Protocols/Spdz2kPrep.hpp | 35 ++-- Protocols/SpdzWise.h | 11 +- Protocols/SpdzWise.hpp | 28 +-- Protocols/SpdzWiseInput.hpp | 3 +- Protocols/SpdzWisePrep.hpp | 13 +- Protocols/SpdzWiseRing.hpp | 2 +- Protocols/SquarePrep.h | 6 +- README.md | 6 +- Scripts/decompile.py | 16 ++ Scripts/memory-usage.py | 29 +++ Scripts/run-common.sh | 31 +--- Scripts/test_streaming.sh | 4 + Scripts/tldr.sh | 3 +- Tools/BitVector.cpp | 9 + Tools/BitVector.h | 9 +- Tools/Buffer.cpp | 13 +- Tools/Bundle.h | 2 +- Tools/TimerWithComm.cpp | 23 +++ Tools/TimerWithComm.h | 23 +++ Tools/benchmarking.cpp | 15 ++ Tools/benchmarking.h | 3 + Tools/octetStream.h | 4 +- Tools/random.cpp | 4 +- Tools/random.h | 2 + Utils/Fake-Offline.cpp | 2 +- Utils/binary-example.cpp | 140 ++++++++++++++ Utils/mixed-example.cpp | 137 ++++++++++++++ Utils/paper-example.cpp | 49 ++--- Utils/stream-fake-mascot-triples.cpp | 21 ++- Yao/YaoEvaluator.h | 3 - Yao/YaoGarbler.cpp | 5 - Yao/YaoGarbler.h | 2 - Yao/YaoWire.h | 4 + Yao/YaoWire.hpp | 20 ++ doc/Doxyfile | 2 +- doc/conf.py | 5 +- doc/index.rst | 14 +- doc/io.rst | 10 + doc/low-level.rst | 142 +++++--------- doc/networking.rst | 6 +- doc/non-linear.rst | 4 +- doc/preprocessing.rst | 64 ++++++- doc/troubleshooting.rst | 21 ++- 216 files changed, 2406 insertions(+), 1113 deletions(-) create mode 100644 Machines/Atlas.hpp create mode 100644 Machines/Rep4.hpp create mode 100644 Machines/Semi2k.hpp create mode 100644 Machines/Tinier.cpp create mode 100644 Processor/NoFilePrep.h create mode 100644 Processor/OnlineOptions.hpp create mode 100644 Programs/Source/keras_mnist_lenet_predict.mpc create mode 100644 Protocols/ProtocolSet.h create mode 100644 Protocols/ProtocolSetup.h create mode 100644 Protocols/ReplicatedPO.h create mode 100644 Protocols/ReplicatedPO.hpp rename Protocols/{Semi2k.h => Semi.h} (75%) create mode 100755 Scripts/decompile.py create mode 100755 Scripts/memory-usage.py create mode 100644 Tools/TimerWithComm.cpp create mode 100644 Tools/TimerWithComm.h create mode 100644 Tools/benchmarking.cpp create mode 100644 Utils/binary-example.cpp create mode 100644 Utils/mixed-example.cpp diff --git a/.gitmodules b/.gitmodules index 455a55143..32dca28be 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,7 +3,7 @@ url = https://github.com/mkskeller/SimpleOT [submodule "mpir"] path = mpir - url = git://github.com/wbhart/mpir.git + url = https://github.com/wbhart/mpir [submodule "Programs/Circuits"] path = Programs/Circuits url = https://github.com/mkskeller/bristol-fashion diff --git a/BMR/Party.cpp b/BMR/Party.cpp index 5ca1360ab..84ba909b3 100644 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -259,7 +259,7 @@ ProgramParty::~ProgramParty() reset(); if (P) { - cerr << "Data sent: " << 1e-6 * P->comm_stats.total_data() << " MB" << endl; + cerr << "Data sent: " << 1e-6 * P->total_comm().total_data() << " MB" << endl; delete P; } delete[] eval_threads; diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp index 55adcbfb1..760a20b89 100644 --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -175,7 +175,7 @@ void GarbleInputter::exchange() assert(party.P != 0); assert(party.MC != 0); auto& protocol = party.shared_proc->protocol; - protocol.init_mul(party.shared_proc); + protocol.init_mul(); for (auto& tuple : tuples) protocol.prepare_mul(tuple.first->mask, T::constant(1, party.P->my_num(), party.mac_key) diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp index 0c97f9bd8..8e16c3077 100644 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -155,7 +155,7 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : while (next != GC::DONE_BREAK); MC->Check(*P); - data_sent = P->comm_stats.total_data() + prep->data_sent(); + data_sent = P->total_comm().sent; this->machine.write_memory(this->N.my_num()); } @@ -173,7 +173,8 @@ void RealProgramParty::garble() garble_jobs.clear(); garble_inputter->reset_all(*P); auto& protocol = *garble_protocol; - protocol.init_mul(shared_proc); + protocol.init(*prep, shared_proc->MC); + protocol.init_mul(); next = this->first_phase(program, garble_processor, this->garble_machine); @@ -181,7 +182,8 @@ void RealProgramParty::garble() protocol.exchange(); typename T::Protocol second_protocol(*P); - second_protocol.init_mul(shared_proc); + second_protocol.init(*prep, shared_proc->MC); + second_protocol.init_mul(); for (auto& job : garble_jobs) job.middle_round(*this, second_protocol); diff --git a/BMR/Register.h b/BMR/Register.h index d0a75e930..f348f7b7e 100644 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -293,6 +293,9 @@ class ProgramRegister : public Phase, public Register template static void convcbit2s(GC::Processor&, const BaseInstruction&) { throw runtime_error("convcbit2s not implemented"); } + template + static void andm(GC::Processor&, const BaseInstruction&) + { throw runtime_error("andm not implemented"); } // most BMR phases don't need actual input template diff --git a/BMR/TrustedParty.cpp b/BMR/TrustedParty.cpp index 6bd1ba264..439bcfc73 100644 --- a/BMR/TrustedParty.cpp +++ b/BMR/TrustedParty.cpp @@ -42,6 +42,12 @@ BaseTrustedParty::BaseTrustedParty() _received_gc_received = 0; n_received = 0; randomfd = open("/dev/urandom", O_RDONLY); + done_filling = false; +} + +BaseTrustedParty::~BaseTrustedParty() +{ + close(randomfd); } TrustedProgramParty::TrustedProgramParty(int argc, char** argv) : diff --git a/BMR/TrustedParty.h b/BMR/TrustedParty.h index 24e8120de..260e7a516 100644 --- a/BMR/TrustedParty.h +++ b/BMR/TrustedParty.h @@ -20,7 +20,7 @@ class BaseTrustedParty : virtual public CommonFakeParty { vector msg_input_masks; BaseTrustedParty(); - virtual ~BaseTrustedParty() {} + virtual ~BaseTrustedParty(); /* From NodeUpdatable class */ virtual void NodeReady(); @@ -104,7 +104,6 @@ class TrustedProgramParty : public BaseTrustedParty { void add_all_keys(const Register& reg, bool external); }; - inline void BaseTrustedParty::add_keys(const Register& reg) { for(int p = 0; p < get_n_parties(); p++) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c9be9e5b..2b75d24f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. +## 0.2.9 (Jan 11, 2021) + +- Disassembler +- Run-time parameter for probabilistic truncation error +- Probabilistic truncation for some protocols computing modulo a prime +- Simplified C++ interface +- Comparison as in [ACCO](https://dl.acm.org/doi/10.1145/3474123.3486757) +- More general scalar-vector multiplication +- Complete memory support for clear bits +- Extended clear bit functionality with Yao's garbled circuits +- Allow preprocessing information to be supplied via named pipes +- In-place operations for containers + ## 0.2.8 (Nov 4, 2021) - Tested on Apple laptop with ARM chip diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py index 5a65e73ac..53da15ba2 100644 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -112,10 +112,16 @@ def load_mem(cls, address, mem_type=None, size=None): return cls.load_dynamic_mem(address) else: for i in range(res.size): - cls.load_inst[util.is_constant(address)](res[i], address + i) + cls.mem_op(cls.load_inst, res[i], address + i) return res def store_in_mem(self, address): - self.store_inst[isinstance(address, int)](self, address) + self.mem_op(self.store_inst, self, address) + @staticmethod + def mem_op(inst, reg, address): + direct = isinstance(address, int) + if not direct: + address = regint.conv(address) + inst[direct](reg, address) @classmethod def new(cls, value=None, n=None): if util.is_constant(value): diff --git a/Compiler/comparison.py b/Compiler/comparison.py index f4cf89ad6..2f7ca81f5 100644 --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -77,13 +77,16 @@ def LTZ(s, a, k, kappa): k: bit length of a """ + movs(s, program.non_linear.ltz(a, k, kappa)) + +def LtzRing(a, k): from .types import sint, _bitint from .GC.types import sbitvec if program.use_split(): summands = a.split_to_two_summands(k) carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands))) msb = carry ^ summands[0][-1] ^ summands[1][-1] - movs(s, sint.conv(msb)) + return sint.conv(msb) return elif program.options.ring: from . import floatingpoint @@ -96,11 +99,7 @@ def LTZ(s, a, k, kappa): a = r_bin[0].bit_decompose_clear(c_prime, m) b = r_bin[:m] u = CarryOutRaw(a[::-1], b[::-1]) - movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))) - return - t = sint() - Trunc(t, a, k, k - 1, kappa, True) - subsfi(s, t, 0) + return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)) def LessThanZero(a, k, kappa): from . import types diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py index 64e76434c..b2898e21a 100644 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -82,7 +82,7 @@ def run(args, options): prog.finalize() if prog.req_num: - print('Program requires:') + print('Program requires at most:') for x in prog.req_num.pretty(): print(x) diff --git a/Compiler/exceptions.py b/Compiler/exceptions.py index fd0265637..c68ecd317 100644 --- a/Compiler/exceptions.py +++ b/Compiler/exceptions.py @@ -12,4 +12,7 @@ class ArgumentError(CompilerError): """ Exception raised for errors in instruction argument parsing. """ def __init__(self, arg, msg): self.arg = arg - self.msg = msg \ No newline at end of file + self.msg = msg + +class VectorMismatch(CompilerError): + pass diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py index a15a62dd9..c596240b6 100644 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -392,7 +392,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): for i in range(1,l): ci[i] = c % two_power(i) c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l)) - lts(d, c_dprime, r_prime, l, kappa) + d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa) if compute_modulo: b = c_dprime - r_prime + pow2m * d return b, pow2m @@ -629,12 +629,14 @@ def BITLT(a, b, bit_length): # - From the paper # Multiparty Computation for Interval, Equality, and Comparison without # Bit-Decomposition Protocol -def BitDecFull(a, maybe_mixed=False): +def BitDecFull(a, n_bits=None, maybe_mixed=False): from .library import get_program, do_while, if_, break_point from .types import sint, regint, longint, cint p = get_program().prime assert p bit_length = p.bit_length() + n_bits = n_bits or bit_length + assert n_bits <= bit_length logp = int(round(math.log(p, 2))) if abs(p - 2 ** logp) / p < 2 ** -get_program().security: # inspired by Rabbit (https://eprint.iacr.org/2021/119) @@ -677,12 +679,12 @@ def _(): czero = (c==0) q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t) fbar = [bbits[0].clear_type.conv(cint(x)) - for x in ((1< 0. + """ Crash runtime if the value in the register is not zero. :param: Crash condition (regint)""" code = base.opcodes['CRASH'] @@ -1275,7 +1275,7 @@ class prep(base.Instruction): field_type = 'modp' def add_usage(self, req_node): - req_node.increment((self.field_type, self.args[0]), 1) + req_node.increment((self.field_type, self.args[0]), self.get_size()) def has_var_args(self): return True @@ -2407,19 +2407,6 @@ def expand(self): subml(self.args[0], s[5], c[1]) -@base.gf2n -@base.vectorize -class lts(base.CISC): - """ Secret comparison $s_i = (s_j < s_k)$. """ - __slots__ = [] - arg_format = ['sw', 's', 's', 'int', 'int'] - - def expand(self): - from .types import sint - a = sint() - subs(a, self.args[1], self.args[2]) - comparison.LTZ(self.args[0], a, self.args[3], self.args[4]) - # placeholder for documentation class cisc: """ Meta instruction for emulation. This instruction is only generated diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 38fd97d29..fb2a67b89 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -4,6 +4,8 @@ import inspect import functools import copy +import sys +import struct from Compiler.exceptions import * from Compiler.config import * from Compiler import util @@ -299,11 +301,12 @@ def maybe_vectorized_instruction(*args, **kwargs): vectorized_name = 'v' + instruction.__name__ Vectorized_Instruction.__name__ = vectorized_name global_dict[vectorized_name] = Vectorized_Instruction + + if 'sphinx.extension' in sys.modules: + return instruction + global_dict[instruction.__name__ + '_class'] = instruction - instruction.__doc__ = '' - # exclude GF(2^n) instructions from documentation - if instruction.code and instruction.code >> 8 == 1: - maybe_vectorized_instruction.__doc__ = '' + maybe_vectorized_instruction.arg_format = instruction.arg_format return maybe_vectorized_instruction @@ -389,8 +392,11 @@ def maybe_gf2n_instruction(*args, **kwargs): else: global_dict[GF2N_Instruction.__name__] = GF2N_Instruction + if 'sphinx.extension' in sys.modules: + return instruction + global_dict[instruction.__name__ + '_class'] = instruction_cls - instruction_cls.__doc__ = '' + maybe_gf2n_instruction.arg_format = instruction.arg_format return maybe_gf2n_instruction #return instruction @@ -661,6 +667,12 @@ def encode(cls, arg): assert arg.i >= 0 return int_to_bytes(arg.i) + def __init__(self, f): + self.i = struct.unpack('>I', f.read(4))[0] + + def __str__(self): + return self.reg_type + str(self.i) + class ClearModpAF(RegisterArgFormat): reg_type = RegType.ClearModp @@ -686,6 +698,12 @@ def check(cls, arg): def encode(cls, arg): return int_to_bytes(arg) + def __init__(self, f): + self.i = struct.unpack('>i', f.read(4))[0] + + def __str__(self): + return str(self.i) + class ImmediateModpAF(IntArgFormat): @classmethod def check(cls, arg): @@ -722,6 +740,13 @@ def check(cls, arg): def encode(cls, arg): return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg)) + def __init__(self, f): + tmp = f.read(16) + self.str = str(tmp[0:tmp.find(b'\0')], 'ascii') + + def __str__(self): + return self.str + ArgFormats = { 'c': ClearModpAF, 's': SecretModpAF, @@ -890,6 +915,54 @@ def __str__(self): def __repr__(self): return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')' +class ParsedInstruction: + reverse_opcodes = {} + + def __init__(self, f): + cls = type(self) + from Compiler import instructions + from Compiler.GC import instructions as gc_inst + if not cls.reverse_opcodes: + for module in instructions, gc_inst: + for x, y in inspect.getmodule(module).__dict__.items(): + if inspect.isclass(y) and y.__name__[0] != 'v': + try: + cls.reverse_opcodes[y.code] = y + except AttributeError: + pass + read = lambda: struct.unpack('>I', f.read(4))[0] + full_code = read() + code = full_code % (1 << Instruction.code_length) + self.size = full_code >> Instruction.code_length + self.type = cls.reverse_opcodes[code] + t = self.type + name = t.__name__ + try: + n_args = len(t.arg_format) + self.var_args = False + except: + n_args = read() + self.var_args = True + try: + arg_format = iter(t.arg_format) + except: + if name == 'cisc': + arg_format = itertools.chain(['str'], itertools.repeat('int')) + else: + arg_format = itertools.repeat('int') + self.args = [ArgFormats[next(arg_format)](f) + for i in range(n_args)] + + def __str__(self): + name = self.type.__name__ + res = name + ' ' + if self.size > 1: + res = 'v' + res + str(self.size) + ', ' + if self.var_args: + res += str(len(self.args)) + ', ' + res += ', '.join(str(arg) for arg in self.args) + return res + class VarArgsInstruction(Instruction): def has_var_args(self): return True diff --git a/Compiler/library.py b/Compiler/library.py index 529608dc2..7bab1951a 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -219,6 +219,9 @@ def crash(condition=None): :param condition: crash if true (default: true) """ + if isinstance(condition, localint): + # allow crash on local values + condition = condition._v if condition == None: condition = regint(1) instructions.crash(regint.conv(condition)) @@ -1347,6 +1350,8 @@ def while_loop(loop_body, condition, arg, g=None): arg = regint(arg) def loop_fn(): result = loop_body(arg) + if isinstance(result, MemValue): + result = result.read() result.link(arg) cont = condition(result) return cont @@ -1531,6 +1536,8 @@ def decorator(body): def if_e(condition): """ Conditional execution with else block. + Use :py:class:`~Compiler.types.MemValue` to assign values that + live beyond. :param condition: regint/cint/int @@ -1538,12 +1545,13 @@ def if_e(condition): .. code:: + y = MemValue(0) @if_e(x > 0) def _(): - ... + y.write(1) @else_ def _(): - ... + y.write(0) """ try: condition = bool(condition) @@ -1647,11 +1655,18 @@ def get_player_id(): return res def listen_for_clients(port): - """ Listen for clients on specific port. """ + """ Listen for clients on specific port base. + + :param port: port base (int/regint/cint) + """ instructions.listen(regint.conv(port)) def accept_client_connection(port): - """ Listen for clients on specific port. """ + """ Accept client connection on specific port base. + + :param port: port base (int/regint/cint) + :returns: client id + """ res = regint() instructions.acceptclientconnection(res, regint.conv(port)) return res diff --git a/Compiler/ml.py b/Compiler/ml.py index 5ff1a3753..7e53a78f8 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1810,6 +1810,7 @@ def __init__(self, report_loss=None): self.print_loss_reduction = False self.i_epoch = MemValue(0) self.stopped_on_loss = MemValue(0) + self.stopped_on_low_loss = MemValue(0) @property def layers(self): @@ -1932,6 +1933,7 @@ def run(self, batch_size=None, stop_on_loss=0): """ Run training. :param batch_size: batch size (defaults to example size of first layer) + :param stop_on_loss: stop when loss falls below this (default: 0) """ if self.n_epochs == 0: return @@ -2013,6 +2015,7 @@ def _(j): if self.tol > 0: res *= (1 - (loss_sum >= 0) * \ (loss_sum < self.tol * n_per_epoch)).reveal() + self.stopped_on_low_loss.write(1 - res) return res def reveal_correctness(self, data, truth, batch_size): @@ -2138,6 +2141,7 @@ def _(): if depreciation: self.gamma.imul(depreciation) print_ln('reducing learning rate to %s', self.gamma) + return 1 - self.stopped_on_low_loss if 'model_output' in program.args: self.output_weights() @@ -2386,6 +2390,7 @@ def trainable_variables(self): return list(self.opt.thetas) def build(self, input_shape, batch_size=128): + data_input_shape = input_shape if self.opt != None and \ input_shape == self.opt.layers[0].X.sizes and \ batch_size <= self.batch_size and \ @@ -2458,9 +2463,10 @@ def build(self, input_shape, batch_size=128): else: raise Exception(layer[0] + ' not supported') if layers[-1].d_out == 1: - layers.append(Output(input_shape[0])) + layers.append(Output(data_input_shape[0])) else: - layers.append(MultiOutput(input_shape[0], layers[-1].d_out)) + layers.append( + MultiOutput(data_input_shape[0], layers[-1].d_out)) if self.optimizer[1]: raise Exception('use keyword arguments for optimizer') opt = self.optimizer[0] @@ -2504,7 +2510,7 @@ def fit(self, x, y, batch_size, epochs=1, validation_data=None): if x.total_size() != self.opt.layers[0].X.total_size(): raise Exception('sample data size mismatch') if y.total_size() != self.opt.layers[-1].Y.total_size(): - print (y, layers[-1].Y) + print (y, self.opt.layers[-1].Y) raise Exception('label size mismatch') if validation_data == None: validation_data = None, None diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py index 43e10c2e6..01cb4db58 100644 --- a/Compiler/non_linear.py +++ b/Compiler/non_linear.py @@ -1,7 +1,7 @@ from .comparison import * from .floatingpoint import * from .types import * -from . import comparison +from . import comparison, program class NonLinear: kappa = None @@ -30,6 +30,15 @@ def mod2m(self, a, k, m, signed): def trunc_pr(self, a, k, m, signed=True): if isinstance(a, types.cint): return shift_two(a, m) + prog = program.Program.prog + if prog.use_trunc_pr: + if signed and prog.use_trunc_pr != -1: + a += (1 << (k - 1)) + res = sint() + trunc_pr(res, a, k, m) + if signed and prog.use_trunc_pr != -1: + res -= (1 << (k - m - 1)) + return res return self._trunc_pr(a, k, m, signed) def trunc_round_nearest(self, a, k, m, signed): @@ -44,6 +53,9 @@ def trunc(self, a, k, m, kappa, signed): return a return self._trunc(a, k, m, signed) + def ltz(self, a, k, kappa=None): + return -self.trunc(a, k, k - 1, kappa, True) + class Masking(NonLinear): def eqz(self, a, k): c, r = self._mask(a, k) @@ -100,42 +112,44 @@ def __init__(self, prime): def _mod2m(self, a, k, m, signed): if signed: a += cint(1) << (k - 1) - return sint.bit_compose(self.bit_dec(a, k, k, True)[:m]) + return sint.bit_compose(self.bit_dec(a, k, m, True)) def _trunc_pr(self, a, k, m, signed): # nearest truncation return self.trunc_round_nearest(a, k, m, signed) def _trunc(self, a, k, m, signed=None): - if signed: - a += cint(1) << (k - 1) - res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:]) - if signed: - res -= cint(1) << (k - 1 - m) - return res + return TruncZeros(a - self._mod2m(a, k, m, signed), k, m, signed) def trunc_round_nearest(self, a, k, m, signed): a += cint(1) << (m - 1) if signed: a += cint(1) << (k - 1) k += 1 - res = sint.bit_compose(self.bit_dec(a, k, k, True)[m:]) + res = self._trunc(a, k, m, False) if signed: res -= cint(1) << (k - m - 2) return res def bit_dec(self, a, k, m, maybe_mixed=False): assert k < self.prime.bit_length() - bits = BitDecFull(a, maybe_mixed=maybe_mixed) - if len(bits) < m: - raise CompilerError('%d has fewer than %d bits' % (self.prime, m)) - return bits[:m] + bits = BitDecFull(a, m, maybe_mixed=maybe_mixed) + assert len(bits) == m + return bits def eqz(self, a, k): # always signed a += two_power(k) return 1 - types.sintbit.conv(KORL(self.bit_dec(a, k, k, True))) + def ltz(self, a, k, kappa=None): + if k + 1 < self.prime.bit_length(): + # https://dl.acm.org/doi/10.1145/3474123.3486757 + # "negative" values wrap around when doubling, thus becoming odd + return self.mod2m(2 * a, k + 1, 1, False) + else: + return super(KnownPrime, self).ltz(a, k, kappa) + class Ring(Masking): """ Non-linear functionality modulo a power of two known at compile time. """ @@ -172,3 +186,6 @@ def trunc_round_nearest(self, a, k, m, signed): return TruncRing(None, tmp + 1, k - m + 1, 1, signed) else: return super(Ring, self).trunc_round_nearest(a, k, m, signed) + + def ltz(self, a, k, kappa=None): + return LtzRing(a, k) diff --git a/Compiler/program.py b/Compiler/program.py index 19ce52480..5dad8e516 100644 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -578,6 +578,15 @@ def disable_memory_warnings(self): self.warn_about_mem.append(False) self.curr_block.warn_about_mem = False + @staticmethod + def read_tapes(schedule): + if not os.path.exists(schedule): + schedule = 'Programs/Schedules/%s.sch' % schedule + + lines = open(schedule).readlines() + for tapename in lines[2].split(' '): + yield tapename.strip() + class Tape: """ A tape contains a list of basic blocks, onto which instructions are added. """ def __init__(self, name, program): @@ -1109,7 +1118,20 @@ def require_bit_length(self, bit_length, t='p'): else: self.req_bit_length[t] = max(bit_length, self.req_bit_length) - class Register(object): + @staticmethod + def read_instructions(tapename): + tape = open('Programs/Bytecode/%s.bc' % tapename, 'rb') + while tape.peek(): + yield inst_base.ParsedInstruction(tape) + + class _no_truth(object): + __slots__ = [] + + def __bool__(self): + raise CompilerError('Cannot derive truth value from register, ' + "consider using 'compile.py -l'") + + class Register(_no_truth): """ Class for creating new registers. The register's index is automatically assigned based on the block's reg_counter dictionary. @@ -1233,10 +1255,6 @@ def is_clear(self): self.reg_type == RegType.ClearGF2N or \ self.reg_type == RegType.ClearInt - def __bool__(self): - raise CompilerError('Cannot derive truth value from register, ' - "consider using 'compile.py -l'") - def __str__(self): return self.reg_type + str(self.i) diff --git a/Compiler/types.py b/Compiler/types.py index 48bb27a1d..33df2e373 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -127,7 +127,7 @@ def vectorized_operation(self, *args, **kwargs): if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \ and not isinstance(args[0], bits) \ and args[0].size != self.size: - raise CompilerError('Different vector sizes of operands: %d/%d' + raise VectorMismatch('Different vector sizes of operands: %d/%d' % (self.size, args[0].size)) set_global_vector_size(self.size) try: @@ -221,7 +221,7 @@ def inputmixed(*args): else: instructions.inputmixedreg(*(args[:-1] + (regint.conv(args[-1]),))) -class _number(object): +class _number(Tape._no_truth): """ Number functionality. """ def square(self): @@ -246,7 +246,11 @@ def __mul__(self, other): elif is_one(other): return self else: - return self.mul(other) + try: + return self.mul(other) + except VectorMismatch: + # try reverse multiplication + return NotImplemented __radd__ = __add__ __rmul__ = __mul__ @@ -320,7 +324,7 @@ def __abs__(self): def popcnt_bits(bits): return sum(bits) -class _int(object): +class _int(Tape._no_truth): """ Integer functionality. """ @staticmethod @@ -408,7 +412,7 @@ def half_adder(self, other): def long_one(): return 1 -class _bit(object): +class _bit(Tape._no_truth): """ Binary functionality. """ def bit_xor(self, other): @@ -474,7 +478,7 @@ def bit_xor(self, other): def bit_not(self): return self ^ 1 -class _structure(object): +class _structure(Tape._no_truth): """ Interface for type-dependent container types. """ MemValue = classmethod(lambda cls, value: MemValue(cls.conv(value))) @@ -591,7 +595,7 @@ def traverse(content, level): res.input_from(player) return res -class _vec(object): +class _vec(Tape._no_truth): def link(self, other): assert len(self.v) == len(other.v) for x, y in zip(self.v, other.v): @@ -726,7 +730,7 @@ def expand_to_vector(self, size=None): assert self.size == 1 res = type(self)(size=size) for i in range(size): - movs(res[i], self) + self.mov(res[i], self) return res class _clear(_register): @@ -1010,9 +1014,10 @@ def less_than(self, other, bit_length): if bit_length <= 64: return regint(self) < regint(other) else: + sint.require_bit_length(bit_length + 1) diff = self - other - diff += (1 << (bit_length - 1)) - shifted = diff >> (bit_length - 1) + diff += 1 << bit_length + shifted = diff >> bit_length res = 1 - regint(shifted & 1) return res @@ -1646,7 +1651,7 @@ def binary_output(self, player=None): player = -1 intoutput(player, self) -class localint(object): +class localint(Tape._no_truth): """ Local integer that must prevented from leaking into the secure computation. Uses regint internally. @@ -1669,7 +1674,7 @@ def output(self): __eq__ = lambda self, other: localint(self._v == other) __ne__ = lambda self, other: localint(self._v != other) -class personal(object): +class personal(Tape._no_truth): def __init__(self, player, value): assert value is not NotImplemented assert not isinstance(value, _secret) @@ -2003,9 +2008,11 @@ def mul(self, other): size or one size 1 for a value-vector multiplication. :param other: any compatible type """ - if isinstance(other, _secret) and (1 in (self.size, other.size)) \ + if isinstance(other, _register) and (1 in (self.size, other.size)) \ and (self.size, other.size) != (1, 1): x, y = (other, self) if self.size < other.size else (self, other) + if not isinstance(other, _secret): + return y.expand_to_vector(x.size) * x res = type(self)(size=x.size) mulrs(res, x, y) return res @@ -2221,11 +2228,13 @@ def get_raw_input_from(cls, player): @vectorized_classmethod def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType): """ Securely obtain shares of values input by a client. + This uses the triple-based input protocol introduced by + `Damgård et al. `_ :param n: number of inputs (int) :param client_id: regint :param size: vector size (default 1) - + :returns: list of sint """ # send shares of a triple to client triples = list(itertools.chain(*(sint.get_random_triple() for i in range(n)))) @@ -2910,7 +2919,7 @@ def bit_decompose_embedding(self): sint.bit_type = sintbit sgf2n.bit_type = sgf2n -class _bitint(object): +class _bitint(Tape._no_truth): bits = None log_rounds = False linear_rounds = False @@ -3521,6 +3530,7 @@ def from_int(cls, other): @classmethod def _new(cls, other, k=None, f=None): + assert not isinstance(other, (list, tuple)) res = cls(k=k, f=f) res.v = cint.conv(other) return res @@ -3567,6 +3577,8 @@ def __len__(self): return len(self.v) def __getitem__(self, index): + if isinstance(index, slice): + return [self._new(x, k=self.k, f=self.f) for x in self.v[index]] return self._new(self.v[index], k=self.k, f=self.f) @vectorize @@ -3608,7 +3620,6 @@ def add(self, other): else: return NotImplemented - @vectorize def mul(self, other): """ Clear fixed-point multiplication. @@ -4045,7 +4056,8 @@ def set_precision_from_args(cls, program, adapt_ring=False): 'for fixed-point computation') cls.round_nearest = True if adapt_ring and program.options.ring \ - and 'fix_ring' not in program.args: + and 'fix_ring' not in program.args \ + and 2 * cls.k > int(program.options.ring): need = 2 ** int(math.ceil(math.log(2 * cls.k, 2))) if need != int(program.options.ring): print('Changing computation modulus to 2^%d' % need) @@ -4489,7 +4501,7 @@ def for_mux(self, other): def __neg__(self): return self._new(-self.v + 2 * util.expand(self.Z, self.v.size)) -class _unreduced_squant(object): +class _unreduced_squant(Tape._no_truth): def __init__(self, v, params, res_params=None, n_summands=1): self.v = v self.params = params @@ -5011,7 +5023,7 @@ def reveal(self): :return: cfloat """ return cfloat(self.v.reveal(), self.p.reveal(), self.z.reveal(), self.s.reveal()) -class cfloat(object): +class cfloat(Tape._no_truth): """ Helper class for printing revealed sfloats. """ __slots__ = ['v', 'p', 'z', 's', 'nan'] diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp index a68f8e833..fc19e989b 100644 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -52,10 +52,10 @@ void run(int argc, const char** argv) P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); pShare sk = typename T::Honest::Protocol(P).get_random(); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; DataPositions usage; diff --git a/ECDSA/mascot-ecdsa-party.cpp b/ECDSA/mascot-ecdsa-party.cpp index 87573593b..920397cef 100644 --- a/ECDSA/mascot-ecdsa-party.cpp +++ b/ECDSA/mascot-ecdsa-party.cpp @@ -5,6 +5,8 @@ #define NO_MIXED_CIRCUITS +#define NO_SECURITY_CHECK + #include "GC/TinierSecret.h" #include "GC/TinyMC.h" #include "GC/VectorInput.h" diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp index 58f35d4b0..569aa791f 100644 --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -113,10 +113,10 @@ void run(int argc, const char** argv) P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); sk_prep.get_two(DATA_INVERSE, sk, __); cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples; typename pShare::TriplePrep prep(0, usage); diff --git a/ECDSA/preprocessing.hpp b/ECDSA/preprocessing.hpp index 334d5d1ba..0a5e0ab9c 100644 --- a/ECDSA/preprocessing.hpp +++ b/ECDSA/preprocessing.hpp @@ -41,8 +41,8 @@ void preprocessing(vector>& tuples, int buffer_size, timer.start(); Player& P = proc.P; auto& prep = proc.DataF; - size_t start = P.sent + prep.data_sent(); - auto stats = P.comm_stats + prep.comm_stats(); + size_t start = P.total_comm().sent; + auto stats = P.total_comm(); auto& extra_player = P; auto& protocol = proc.protocol; @@ -77,7 +77,7 @@ void preprocessing(vector>& tuples, int buffer_size, MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player); if (prep_mul) { - protocol.init_mul(&proc); + protocol.init_mul(); for (int i = 0; i < buffer_size; i++) protocol.prepare_mul(inv_ks[i], sk); protocol.start_exchange(); @@ -106,9 +106,9 @@ void preprocessing(vector>& tuples, int buffer_size, timer.stop(); cout << "Generated " << buffer_size << " tuples in " << timer.elapsed() << " seconds, throughput " << buffer_size / timer.elapsed() << ", " - << 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size + << 1e-3 * (P.total_comm().sent - start) / buffer_size << " kbytes per tuple" << endl; - (P.comm_stats + prep.comm_stats() - stats).print(true); + (P.total_comm() - stats).print(true); } template class T> diff --git a/ECDSA/sign.hpp b/ECDSA/sign.hpp index 10991276a..5686349e2 100644 --- a/ECDSA/sign.hpp +++ b/ECDSA/sign.hpp @@ -61,8 +61,7 @@ EcSignature sign(const unsigned char* message, size_t length, (void) pk; Timer timer; timer.start(); - size_t start = P.sent; - auto stats = P.comm_stats; + auto stats = P.total_comm(); EcSignature signature; vector opened_R; if (opts.R_after_msg) @@ -71,7 +70,7 @@ EcSignature sign(const unsigned char* message, size_t length, auto& protocol = proc->protocol; if (proc) { - protocol.init_mul(proc); + protocol.init_mul(); protocol.prepare_mul(sk, tuple.a); protocol.start_exchange(); } @@ -91,9 +90,9 @@ EcSignature sign(const unsigned char* message, size_t length, auto rx = tuple.R.x(); signature.s = MC.open( tuple.a * hash_to_scalar(message, length) + prod * rx, P); + auto diff = (P.total_comm() - stats); cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending " - << (P.sent - start) << " bytes" << endl; - auto diff = (P.comm_stats - stats); + << diff.sent << " bytes" << endl; diff.print(true); return signature; } @@ -139,11 +138,11 @@ void sign_benchmark(vector>& tuples, T sk, P.unchecked_broadcast(bundle); Timer timer; timer.start(); - auto stats = P.comm_stats; + auto stats = P.total_comm(); P256Element pk = MCc.open(sk, P); MCc.Check(P); cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl; - (P.comm_stats - stats).print(true); + (P.total_comm() - stats).print(true); for (size_t i = 0; i < min(10lu, tuples.size()); i++) { @@ -154,13 +153,12 @@ void sign_benchmark(vector>& tuples, T sk, Timer timer; timer.start(); auto& check_player = MCp.get_check_player(P); - auto stats = check_player.comm_stats; - auto start = check_player.sent; + auto stats = check_player.total_comm(); MCp.Check(P); MCc.Check(P); + auto diff = (check_player.total_comm() - stats); cout << "Online checking took " << timer.elapsed() * 1e3 << " ms and sending " - << (check_player.sent - start) << " bytes" << endl; - auto diff = (check_player.comm_stats - stats); + << diff.sent << " bytes" << endl; diff.print(); } } diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h index 12ba1c938..5f8e76fd3 100644 --- a/ExternalIO/Client.h +++ b/ExternalIO/Client.h @@ -8,6 +8,9 @@ #include "Networking/ssl_sockets.h" +/** + * Client-side interface + */ class Client { vector plain_sockets; @@ -15,15 +18,37 @@ class Client ssl_service io_service; public: + /** + * Sockets for cleartext communication + */ vector sockets; + + /** + * Specification of computation domain + */ octetStream specification; + /** + * Start a new set of connections to computing parties. + * @param hostnames location of computing parties + * @param port_base port base + * @param my_client_id client identifier + */ Client(const vector& hostnames, int port_base, int my_client_id); ~Client(); + /** + * Securely input private values. + * @param values vector of integer-like values + */ template void send_private_inputs(const vector& values); + /** + * Securely receive output values. + * @param n number of values + * @returns vector of integer-like values + */ template vector receive_outputs(int n); }; diff --git a/ExternalIO/README.md b/ExternalIO/README.md index 36649f5c7..d4f99288b 100644 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -19,6 +19,8 @@ Scripts/.sh bankers_bonus-1 & ./bankers-bonus-client.x 1 200 0 & ./bankers-bonus-client.x 2 50 1 ``` +`` can be any arithmetic protocol (e.g., `mascot`) but not a +binary protocol (e.g., `yao`). This should output that the winning id is 1. Note that the ids have to be incremental, and the client with the highest id has to input 1 as the last argument while the others have to input 0 there. Furthermore, @@ -32,54 +34,21 @@ different hosts, you will have to distribute the `*.pem` files. ### Connection Setup -**listen**(*int port_num*) - -Setup a socket server to listen for client connections. Runs in own thread so once created clients will be able to connect in the background. - -*port_num* - the port number to listen on. - -**acceptclientconnection**(*regint client_socket_id*, *int port_num*) - -Picks the first available client socket connection. Blocks if none available. - -*client_socket_id* - an identifier used to refer to the client socket. - -*port_num* - the port number identifies the socket server to accept connections on. +1. [Listen for clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.listen_for_clients) +2. [Accept client connections](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.library.accept_client_connection) +3. [Close client connections](https://mp-spdz.readthedocs.io/en/latest/instructions.html#Compiler.instructions.closeclientconnection) ### Data Exchange -Only the sint methods are documented here, equivalent methods are available for the other data types **cfix**, **cint** and **regint**. See implementation details in [types.py](../Compiler/types.py). - -*[sint inputs]* **sint.read_from_socket**(*regint client_socket_id*, *int number_of_inputs*) - -Read a share of an input from a client, blocking on the client send. - -*client_socket_id* - an identifier used to refer to the client socket. - -*number_of_inputs* - the number of inputs expected - -*[inputs]* - returned list of shares of private input. - -**sint.write_to_socket**(*regint client_socket_id*, *[sint values]*, *int message_type*) - -Write shares of values including macs to an external client. - -*client_socket_id* - an identifier used to refer to the client socket. - -*[values]* - list of shares of values to send to client. - -*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. - -See also sint.write_shares_to_socket where macs can be explicitly included or excluded from the message. - -*[sint inputs]* **sint.receive_from_client**(*int number_of_inputs*, *regint client_socket_id*, *int message_type*) - -Receive shares of private inputs from a client, blocking on client send. This is an abstraction which first sends shares of random values to the client and then receives masked input from the client, using the input protocol introduced in [Confidential Benchmarking based on Multiparty Computation. Damgard et al.](http://eprint.iacr.org/2015/1006.pdf) - -*number_of_inputs* - the number of inputs expected +Only the `sint` methods used in the example are documented here, equivalent methods are available for other data types. See [the reference](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#module-Compiler.types). -*client_socket_id* - an identifier used to refer to the client socket. +1. [Public value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.regint.read_from_socket) +2. [Secret value from client](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.receive_from_client) +3. [Reveal secret value to clients](https://mp-spdz.readthedocs.io/en/latest/Compiler.html#Compiler.types.sint.reveal_to_clients) -*message_type* - optional integer which will be sent in first 4 bytes of message, to indicate message type to client. +## Client-Side Interface -*[inputs]* - returned list of shares of private input. +The example uses the `Client` class implemented in +`ExternalIO/Client.hpp` to handle the communication, see +https://mp-spdz.readthedocs.io/en/latest/io.html#reference for +documentation. diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h index ac56668a2..8ac400839 100644 --- a/FHE/FHE_Params.h +++ b/FHE/FHE_Params.h @@ -33,9 +33,6 @@ class FHE_Params int n_mults() const { return FFTData.size() - 1; } - // Rely on default copy assignment/constructor (not that they should - // ever be needed) - void set(const Ring& R,const vector& primes); void set(const vector& primes); void set_sec(int sec); diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp index cb5daa386..c6e294a63 100644 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -178,12 +178,6 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, return extra_slack; } - - - -/****************************************************************************** - * Here onwards needs NTL - ******************************************************************************/ @@ -345,6 +339,7 @@ ZZX Cyclotomic(int N) return F; } #else +// simplified version powers of two int phi_N(int N) { if (((N - 1) & N) != 0) diff --git a/FHE/NTL-Subs.h b/FHE/NTL-Subs.h index ab150d272..c0a2ecfea 100644 --- a/FHE/NTL-Subs.h +++ b/FHE/NTL-Subs.h @@ -1,8 +1,6 @@ #ifndef _NTL_Subs #define _NTL_Subs -/* All these routines use NTL on the inside */ - #include "FHE/Ring.h" #include "FHE/FFT_Data.h" #include "FHE/P2Data.h" @@ -47,7 +45,7 @@ class Parameters }; -// Main setup routine (need NTL if online_only is false) +// Main setup routine void generate_setup(int nparties, int lgp, int lg2, int sec, bool skip_2 = false, int slack = 0, bool round_up = false); @@ -60,7 +58,6 @@ int generate_semi_setup(int plaintext_length, int sec, int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1, bool round_up); -// Everything else needs NTL void init(Ring& Rg, int m, bool generate_poly); void init(P2Data& P2D,const Ring& Rg); diff --git a/FHE/NoiseBounds.cpp b/FHE/NoiseBounds.cpp index ae52fc62f..7ab8e5172 100644 --- a/FHE/NoiseBounds.cpp +++ b/FHE/NoiseBounds.cpp @@ -114,7 +114,6 @@ NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack, cout << "n: " << n << endl; cout << "sec: " << sec << endl; cout << "sigma: " << this->sigma << endl; - cout << "h: " << h << endl; cout << "B_clean size: " << numBits(B_clean) << endl; cout << "B_scale size: " << numBits(B_scale) << endl; cout << "B_KS size: " << numBits(B_KS) << endl; diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp index 9c2545ed8..812560a3a 100644 --- a/FHE/Ring_Element.cpp +++ b/FHE/Ring_Element.cpp @@ -401,19 +401,29 @@ void Ring_Element::change_rep(RepType r) bool Ring_Element::equals(const Ring_Element& a) const { - if (element.empty() and a.element.empty()) - return true; - else if (element.empty() or a.element.empty()) - throw not_implemented(); - if (rep!=a.rep) { throw rep_mismatch(); } if (*FFTD!=*a.FFTD) { throw pr_mismatch(); } + + if (is_zero() or a.is_zero()) + return is_zero() and a.is_zero(); + for (int i=0; i<(*FFTD).phi_m(); i++) { if (!areEqual(element[i],a.element[i],(*FFTD).get_prD())) { return false; } } return true; } +bool Ring_Element::is_zero() const +{ + if (element.empty()) + return true; + for (auto& x : element) + if (not ::isZero(x, FFTD->get_prD())) + return false; + return true; +} + + ConversionIterator Ring_Element::get_iterator() const { if (rep != polynomial) @@ -560,6 +570,8 @@ void Ring_Element::check(const FFT_Data& FFTD) const { if (&FFTD != this->FFTD) throw params_mismatch(); + if (is_zero()) + throw runtime_error("element is zero"); } diff --git a/FHE/Ring_Element.h b/FHE/Ring_Element.h index 5cc93ca9a..5982bbe32 100644 --- a/FHE/Ring_Element.h +++ b/FHE/Ring_Element.h @@ -95,6 +95,7 @@ class Ring_Element void randomize(PRNG& G,bool Diag=false); bool equals(const Ring_Element& a) const; + bool is_zero() const; // This is a NOP in cases where we cannot do a FFT void change_rep(RepType r); diff --git a/FHEOffline/PairwiseGenerator.cpp b/FHEOffline/PairwiseGenerator.cpp index ed5fb303e..dcbd29b52 100644 --- a/FHEOffline/PairwiseGenerator.cpp +++ b/FHEOffline/PairwiseGenerator.cpp @@ -175,7 +175,7 @@ size_t PairwiseGenerator::report_size(ReportType type) template size_t PairwiseGenerator::report_sent() { - return P.sent; + return P.total_comm().sent; } template diff --git a/FHEOffline/SimpleGenerator.h b/FHEOffline/SimpleGenerator.h index 9cacad697..d5ee933af 100644 --- a/FHEOffline/SimpleGenerator.h +++ b/FHEOffline/SimpleGenerator.h @@ -71,7 +71,7 @@ class SimpleGenerator : public GeneratorBase void run(bool exhaust); size_t report_size(ReportType type); void report_size(ReportType type, MemoryUsage& res); - size_t report_sent() { return P.sent; } + size_t report_sent() { return P.total_comm().sent; } }; #endif /* FHEOFFLINE_SIMPLEGENERATOR_H_ */ diff --git a/GC/BitAdder.hpp b/GC/BitAdder.hpp index 9f8525971..437af179a 100644 --- a/GC/BitAdder.hpp +++ b/GC/BitAdder.hpp @@ -96,7 +96,7 @@ void BitAdder::add(vector >& res, b[j] = summands[i][1][input_begin + j]; } - protocol.init_mul(&proc); + protocol.init_mul(); for (size_t j = 0; j < n_items; j++) { res[begin + j][i] = a[j] + b[j] + carries[j]; diff --git a/GC/CcdPrep.h b/GC/CcdPrep.h index 8d232444c..ab02ea802 100644 --- a/GC/CcdPrep.h +++ b/GC/CcdPrep.h @@ -91,11 +91,6 @@ class CcdPrep : public BufferPrep (typename T::clear(tmp.get_bit(0)) << i); } } - - NamedCommStats comm_stats() - { - return part_prep.comm_stats(); - } }; } /* namespace GC */ diff --git a/GC/CcdPrep.hpp b/GC/CcdPrep.hpp index f9535350b..3124efc42 100644 --- a/GC/CcdPrep.hpp +++ b/GC/CcdPrep.hpp @@ -25,6 +25,14 @@ void CcdPrep::set_protocol(typename T::Protocol& protocol) { auto& thread = ShareThread::s(); assert(thread.MC); + + if (part_proc) + { + assert(&part_proc->MC == &thread.MC->get_part_MC()); + assert(&part_proc->P == &protocol.get_part().P); + return; + } + part_proc = new SubProcessor( thread.MC->get_part_MC(), part_prep, protocol.get_part().P); } diff --git a/GC/CcdShare.h b/GC/CcdShare.h index aececad0a..e890ce633 100644 --- a/GC/CcdShare.h +++ b/GC/CcdShare.h @@ -27,6 +27,7 @@ class CcdShare : public ShamirShare, public ShareSecret> typedef ShamirInput Input; typedef ShamirMC MAC_Check; + typedef Shamir Protocol; typedef This small_type; diff --git a/GC/FakeSecret.h b/GC/FakeSecret.h index 55c537de3..00e6c52c9 100644 --- a/GC/FakeSecret.h +++ b/GC/FakeSecret.h @@ -108,6 +108,9 @@ class FakeSecret : public ShareInterface, public BitVec template static void convcbit2s(GC::Processor&, const BaseInstruction&) { throw runtime_error("convcbit2s not implemented"); } + template + static void andm(GC::Processor&, const BaseInstruction&) + { throw runtime_error("andm not implemented"); } static FakeSecret input(GC::Processor& processor, const InputArgs& args); static FakeSecret input(int from, word input, int n_bits); diff --git a/GC/Instruction.cpp b/GC/Instruction.cpp index 3fe0cc588..6be1eb1ab 100644 --- a/GC/Instruction.cpp +++ b/GC/Instruction.cpp @@ -84,7 +84,7 @@ void Instruction::parse(istream& s, int pos) ostringstream os; os << "Code not defined for instruction " << showbase << hex << opcode << dec << endl; os << "This virtual machine executes binary circuits only." << endl; - os << "Try compiling with '-B' or use only sbit* types." << endl; + os << "Use 'compile.py -B'." << endl; throw Invalid_Instruction(os.str()); break; } diff --git a/GC/NoShare.h b/GC/NoShare.h index f60eccd75..c435ec3f8 100644 --- a/GC/NoShare.h +++ b/GC/NoShare.h @@ -7,6 +7,7 @@ #define GC_NOSHARE_H_ #include "Processor/DummyProtocol.h" +#include "Processor/Instruction.h" #include "Protocols/ShareInterface.h" class InputArgs; @@ -148,11 +149,14 @@ class NoShare : public ShareInterface static void trans(Processor&, Integer, const vector&) { fail(); } + static void andm(GC::Processor&, const BaseInstruction&) { fail(); } + static NoShare constant(const GC::Clear&, int, mac_key_type, int = -1) { fail(); return {}; } NoShare() {} - NoShare(int) { fail(); } + template + NoShare(T) { fail(); } void load_clear(Integer, Integer) { fail(); } void random_bit() { fail(); } diff --git a/GC/PostSacriBin.cpp b/GC/PostSacriBin.cpp index 81341cf08..742480600 100644 --- a/GC/PostSacriBin.cpp +++ b/GC/PostSacriBin.cpp @@ -9,6 +9,7 @@ #include "Protocols/Replicated.hpp" #include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "ShareSecret.hpp" namespace GC @@ -28,24 +29,19 @@ PostSacriBin::~PostSacriBin() } } -void PostSacriBin::init_mul(SubProcessor* proc) -{ - assert(proc != 0); - init_mul(proc->DataF, proc->MC); -} - -void PostSacriBin::init_mul(Preprocessing&, T::MC&) +void PostSacriBin::init_mul() { if ((int) inputs.size() >= OnlineOptions::singleton.batch_size) check(); honest.init_mul(); } -PostSacriBin::T::clear PostSacriBin::prepare_mul(const T& x, const T& y, int n) +void PostSacriBin::prepare_mul(const T& x, const T& y, int n) { + if (n == -1) + n = T::default_length; honest.prepare_mul(x, y, n); inputs.push_back({{x.mask(n), y.mask(n)}}); - return {}; } void PostSacriBin::exchange() @@ -55,6 +51,8 @@ void PostSacriBin::exchange() PostSacriBin::T PostSacriBin::finalize_mul(int n) { + if (n == -1) + n = T::default_length; auto res = honest.finalize_mul(n); outputs.push_back({res, n}); return res; diff --git a/GC/PostSacriBin.h b/GC/PostSacriBin.h index 50baa9c5d..8f1643a76 100644 --- a/GC/PostSacriBin.h +++ b/GC/PostSacriBin.h @@ -38,9 +38,8 @@ class PostSacriBin : public ReplicatedBase, PostSacriBin(Player& P); ~PostSacriBin(); - void init_mul(Preprocessing&, T::MC&); - void init_mul(SubProcessor* proc); - T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); diff --git a/GC/RepPrep.hpp b/GC/RepPrep.hpp index 1c91fd395..f83fbdaf4 100644 --- a/GC/RepPrep.hpp +++ b/GC/RepPrep.hpp @@ -3,6 +3,9 @@ * */ +#ifndef GC_REPPREP_HPP_ +#define GC_REPPREP_HPP_ + #include "RepPrep.h" #include "ShareThread.h" #include "Processor/OnlineOptions.h" @@ -98,3 +101,5 @@ void RepPrep::buffer_inputs(int player) } } /* namespace GC */ + +#endif diff --git a/GC/Secret.h b/GC/Secret.h index 14f6638af..c4b6e8eb1 100644 --- a/GC/Secret.h +++ b/GC/Secret.h @@ -126,6 +126,9 @@ class Secret template static void convcbit2s(Processor& processor, const BaseInstruction& instruction) { T::convcbit2s(processor, instruction); } + template + static void andm(Processor& processor, const BaseInstruction& instruction) + { T::andm(processor, instruction); } Secret(); Secret(const Integer& x) { *this = x; } diff --git a/GC/SemiPrep.cpp b/GC/SemiPrep.cpp index 9fc3f4918..9eed3b316 100644 --- a/GC/SemiPrep.cpp +++ b/GC/SemiPrep.cpp @@ -24,12 +24,15 @@ SemiPrep::SemiPrep(DataPositions& usage, bool) : void SemiPrep::set_protocol(Beaver& protocol) { if (triple_generator) + { + assert(&triple_generator->get_player() == &protocol.P); return; + } (void) protocol; params.set_passive(); triple_generator = new SemiSecret::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), + BaseMachine::fresh_ot_setup(protocol.P), protocol.P.N, -1, OnlineOptions::singleton.batch_size, 1, params, {}, &protocol.P); triple_generator->multi_threaded = false; @@ -61,12 +64,4 @@ void SemiPrep::buffer_bits() } } -NamedCommStats SemiPrep::comm_stats() -{ - if (triple_generator) - return triple_generator->comm_stats(); - else - return {}; -} - } /* namespace GC */ diff --git a/GC/SemiPrep.h b/GC/SemiPrep.h index 97214c28d..737cfb986 100644 --- a/GC/SemiPrep.h +++ b/GC/SemiPrep.h @@ -44,6 +44,8 @@ class SemiPrep : public BufferPrep, ShiftableTripleBuffer get_triple_no_count(int n_bits) { + if (n_bits == -1) + n_bits = SemiSecret::default_length; return ShiftableTripleBuffer::get_triple_no_count(n_bits); } @@ -51,8 +53,6 @@ class SemiPrep : public BufferPrep, ShiftableTripleBuffer static void convcbit2s(Processor& processor, const BaseInstruction& instruction) { processor.convcbit2s(instruction); } + static void andm(Processor& processor, const BaseInstruction& instruction) + { processor.andm(instruction); } static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); } diff --git a/GC/ShareSecret.hpp b/GC/ShareSecret.hpp index 1a508828b..23c86cb28 100644 --- a/GC/ShareSecret.hpp +++ b/GC/ShareSecret.hpp @@ -47,7 +47,7 @@ void ShareSecret::invert(int n, const U& x) { U ones; ones.load_clear(64, -1); - static_cast(*this) = U(x ^ ones) & get_mask(n); + reinterpret_cast(*this) = U(x + ones) & get_mask(n); } template @@ -92,8 +92,12 @@ template void ShareSecret::store_clear_in_dynamic(Memory& mem, const vector& accesses) { + auto& thread = ShareThread::s(); + assert(thread.P); + assert(thread.MC); for (auto access : accesses) - mem[access.address] = access.value; + mem[access.address] = U::constant(access.value, thread.P->my_num(), + thread.MC->get_alphai()); } template @@ -330,7 +334,7 @@ void ShareSecret::random_bit() template U& GC::ShareSecret::operator=(const U& other) { - U& real_this = static_cast(*this); + U& real_this = reinterpret_cast(*this); real_this = other; return real_this; } diff --git a/GC/ShareThread.h b/GC/ShareThread.h index 5f995e808..42c5e3bd6 100644 --- a/GC/ShareThread.h +++ b/GC/ShareThread.h @@ -58,9 +58,6 @@ class StandaloneShareThread : public ShareThread, public Thread void pre_run(); void post_run() { ShareThread::post_run(); } - - NamedCommStats comm_stats() - { return Thread::comm_stats() + this->DataF.comm_stats(); } }; template diff --git a/GC/ShareThread.hpp b/GC/ShareThread.hpp index 14d496115..07085040b 100644 --- a/GC/ShareThread.hpp +++ b/GC/ShareThread.hpp @@ -63,6 +63,7 @@ void ShareThread::pre_run(Player& P, typename T::mac_key_type mac_key) protocol = new typename T::Protocol(*this->P); MC = this->new_mc(mac_key); DataF.set_protocol(*this->protocol); + this->protocol->init(DataF, *MC); } template @@ -85,7 +86,7 @@ void ShareThread::and_(Processor& processor, { auto& protocol = this->protocol; processor.check_args(args, 4); - protocol->init_mul(DataF, *this->MC); + protocol->init_mul(); T x_ext, y_ext; for (size_t i = 0; i < args.size(); i += 4) { diff --git a/GC/Thread.h b/GC/Thread.h index 659c070a0..6631ad723 100644 --- a/GC/Thread.h +++ b/GC/Thread.h @@ -55,8 +55,6 @@ class Thread void join_tape(); void finish(); - - virtual NamedCommStats comm_stats(); }; template diff --git a/GC/Thread.hpp b/GC/Thread.hpp index 5487c41b2..d0b515cbf 100644 --- a/GC/Thread.hpp +++ b/GC/Thread.hpp @@ -96,13 +96,6 @@ void Thread::finish() pthread_join(thread, 0); } -template -NamedCommStats Thread::comm_stats() -{ - assert(P); - return P->comm_stats; -} - } /* namespace GC */ diff --git a/GC/ThreadMaster.hpp b/GC/ThreadMaster.hpp index 060e9f118..c6c9dcaac 100644 --- a/GC/ThreadMaster.hpp +++ b/GC/ThreadMaster.hpp @@ -95,11 +95,11 @@ void ThreadMaster::run() post_run(); - NamedCommStats stats = P->comm_stats; + NamedCommStats stats = P->total_comm(); ExecutionStats exe_stats; for (auto thread : threads) { - stats += thread->P->comm_stats; + stats += thread->P->total_comm(); exe_stats += thread->processor.stats; delete thread; } diff --git a/GC/TinierSharePrep.h b/GC/TinierSharePrep.h index 34beaf6fb..4e316e38c 100644 --- a/GC/TinierSharePrep.h +++ b/GC/TinierSharePrep.h @@ -44,8 +44,6 @@ class TinierSharePrep : public PersonalPrep ~TinierSharePrep(); void set_protocol(typename T::Protocol& protocol); - - NamedCommStats comm_stats(); }; } diff --git a/GC/TinierSharePrep.hpp b/GC/TinierSharePrep.hpp index 57e759b9c..e136ec446 100644 --- a/GC/TinierSharePrep.hpp +++ b/GC/TinierSharePrep.hpp @@ -8,7 +8,7 @@ #include "TinierSharePrep.h" -#include "PersonalPrep.hpp" +#include "PersonalPrep.h" namespace GC { @@ -39,14 +39,17 @@ template void TinierSharePrep::set_protocol(typename T::Protocol& protocol) { if (triple_generator) + { + assert(&triple_generator->get_player() == &protocol.P); return; + } params.generateMACs = true; params.amplify = false; params.check = false; auto& thread = ShareThread::s(); triple_generator = new typename T::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), protocol.P.N, -1, + BaseMachine::fresh_ot_setup(protocol.P), protocol.P.N, -1, OnlineOptions::singleton.batch_size, 1, params, thread.MC->get_alphai(), &protocol.P); triple_generator->multi_threaded = false; @@ -84,17 +87,6 @@ void GC::TinierSharePrep::buffer_bits() BufferPrep::get_random_from_inputs(thread.P->num_players())); } -template -NamedCommStats TinierSharePrep::comm_stats() -{ - NamedCommStats res; - if (triple_generator) - res += triple_generator->comm_stats(); - if (real_triple_generator) - res += real_triple_generator->comm_stats(); - return res; -} - } #endif diff --git a/GC/TinyPrep.hpp b/GC/TinyPrep.hpp index 2b8a11b79..897b3b482 100644 --- a/GC/TinyPrep.hpp +++ b/GC/TinyPrep.hpp @@ -16,7 +16,7 @@ void TinierSharePrep::init_real(Player& P) assert(real_triple_generator == 0); auto& thread = ShareThread::s(); real_triple_generator = new typename T::whole_type::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), P.N, -1, + BaseMachine::fresh_ot_setup(P), P.N, -1, OnlineOptions::singleton.batch_size, 1, params, thread.MC->get_alphai(), &P); real_triple_generator->multi_threaded = false; diff --git a/GC/VectorInput.h b/GC/VectorInput.h index c17cd93d4..44c9591b9 100644 --- a/GC/VectorInput.h +++ b/GC/VectorInput.h @@ -36,6 +36,8 @@ class VectorInput : public InputBase void add_mine(const typename T::open_type& input, int n_bits) { + if (n_bits == -1) + n_bits = T::default_length; for (int i = 0; i < n_bits; i++) part_input.add_mine(input.get_bit(i)); input_lengths.push_back(n_bits); @@ -43,6 +45,8 @@ class VectorInput : public InputBase void add_other(int player, int n_bits) { + if (n_bits == -1) + n_bits = T::default_length; for (int i = 0; i < n_bits; i++) part_input.add_other(player); } @@ -69,6 +73,8 @@ class VectorInput : public InputBase void finalize_other(int player, T& target, octetStream&, int n_bits) { + if (n_bits == -1) + n_bits = T::default_length; target.resize_regs(n_bits); for (int i = 0; i < n_bits; i++) part_input.finalize_other(player, target.get_reg(i), diff --git a/GC/VectorProtocol.h b/GC/VectorProtocol.h index 3f7e203c5..94ef19893 100644 --- a/GC/VectorProtocol.h +++ b/GC/VectorProtocol.h @@ -21,9 +21,10 @@ class VectorProtocol : public ProtocolBase VectorProtocol(Player& P); - void init_mul(SubProcessor* proc); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init(Preprocessing& prep, typename T::MAC_Check& MC); + + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); void finalize_mult(T& res, int n = -1); T finalize_mul(int n = -1); diff --git a/GC/VectorProtocol.hpp b/GC/VectorProtocol.hpp index cae461812..e72e0d148 100644 --- a/GC/VectorProtocol.hpp +++ b/GC/VectorProtocol.hpp @@ -18,26 +18,26 @@ VectorProtocol::VectorProtocol(Player& P) : } template -void VectorProtocol::init_mul(SubProcessor* proc) +void VectorProtocol::init(Preprocessing& prep, + typename T::MAC_Check& MC) { - assert(proc); - init_mul(proc->DataF, proc->MC); + part_protocol.init(prep.get_part(), MC.get_part_MC()); } template -void VectorProtocol::init_mul(Preprocessing& prep, - typename T::MAC_Check& MC) +void VectorProtocol::init_mul() { - part_protocol.init_mul(prep.get_part(), MC.get_part_MC()); + part_protocol.init_mul(); } template -typename T::clear VectorProtocol::prepare_mul(const T& x, +void VectorProtocol::prepare_mul(const T& x, const T& y, int n) { + if (n == -1) + n = T::default_length; for (int i = 0; i < n; i++) part_protocol.prepare_mul(x.get_reg(i), y.get_reg(i), 1); - return {}; } template @@ -57,6 +57,8 @@ T VectorProtocol::finalize_mul(int n) template void VectorProtocol::finalize_mult(T& res, int n) { + if (n == -1) + n = T::default_length; res.resize_regs(n); for (int i = 0; i < n; i++) res.get_reg(i) = part_protocol.finalize_mul(1); diff --git a/GC/instructions.h b/GC/instructions.h index fc278d441..66ae46d22 100644 --- a/GC/instructions.h +++ b/GC/instructions.h @@ -46,6 +46,7 @@ X(NOTCB, processor.notcb(INST)) \ X(ANDRS, T::andrs(PROC, EXTRA)) \ X(ANDS, T::ands(PROC, EXTRA)) \ + X(ANDM, T::andm(PROC, instruction)) \ X(ADDCB, C0 = PC1 + PC2) \ X(ADDCBI, C0 = PC1 + int(IMM)) \ X(MULCBI, C0 = PC1 * int(IMM)) \ @@ -76,7 +77,6 @@ #define COMBI_INSTRUCTIONS BIT_INSTRUCTIONS \ X(INPUTB, T::inputb(PROC, Proc, EXTRA)) \ X(INPUTBVEC, T::inputbvec(PROC, Proc, EXTRA)) \ - X(ANDM, processor.andm(instruction)) \ X(CONVSINT, S0.load_clear(IMM, Proc.read_Ci(REG1))) \ X(CONVCINT, C0 = Proc.read_Ci(REG1)) \ X(CONVCBIT, Proc.write_Ci(R0, PC1.get())) \ diff --git a/License.txt b/License.txt index 3a9eb2ae0..ccaafe01e 100644 --- a/License.txt +++ b/License.txt @@ -1,5 +1,5 @@ CSIRO Open Source Software Licence Agreement (variation of the BSD / MIT License) -Copyright (c) 2021, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. +Copyright (c) 2022, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. All rights reserved. CSIRO is willing to grant you a licence to this MP-SPDZ sofware on the following terms, except where otherwise indicated for third party material. Redistribution and use of this software in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. diff --git a/Machines/Atlas.hpp b/Machines/Atlas.hpp new file mode 100644 index 000000000..045b69b9e --- /dev/null +++ b/Machines/Atlas.hpp @@ -0,0 +1,16 @@ +/* + * Atlas.hpp + * + */ + +#ifndef MACHINES_ATLAS_HPP_ +#define MACHINES_ATLAS_HPP_ + +#include "Protocols/AtlasShare.h" +#include "Protocols/AtlasPrep.h" +#include "GC/AtlasSecret.h" + +#include "ShamirMachine.hpp" +#include "Protocols/Atlas.hpp" + +#endif /* MACHINES_ATLAS_HPP_ */ diff --git a/Machines/Rep.hpp b/Machines/Rep.hpp index d37c385c5..a480860f8 100644 --- a/Machines/Rep.hpp +++ b/Machines/Rep.hpp @@ -4,6 +4,7 @@ */ #include "Protocols/MalRepRingPrep.h" +#include "Protocols/ReplicatedPrep2k.h" #include "Processor/Data_Files.hpp" #include "Processor/Instruction.hpp" diff --git a/Machines/Rep4.hpp b/Machines/Rep4.hpp new file mode 100644 index 000000000..83ad1cff5 --- /dev/null +++ b/Machines/Rep4.hpp @@ -0,0 +1,17 @@ +/* + * Rep4.hpp + * + */ + +#ifndef MACHINES_REP4_HPP_ +#define MACHINES_REP4_HPP_ + +#include "GC/Rep4Secret.h" +#include "Protocols/Rep4Share2k.h" +#include "Protocols/Rep4Prep.h" +#include "Protocols/Rep4.hpp" +#include "Protocols/Rep4MC.hpp" +#include "Protocols/Rep4Input.hpp" +#include "Protocols/Rep4Prep.hpp" + +#endif /* MACHINES_REP4_HPP_ */ diff --git a/Machines/SPDZ.hpp b/Machines/SPDZ.hpp index 02ad9b983..a221b087a 100644 --- a/Machines/SPDZ.hpp +++ b/Machines/SPDZ.hpp @@ -21,13 +21,15 @@ #include "GC/TinierSecret.h" #include "GC/TinyMC.h" #include "GC/VectorInput.h" +#include "GC/VectorProtocol.h" -#include "GC/ShareParty.hpp" +#include "GC/ShareParty.h" #include "GC/Secret.hpp" -#include "GC/TinyPrep.hpp" -#include "GC/ShareSecret.hpp" -#include "GC/TinierSharePrep.hpp" -#include "GC/CcdPrep.hpp" +#include "GC/ShareSecret.h" +#include "GC/TinierSharePrep.h" +#include "GC/CcdPrep.h" + +#include "GC/VectorProtocol.hpp" #include "Math/gfp.hpp" diff --git a/Machines/SPDZ2k.hpp b/Machines/SPDZ2k.hpp index 672a29b4e..6cb02779d 100644 --- a/Machines/SPDZ2k.hpp +++ b/Machines/SPDZ2k.hpp @@ -23,9 +23,10 @@ #include "Protocols/MascotPrep.hpp" #include "Protocols/Spdz2kPrep.hpp" -#include "GC/ShareParty.hpp" -#include "GC/ShareSecret.hpp" +#include "GC/ShareParty.h" +#include "GC/ShareSecret.h" #include "GC/Secret.hpp" -#include "GC/TinyPrep.hpp" -#include "GC/TinierSharePrep.hpp" -#include "GC/CcdPrep.hpp" +#include "GC/TinierSharePrep.h" +#include "GC/CcdPrep.h" + +#include "GC/VectorProtocol.hpp" diff --git a/Machines/Semi.hpp b/Machines/Semi.hpp index 36c9d8c50..1a0931467 100644 --- a/Machines/Semi.hpp +++ b/Machines/Semi.hpp @@ -18,3 +18,4 @@ #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" +#include "Protocols/MalRepRingPrep.hpp" diff --git a/Machines/Semi2k.hpp b/Machines/Semi2k.hpp new file mode 100644 index 000000000..56f86d9ba --- /dev/null +++ b/Machines/Semi2k.hpp @@ -0,0 +1,15 @@ +/* + * Semi2.hpp + * + */ + +#ifndef MACHINES_SEMI2K_HPP_ +#define MACHINES_SEMI2K_HPP_ + +#include "Protocols/Semi2kShare.h" +#include "Protocols/SemiPrep2k.h" + +#include "Semi.hpp" +#include "Protocols/RepRingOnlyEdabitPrep.hpp" + +#endif /* MACHINES_SEMI2K_HPP_ */ diff --git a/Machines/ShamirMachine.hpp b/Machines/ShamirMachine.hpp index 080332aea..7697c5124 100644 --- a/Machines/ShamirMachine.hpp +++ b/Machines/ShamirMachine.hpp @@ -27,6 +27,7 @@ #include "Protocols/Beaver.hpp" #include "Protocols/Spdz2kPrep.hpp" #include "Protocols/ReplicatedPrep.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/VectorProtocol.hpp" #include "GC/Secret.hpp" diff --git a/Machines/Tinier.cpp b/Machines/Tinier.cpp new file mode 100644 index 000000000..99ad1c5c1 --- /dev/null +++ b/Machines/Tinier.cpp @@ -0,0 +1,23 @@ +/* + * Tinier.cpp + * + */ + +#include "GC/TinyMC.h" +#include "GC/TinierSecret.h" +#include "GC/VectorInput.h" + +#include "GC/ShareParty.hpp" +#include "GC/Secret.hpp" +#include "GC/TinyPrep.hpp" +#include "GC/ShareSecret.hpp" +#include "GC/TinierSharePrep.hpp" +#include "GC/CcdPrep.hpp" +#include "GC/PersonalPrep.hpp" + +//template class GC::ShareParty>; +template class GC::CcdPrep>; +template class Preprocessing>; +template class GC::TinierSharePrep>; +template class GC::ShareSecret>; +template class TripleShuffleSacrifice>; diff --git a/Machines/atlas-party.cpp b/Machines/atlas-party.cpp index 6e754c7ff..2df033e60 100644 --- a/Machines/atlas-party.cpp +++ b/Machines/atlas-party.cpp @@ -3,12 +3,7 @@ * */ -#include "Protocols/AtlasShare.h" -#include "Protocols/AtlasPrep.h" -#include "GC/AtlasSecret.h" - -#include "ShamirMachine.hpp" -#include "Protocols/Atlas.hpp" +#include "Atlas.hpp" int main(int argc, const char** argv) { diff --git a/Machines/emulate.cpp b/Machines/emulate.cpp index 8525b0671..f26f5f324 100644 --- a/Machines/emulate.cpp +++ b/Machines/emulate.cpp @@ -10,11 +10,13 @@ #include "Processor/RingOptions.h" #include "Processor/Machine.hpp" +#include "Processor/OnlineOptions.hpp" #include "Math/Z2k.hpp" #include "Protocols/Replicated.hpp" #include "Protocols/ShuffleSacrifice.hpp" #include "Protocols/ReplicatedPrep.hpp" #include "Protocols/FakeShare.hpp" +#include "Protocols/MalRepRingPrep.hpp" int main(int argc, const char** argv) { @@ -22,7 +24,7 @@ int main(int argc, const char** argv) Names N; ez::ezOptionParser opt; RingOptions ring_opts(opt, argc, argv); - online_opts = {opt, argc, argv}; + online_opts = {opt, argc, argv, FakeShare>()}; opt.parse(argc, argv); opt.syntax = string(argv[0]) + " "; @@ -44,9 +46,7 @@ int main(int argc, const char** argv) #ifdef ROUND_NEAREST_IN_EMULATION cerr << "Using nearest rounding instead of probabilistic truncation" << endl; #else -#ifdef RISKY_TRUNCATION_IN_EMULATION - cerr << "Using risky truncation" << endl; -#endif + online_opts.set_trunc_error(opt); #endif int R = ring_opts.ring_size_from_opts_or_schedule(progname); diff --git a/Machines/hemi-party.cpp b/Machines/hemi-party.cpp index 471862dab..934c15dcd 100644 --- a/Machines/hemi-party.cpp +++ b/Machines/hemi-party.cpp @@ -24,6 +24,7 @@ #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" #include "Protocols/Hemi.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" #include "Math/gfp.hpp" diff --git a/Machines/no-party.cpp b/Machines/no-party.cpp index 2120322f3..ce542de18 100644 --- a/Machines/no-party.cpp +++ b/Machines/no-party.cpp @@ -8,6 +8,7 @@ #include "Processor/OnlineMachine.hpp" #include "Processor/Machine.hpp" #include "Protocols/Replicated.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "Math/gfp.hpp" #include "Math/Z2k.hpp" diff --git a/Machines/soho-party.cpp b/Machines/soho-party.cpp index 6f7c70a3a..7ecc450da 100644 --- a/Machines/soho-party.cpp +++ b/Machines/soho-party.cpp @@ -22,6 +22,7 @@ #include "Protocols/MAC_Check.hpp" #include "Protocols/SemiMC.hpp" #include "Protocols/Beaver.hpp" +#include "Protocols/MalRepRingPrep.hpp" #include "GC/ShareSecret.hpp" #include "GC/SemiHonestRepPrep.h" #include "Math/gfp.hpp" diff --git a/Makefile b/Makefile index 9d634c0e9..e40528b8c 100644 --- a/Makefile +++ b/Makefile @@ -25,6 +25,8 @@ MINI_OT = OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT) VMOBJS = $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT) VM = $(MINI_OT) $(SHAREDLIB) COMMON = $(SHAREDLIB) +TINIER = Machines/Tinier.o $(OT) +SPDZ = Machines/SPDZ.o $(TINIER) LIB = libSPDZ.a @@ -117,7 +119,7 @@ sy: sy-rep-field-party.x sy-rep-ring-party.x sy-shamir-party.x ecdsa: $(patsubst ECDSA/%.cpp,%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) Fake-ECDSA.x ecdsa-static: static-dir $(patsubst ECDSA/%.cpp,static/%.x,$(wildcard ECDSA/*-ecdsa-party.cpp)) -$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(OT) $(GC) +$(LIBRELEASE): Protocols/MalRepRingOptions.o $(PROCESSOR) $(COMMONOBJS) $(TINIER) $(GC) $(AR) -csr $@ $^ CFLAGS += -fPIC @@ -203,16 +205,16 @@ ps-rep-bin-party.x: GC/PostSacriBin.o semi-bin-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o tiny-party.x: $(OT) tinier-party.x: $(OT) -spdz2k-party.x: $(OT) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) +spdz2k-party.x: $(TINIER) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp)) semi-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o semi2k-party.x: $(OT) GC/SemiSecret.o GC/SemiPrep.o GC/square64.o hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT) -cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) -chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(OT) -lowgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o -highgear-party.x: $(FHEOFFLINE) $(OT) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o +cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) +chaigear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o $(TINIER) +lowgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o +highgear-party.x: $(FHEOFFLINE) $(TINIER) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o atlas-party.x: GC/AtlasSecret.o static/hemi-party.x: $(FHEOBJS) static/soho-party.x: $(FHEOBJS) @@ -220,10 +222,10 @@ static/cowgear-party.x: $(FHEOBJS) static/chaigear-party.x: $(FHEOBJS) static/lowgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/LowGearKeyGen.o static/highgear-party.x: $(FHEOBJS) Protocols/CowGearOptions.o Protocols/HighGearKeyGen.o -mascot-party.x: Machines/SPDZ.o $(OT) -static/mascot-party.x: Machines/SPDZ.o -Player-Online.x: Machines/SPDZ.o $(OT) -mama-party.x: $(OT) +mascot-party.x: $(SPDZ) +static/mascot-party.x: $(SPDZ) +Player-Online.x: $(SPDZ) +mama-party.x: $(TINIER) ps-rep-ring-party.x: Protocols/MalRepRingOptions.o malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o sy-rep-ring-party.x: Protocols/MalRepRingOptions.o @@ -236,8 +238,10 @@ emulate.x: GC/FakeSecret.o semi-bmr-party.x: GC/SemiPrep.o GC/SemiSecret.o $(OT) real-bmr-party.x: $(OT) paper-example.x: $(VM) $(OT) $(FHEOFFLINE) -mascot-offline.x: $(VM) $(OT) -cowgear-offline.x: $(OT) $(FHEOFFLINE) +binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o +mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/SemiSecret.o GC/AtlasSecret.o Machines/Tinier.o +mascot-offline.x: $(VM) $(TINIER) +cowgear-offline.x: $(TINIER) $(FHEOFFLINE) static/rep-bmr-party.x: $(BMR) static/mal-rep-bmr-party.x: $(BMR) static/shamir-bmr-party.x: $(BMR) diff --git a/Math/BitVec.h b/Math/BitVec.h index f9e874d14..f0d60a1b9 100644 --- a/Math/BitVec.h +++ b/Math/BitVec.h @@ -26,6 +26,7 @@ class BitVec_ : public IntBase static const false_type invertible; static const true_type characteristic_two; + static const true_type binary; static char type_char() { return 'B'; } static string type_short() { return "B"; } @@ -64,8 +65,21 @@ class BitVec_ : public IntBase void pack(octetStream& os) const { os.store_int(this->a); } void unpack(octetStream& os) { this->a = os.get_int(); } - void pack(octetStream& os, int n) const { os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); } - void unpack(octetStream& os, int n) { this->a = os.get_int(DIV_CEIL(n, 8)); } + void pack(octetStream& os, int n) const + { + if (n == -1) + pack(os); + else + os.store_int(super::mask(n).get(), DIV_CEIL(n, 8)); + } + + void unpack(octetStream& os, int n) + { + if (n == -1) + unpack(os); + else + this->a = os.get_int(DIV_CEIL(n, 8)); + } static BitVec_ unpack_new(octetStream& os, int n = n_bits) { @@ -81,5 +95,7 @@ template const false_type BitVec_::invertible; template const true_type BitVec_::characteristic_two; +template +const true_type BitVec_::binary; #endif /* MATH_BITVEC_H_ */ diff --git a/Math/Setup.hpp b/Math/Setup.hpp index 6545d67ec..91cafaea5 100644 --- a/Math/Setup.hpp +++ b/Math/Setup.hpp @@ -36,8 +36,9 @@ void read_setup(const string& dir_prefix, int lgp = -1) { if (lgp > 0) { - cerr << "No modulus found in " << filename << ", generating " << lgp - << "-bit prime" << endl; + if (OnlineOptions::singleton.verbose) + cerr << "No modulus found in " << filename << ", generating " + << lgp << "-bit prime" << endl; T::init_default(lgp); } else diff --git a/Math/ValueInterface.h b/Math/ValueInterface.h index d15af24c8..07807cb23 100644 --- a/Math/ValueInterface.h +++ b/Math/ValueInterface.h @@ -20,6 +20,7 @@ class ValueInterface static const false_type characteristic_two; static const false_type prime_field; static const false_type invertible; + static const false_type binary; template static void init(bool mont = true) { (void) mont; } diff --git a/Math/Z2k.h b/Math/Z2k.h index 3e6530442..ad32cbf16 100644 --- a/Math/Z2k.h +++ b/Math/Z2k.h @@ -47,6 +47,7 @@ class Z2 : public ValueInterface static int size_in_limbs() { return N_WORDS; } static int size_in_bits() { return size() * 8; } static int length() { return size_in_bits(); } + static int n_bits() { return N_BITS; } static int t() { return 0; } static char type_char() { return 'R'; } @@ -100,6 +101,8 @@ class Z2 : public ValueInterface int bit_length() const; + Z2 mask(int) const { return *this; } + Z2 operator+(const Z2& other) const; Z2 operator-(const Z2& other) const; diff --git a/Math/Zp_Data.cpp b/Math/Zp_Data.cpp index 63c279a26..17fcdf24c 100644 --- a/Math/Zp_Data.cpp +++ b/Math/Zp_Data.cpp @@ -86,6 +86,42 @@ void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y,int t { inline_mpn_copyi(z,ans+t,t); } } +void Zp_Data::Mont_Mult_switch(mp_limb_t* z, const mp_limb_t* x, + const mp_limb_t* y) const +{ + switch (t) + { +#ifdef __BMI2__ +#define CASE(N) \ + case N: \ + Mont_Mult_(z, x, y); \ + break; + CASE(1) + CASE(2) +#if MAX_MOD_SZ >= 4 + CASE(3) + CASE(4) +#endif +#if MAX_MOD_SZ >= 5 + CASE(5) +#endif +#if MAX_MOD_SZ >= 6 + CASE(6) +#endif +#if MAX_MOD_SZ >= 10 + CASE(7) + CASE(8) + CASE(9) + CASE(10) +#endif +#undef CASE +#endif + default: + Mont_Mult_variable(z, x, y); + break; + } +} + ostream& operator<<(ostream& s,const Zp_Data& ZpD) diff --git a/Math/Zp_Data.h b/Math/Zp_Data.h index 96deb7951..f30e71037 100644 --- a/Math/Zp_Data.h +++ b/Math/Zp_Data.h @@ -40,6 +40,7 @@ class Zp_Data template void Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; + void Mont_Mult_switch(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const; void Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y, int t) const; void Mont_Mult_variable(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const { Mont_Mult(z, x, y, t); } @@ -242,37 +243,11 @@ inline void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* { if (not cpu_has_bmi2()) return Mont_Mult_variable(z, x, y); - switch (t) - { #ifdef __BMI2__ -#define CASE(N) \ - case N: \ - Mont_Mult_(z, x, y); \ - break; - CASE(1) - CASE(2) -#if MAX_MOD_SZ >= 4 - CASE(3) - CASE(4) -#endif -#if MAX_MOD_SZ >= 5 - CASE(5) -#endif -#if MAX_MOD_SZ >= 6 - CASE(6) -#endif -#if MAX_MOD_SZ >= 10 - CASE(7) - CASE(8) - CASE(9) - CASE(10) -#endif -#undef CASE + return Mont_Mult_switch(z, x, y); +#else + return Mont_Mult_variable(z, x, y); #endif - default: - Mont_Mult_variable(z, x, y); - break; - } } inline void Zp_Data::Mont_Mult_max(mp_limb_t* z, const mp_limb_t* x, diff --git a/Math/gfp.h b/Math/gfp.h index 7b257b5fa..bde43025e 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -11,7 +11,6 @@ using namespace std; #include "Math/Bit.h" #include "Math/Setup.h" #include "Tools/random.h" -#include "GC/NoShare.h" #include "Processor/OnlineOptions.h" #include "Math/modp.hpp" @@ -101,6 +100,7 @@ class gfp_ : public ValueInterface static int size() { return t() * sizeof(mp_limb_t); } static int size_in_bits() { return 8 * size(); } static int length() { return ZpD.pr_bit_length; } + static int n_bits() { return length() - 1; } static void reqbl(int n); diff --git a/Networking/CryptoPlayer.cpp b/Networking/CryptoPlayer.cpp index c2e1403b9..9d8da6514 100644 --- a/Networking/CryptoPlayer.cpp +++ b/Networking/CryptoPlayer.cpp @@ -5,6 +5,7 @@ #include "CryptoPlayer.h" #include "Math/Setup.h" +#include "Tools/Bundle.h" void check_ssl_file(string filename) { @@ -124,12 +125,14 @@ CryptoPlayer::~CryptoPlayer() void CryptoPlayer::send_to_no_stats(int other, const octetStream& o) const { + assert(other != my_num()); senders[other]->request(o); senders[other]->wait(o); } void CryptoPlayer::receive_player_no_stats(int other, octetStream& o) const { + assert(other != my_num()); receivers[other]->request(o); receivers[other]->wait(o); } @@ -137,6 +140,7 @@ void CryptoPlayer::receive_player_no_stats(int other, octetStream& o) const void CryptoPlayer::exchange_no_stats(int other, const octetStream& to_send, octetStream& to_receive) const { + assert(other != my_num()); if (&to_send == &to_receive) { MultiPlayer::exchange_no_stats(other, to_send, to_receive); @@ -153,6 +157,7 @@ void CryptoPlayer::exchange_no_stats(int other, const octetStream& to_send, void CryptoPlayer::pass_around_no_stats(const octetStream& to_send, octetStream& to_receive, int offset) const { + assert(get_player(offset) != my_num()); if (&to_send == &to_receive) { MultiPlayer::pass_around_no_stats(to_send, to_receive, offset); diff --git a/Networking/Player.cpp b/Networking/Player.cpp index 61c8fd65c..cd92df541 100644 --- a/Networking/Player.cpp +++ b/Networking/Player.cpp @@ -14,12 +14,14 @@ using namespace std; -void Names::init(int player,int pnb,int my_port,const char* servername) +void Names::init(int player, int pnb, int my_port, const char* servername, + bool setup_socket) { player_no=player; portnum_base=pnb; setup_names(servername, my_port); - setup_server(); + if (setup_socket) + setup_server(); } Names::Names(int player, int nplayers, const string& servername, int pnb, @@ -124,7 +126,7 @@ void Names::setup_names(const char *servername, int my_port) my_port = default_port(player_no); int socket_num; - int pn = portnum_base - 1; + int pn = portnum_base; set_up_client_socket(socket_num, servername, pn); octetStream("P" + to_string(player_no)).Send(socket_num); #ifdef DEBUG_NETWORKING @@ -132,15 +134,11 @@ void Names::setup_names(const char *servername, int my_port) #endif // Send my name - octet my_name[512]; - memset(my_name,0,512*sizeof(octet)); sockaddr_in address; socklen_t size = sizeof address; getsockname(socket_num, (sockaddr*)&address, &size); - char* name = inet_ntoa(address.sin_addr); - // max length of IP address with ending 0 - strncpy((char*)my_name, name, 16); - send(socket_num,my_name,512); + char* my_name = inet_ntoa(address.sin_addr); + octetStream(my_name).Send(socket_num); send(socket_num,(octet*)&my_port,4); #ifdef DEBUG_NETWORKING fprintf(stderr, "My Name = %s\n",my_name); @@ -158,9 +156,10 @@ void Names::setup_names(const char *servername, int my_port) names.resize(nplayers); ports.resize(nplayers); for (i=0; iinit(); } +void Names::set_server(ServerSocket* socket) +{ + assert(not server); + server = socket; +} + Names::Names(const Names& other) { @@ -201,6 +206,7 @@ Player::Player(const Names& Nms) : { nplayers=Nms.nplayers; player_no=Nms.player_no; + thread_stats.resize(nplayers); } @@ -243,6 +249,10 @@ MultiPlayer::~MultiPlayer() Player::~Player() { +#ifdef VERBOSE + for (auto& x : thread_stats) + x.print(); +#endif } PlayerBase::~PlayerBase() @@ -685,7 +695,7 @@ void VirtualTwoPartyPlayer::send(octetStream& o) const { TimeScope ts(comm_stats["Sending one-to-one"].add(o)); P.send_to_no_stats(other_player, o); - sent += o.get_length(); + comm_stats.sent += o.get_length(); } void RealTwoPartyPlayer::receive(octetStream& o) const @@ -729,12 +739,13 @@ void RealTwoPartyPlayer::exchange(octetStream& o) const void VirtualTwoPartyPlayer::send_receive_player(vector& o) const { TimeScope ts(comm_stats["Exchanging one-to-one"].add(o[0])); - sent += o[0].get_length(); + comm_stats.sent += o[0].get_length(); P.exchange_no_stats(other_player, o[0], o[1]); } VirtualTwoPartyPlayer::VirtualTwoPartyPlayer(Player& P, int other_player) : - TwoPartyPlayer(P.my_num()), P(P), other_player(other_player) + TwoPartyPlayer(P.my_num()), P(P), other_player(other_player), comm_stats( + P.thread_stats.at(other_player)) { } @@ -814,5 +825,13 @@ void NamedCommStats::print(bool newline) cerr << endl; } +NamedCommStats Player::total_comm() const +{ + auto res = comm_stats; + for (auto& x : thread_stats) + res += x; + return res; +} + template class MultiPlayer; template class MultiPlayer ; diff --git a/Networking/Player.h b/Networking/Player.h index 033aa3bd1..9c90dbd1f 100644 --- a/Networking/Player.h +++ b/Networking/Player.h @@ -35,6 +35,7 @@ class Names friend class Player; friend class PlainPlayer; friend class RealTwoPartyPlayer; + friend class Server; vector names; vector ports; @@ -51,6 +52,8 @@ class Names void setup_server(); + void set_server(ServerSocket* socket); + public: static const int DEFAULT_PORT = -1; @@ -62,8 +65,10 @@ class Names * @param my_port my port number (`DEFAULT_PORT` for default, * which is base port number plus player number) * @param servername location of server + * @param setup_socket whether to start listening */ - void init(int player,int pnb,int my_port,const char* servername); + void init(int player, int pnb, int my_port, const char* servername, + bool setup_socket = true); Names(int player,int pnb,int my_port,const char* servername) : Names() { init(player,pnb,my_port,servername); } @@ -172,11 +177,12 @@ class PlayerBase protected: int player_no; -public: size_t& sent; - mutable Timer timer; mutable NamedCommStats comm_stats; +public: + mutable Timer timer; + PlayerBase(int player_no) : player_no(player_no), sent(comm_stats.sent) {} virtual ~PlayerBase(); @@ -205,6 +211,8 @@ class Player : public PlayerBase public: const Names& N; + mutable vector thread_stats; + Player(const Names& Nms); virtual ~Player(); @@ -358,6 +366,8 @@ class Player : public PlayerBase virtual void request_receive(int i, octetStream& o) const { (void)i; (void)o; } virtual void wait_receive(int i, octetStream& o) const { receive_player(i, o); } + + NamedCommStats total_comm() const; }; /** @@ -500,6 +510,7 @@ class VirtualTwoPartyPlayer : public TwoPartyPlayer { Player& P; int other_player; + NamedCommStats& comm_stats; public: VirtualTwoPartyPlayer(Player& P, int other_player); diff --git a/Networking/Receiver.cpp b/Networking/Receiver.cpp index e93f47c44..7e8c93fe9 100644 --- a/Networking/Receiver.cpp +++ b/Networking/Receiver.cpp @@ -51,9 +51,17 @@ void Receiver::run() while (in.pop(os)) { os->reset_write_head(); +#ifdef VERBOSE_SSL timer.start(); + RunningTimer mytimer; +#endif os->Receive(socket); +#ifdef VERBOSE_SSL + cout << "receiving " << os->get_length() * 1e-6 << " MB on " << socket + << " took " << mytimer.elapsed() << ", total " + << timer.elapsed() << endl; timer.stop(); +#endif out.push(os); } } diff --git a/Networking/Sender.cpp b/Networking/Sender.cpp index 51d5f4711..4e4b98810 100644 --- a/Networking/Sender.cpp +++ b/Networking/Sender.cpp @@ -47,9 +47,17 @@ void Sender::run() const octetStream* os = 0; while (in.pop(os)) { -// timer.start(); +#ifdef VERBOSE_SSL + timer.start(); + RunningTimer mytimer; +#endif os->Send(socket); -// timer.stop(); +#ifdef VERBOSE_SSL + cout << "sending " << os->get_length() * 1e-6 << " MB on " << socket + << " took " << mytimer.elapsed() << ", total " + << timer.elapsed() << endl; + timer.stop(); +#endif out.push(os); } } diff --git a/Networking/Server.cpp b/Networking/Server.cpp index d9a056dd2..facda0a26 100644 --- a/Networking/Server.cpp +++ b/Networking/Server.cpp @@ -28,9 +28,7 @@ void Server::get_ip(int num) inet_ntop(AF_INET6, &s->sin6_addr, ipstr, sizeof ipstr); } - names[num]=new octet[512]; - memset(names[num], 0, 512); - strncpy((char*)names[num], ipstr, INET6_ADDRSTRLEN); + names[num] = ipstr; #ifdef DEBUG_NETWORKING cerr << "Client IP address: " << names[num] << endl; @@ -45,11 +43,11 @@ void Server::get_name(int num) #endif // Receive name sent by client (legacy) - not used here - octet my_name[512]; - receive(socket_num[num],my_name,512); + octetStream os; + os.Receive(socket_num[num]); receive(socket_num[num],(octet*)&ports[num],4); #ifdef DEBUG_NETWORKING - cerr << "Player " << num << " sent (IP for info only) " << my_name << ":" + cerr << "Player " << num << " sent (IP for info only) " << os.str() << ":" << ports[num] << endl; #endif @@ -66,7 +64,7 @@ void Server::send_names(int num) send(socket_num[num],nmachines,4); for (int i=0; i= 0); assert(my_num < nplayers); @@ -172,12 +175,19 @@ Server* Server::start_networking(Names& N, int my_num, int nplayers, { pthread_create(&thread, 0, Server::start_in_thread, server = new Server(nplayers, portnum)); - } - N.init(my_num, portnum, my_port, hostname.c_str()); - if (my_num == 0) - { + N.init(my_num, portnum, my_port, hostname.c_str(), false); pthread_join(thread, 0); + N.set_server(server->get_socket()); delete server; } + else + N.init(my_num, portnum, my_port, hostname.c_str()); return 0; } + +ServerSocket* Server::get_socket() +{ + auto res = server_socket; + server_socket = 0; + return res; +} diff --git a/Networking/Server.h b/Networking/Server.h index a5e833add..ad6d5fd5d 100644 --- a/Networking/Server.h +++ b/Networking/Server.h @@ -14,10 +14,11 @@ using namespace std; class Server { vector socket_num; - vector names; + vector names; vector ports; int nmachines; int PortnumBase; + ServerSocket* server_socket; void get_ip(int num); void get_name(int num); @@ -31,7 +32,11 @@ class Server Server(int argc, char** argv); Server(int nmachines, int PortnumBase); + ~Server(); + void start(); + + ServerSocket* get_socket(); }; #endif /* NETWORKING_SERVER_H_ */ diff --git a/Networking/ssl_sockets.h b/Networking/ssl_sockets.h index 8989a0a10..79cb35222 100644 --- a/Networking/ssl_sockets.h +++ b/Networking/ssl_sockets.h @@ -7,6 +7,7 @@ #define CRYPTO_SSL_SOCKETS_H_ #include "Tools/int.h" +#include "Tools/time-func.h" #include "sockets.h" #include "Math/Setup.h" @@ -46,6 +47,10 @@ class ssl_socket : public boost::asio::ssl::stream string me, bool client) : parent(io_service, ctx) { +#ifdef DEBUG_NETWORKING + cerr << me << " setting up SSL to " << other << " as " << + (client ? "client" : "server") << endl; +#endif lowest_layer().assign(boost::asio::ip::tcp::v4(), plaintext_socket); set_verify_mode(boost::asio::ssl::verify_peer); set_verify_callback(boost::asio::ssl::rfc2818_verification(other)); @@ -82,8 +87,16 @@ template<> inline void send(ssl_socket* socket, octet* data, size_t length) { size_t sent = 0; +#ifdef VERBOSE_SSL + RunningTimer timer; +#endif while (sent < length) + { sent += send_non_blocking(socket, data + sent, length - sent); +#ifdef VERBOSE_SSL + cout << "sent " << sent * 1e-6 << " MB at " << timer.elapsed() << endl; +#endif + } } template<> diff --git a/OT/BaseOT.cpp b/OT/BaseOT.cpp index 8847728e9..988565854 100644 --- a/OT/BaseOT.cpp +++ b/OT/BaseOT.cpp @@ -1,6 +1,7 @@ #include "OT/BaseOT.h" #include "Tools/random.h" #include "Tools/benchmarking.h" +#include "Tools/Bundle.h" #include #include @@ -78,6 +79,23 @@ void send_if_ot_receiver(TwoPartyPlayer* P, vector& os, OT_ROLE rol void BaseOT::exec_base(bool new_receiver_inputs) { + Bundle bundle(*P); +#ifdef NO_AVX_OT + bundle.mine = string("OT without AVX"); +#else + bundle.mine = string("OT with AVX"); +#endif + try + { + bundle.compare(*P); + } + catch (mismatch_among_parties&) + { + cerr << "Parties compiled with different base OT algorithms" << endl; + cerr << "Set \"AVX_OT\" to the same value on all parties" << endl; + exit(1); + } + #ifdef NO_AVX_OT #ifdef USE_RISTRETTO typedef CurveElement Element; diff --git a/OT/NPartyTripleGenerator.h b/OT/NPartyTripleGenerator.h index d5981e713..8a84ca0a3 100644 --- a/OT/NPartyTripleGenerator.h +++ b/OT/NPartyTripleGenerator.h @@ -116,7 +116,7 @@ class OTTripleGenerator : public GeneratorThread mac_key_type get_mac_key() const { return mac_key; } - NamedCommStats comm_stats(); + Player& get_player() { return globalPlayer; } }; template @@ -209,15 +209,4 @@ class Spdz2kTripleGenerator : public NPartyTripleGenerator void generateTriples(); }; -template -NamedCommStats OTTripleGenerator::comm_stats() -{ - NamedCommStats res; - if (parentPlayer != &globalPlayer) - res = globalPlayer.comm_stats; - for (auto& player : players) - res += player->comm_stats; - return res; -} - #endif diff --git a/Processor/BaseMachine.cpp b/Processor/BaseMachine.cpp index bc36a8606..019fc6f28 100644 --- a/Processor/BaseMachine.cpp +++ b/Processor/BaseMachine.cpp @@ -110,22 +110,31 @@ void BaseMachine::time() void BaseMachine::start(int n) { cout << "Starting timer " << n << " at " << timer[n].elapsed() + << " (" << timer[n].mb_sent() << " MB)" << " after " << timer[n].idle() << endl; - timer[n].start(); + timer[n].start(total_comm()); } void BaseMachine::stop(int n) { - timer[n].stop(); - cout << "Stopped timer " << n << " at " << timer[n].elapsed() << endl; + timer[n].stop(total_comm()); + cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " (" + << timer[n].mb_sent() << " MB)" << endl; } void BaseMachine::print_timers() { + cerr << "The following timing is "; + if (OnlineOptions::singleton.live_prep) + cerr << "in"; + else + cerr << "ex"; + cerr << "clusive preprocessing." << endl; cerr << "Time = " << timer[0].elapsed() << " seconds " << endl; timer.erase(0); - for (map::iterator it = timer.begin(); it != timer.end(); it++) - cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds " << endl; + for (auto it = timer.begin(); it != timer.end(); it++) + cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds (" + << it->second.mb_sent() << " MB)" << endl; } string BaseMachine::memory_filename(const string& type_short, int my_number) @@ -170,3 +179,18 @@ bigint BaseMachine::prime_from_schedule(string progname) else return 0; } + +NamedCommStats BaseMachine::total_comm() +{ + NamedCommStats res; + for (auto& queue : queues) + res += queue->get_comm_stats(); + return res; +} + +void BaseMachine::set_thread_comm(const NamedCommStats& stats) +{ + auto queue = queues.at(BaseMachine::thread_num); + assert(queue); + queue->set_comm_stats(stats); +} diff --git a/Processor/BaseMachine.h b/Processor/BaseMachine.h index 0e08549e3..035a0cfef 100644 --- a/Processor/BaseMachine.h +++ b/Processor/BaseMachine.h @@ -7,6 +7,7 @@ #define PROCESSOR_BASEMACHINE_H_ #include "Tools/time-func.h" +#include "Tools/TimerWithComm.h" #include "OT/OTTripleSetup.h" #include "ThreadJob.h" #include "ThreadQueues.h" @@ -22,7 +23,7 @@ class BaseMachine protected: static BaseMachine* singleton; - std::map timer; + std::map timer; string compiler; string domain; @@ -66,12 +67,18 @@ class BaseMachine virtual void reqbl(int) {} - OTTripleSetup fresh_ot_setup(); + static OTTripleSetup fresh_ot_setup(Player& P); + + NamedCommStats total_comm(); + void set_thread_comm(const NamedCommStats& stats); }; -inline OTTripleSetup BaseMachine::fresh_ot_setup() +inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P) { - return ot_setups.at(thread_num).get_fresh(); + if (singleton and size_t(thread_num) < s().ot_setups.size()) + return s().ot_setups.at(thread_num).get_fresh(); + else + return OTTripleSetup(P, true); } #endif /* PROCESSOR_BASEMACHINE_H_ */ diff --git a/Processor/Binary_File_IO.hpp b/Processor/Binary_File_IO.hpp index be1fb8fdb..9878f4a6b 100644 --- a/Processor/Binary_File_IO.hpp +++ b/Processor/Binary_File_IO.hpp @@ -38,7 +38,7 @@ void Binary_File_IO::read_from_file(const string filename, vector< T >& buffer, int size_in_bytes = T::size() * buffer.size(); int n_read = 0; - char * read_buffer = new char[size_in_bytes]; + char read_buffer[size_in_bytes]; inf.seekg(start_posn * T::size()); do { diff --git a/Processor/Data_Files.h b/Processor/Data_Files.h index 8f44ed253..8d05747e0 100644 --- a/Processor/Data_Files.h +++ b/Processor/Data_Files.h @@ -89,6 +89,7 @@ template class Processor; template class Data_Files; template class Machine; template class SubProcessor; +template class NoFilePrep; /** * Abstract base class for preprocessing @@ -125,6 +126,7 @@ class Preprocessing : public PrepBase template static Preprocessing* get_new(Machine& machine, DataPositions& usage, SubProcessor* proc); + template static Preprocessing* get_new(bool live_prep, const Names& N, DataPositions& usage); static Preprocessing* get_live_prep(SubProcessor* proc, @@ -133,22 +135,21 @@ class Preprocessing : public PrepBase Preprocessing(DataPositions& usage) : usage(usage), do_count(true) {} virtual ~Preprocessing() {} - virtual void set_protocol(typename T::Protocol& protocol) = 0; + virtual void set_protocol(typename T::Protocol&) {}; virtual void set_proc(SubProcessor* proc) { (void) proc; } virtual void seekg(DataPositions& pos) { (void) pos; } virtual void prune() {} virtual void purge() {} - virtual size_t data_sent() { return comm_stats().sent; } - virtual NamedCommStats comm_stats() { return {}; } - - virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0; - virtual void get_two_no_count(Dtype dtype, T& a, T& b) = 0; - virtual void get_one_no_count(Dtype dtype, T& a) = 0; - virtual void get_input_no_count(T& a, typename T::open_type& x, int i) = 0; - virtual void get_no_count(vector& S, DataTag tag, const vector& regs, - int vector_size) = 0; + virtual void get_three_no_count(Dtype, T&, T&, T&) + { throw not_implemented(); } + virtual void get_two_no_count(Dtype, T&, T&) { throw not_implemented(); } + virtual void get_one_no_count(Dtype, T&) { throw not_implemented(); } + virtual void get_input_no_count(T&, typename T::open_type&, int) + { throw not_implemented() ; } + virtual void get_no_count(vector&, DataTag, const vector&, int) + { throw not_implemented(); } void get(Dtype dtype, T* a); void get_three(Dtype dtype, T& a, T& b, T& c); @@ -191,6 +192,9 @@ class Sub_Data_Files : public Preprocessing { template friend class Sub_Data_Files; + typedef typename conditional, NoFilePrep>::type part_type; + static int tuple_length(int dtype); BufferOwner buffers[N_DTYPE]; @@ -205,7 +209,7 @@ class Sub_Data_Files : public Preprocessing const string prep_data_dir; int thread_num; - Sub_Data_Files* part; + part_type* part; void buffer_edabits_with_queues(bool strict, int n_bits) { buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); } @@ -274,7 +278,7 @@ class Sub_Data_Files : public Preprocessing void get_no_count(vector& S, DataTag tag, const vector& regs, int vector_size); void get_dabit_no_count(T& a, typename T::bit_type& b); - Preprocessing& get_part(); + part_type& get_part(); }; template @@ -307,8 +311,6 @@ class Data_Files } void reset_usage() { usage.reset(); skipped.reset(); } - - NamedCommStats comm_stats(); }; template inline @@ -418,6 +420,7 @@ T Preprocessing::get_bit() template T Preprocessing::get_random() { + assert(not usage.inputs.empty()); return get_random_from_inputs(usage.inputs.size()); } @@ -429,10 +432,4 @@ inline void Data_Files::purge() DataFb.purge(); } -template -NamedCommStats Data_Files::comm_stats() -{ - return DataFp.comm_stats() + DataF2.comm_stats() + DataFb.comm_stats(); -} - #endif diff --git a/Processor/Data_Files.hpp b/Processor/Data_Files.hpp index 3635dc0ac..359ff6207 100644 --- a/Processor/Data_Files.hpp +++ b/Processor/Data_Files.hpp @@ -3,6 +3,7 @@ #include "Processor/Data_Files.h" #include "Processor/Processor.h" +#include "Processor/NoFilePrep.h" #include "Protocols/dabit.h" #include "Math/Setup.h" #include "GC/BitPrepFiles.h" @@ -30,6 +31,7 @@ Preprocessing* Preprocessing::get_new( } template +template Preprocessing* Preprocessing::get_new( bool live_prep, const Names& N, DataPositions& usage) @@ -156,17 +158,7 @@ Data_Files::Data_Files(const Names& N) : template Data_Files::~Data_Files() { -#ifdef VERBOSE - if (DataFp.data_sent()) - cerr << "Sent for " << sint::type_string() << " preprocessing threads: " << - DataFp.data_sent() * 1e-6 << " MB" << endl; -#endif delete &DataFp; -#ifdef VERBOSE - if (DataF2.data_sent()) - cerr << "Sent for " << sgf2n::type_string() << " preprocessing threads: " << - DataF2.data_sent() * 1e-6 << " MB" << endl; -#endif delete &DataF2; delete &DataFb; } @@ -264,6 +256,8 @@ void Sub_Data_Files::purge() for (auto it : extended) it.second.purge(); dabit_buffer.purge(); + if (part != 0) + part->purge(); } template @@ -329,10 +323,10 @@ void Sub_Data_Files::buffer_edabits_with_queues(bool strict, int n_bits, } template -Preprocessing& Sub_Data_Files::get_part() +typename Sub_Data_Files::part_type& Sub_Data_Files::get_part() { if (part == 0) - part = new Sub_Data_Files(my_num, num_players, + part = new part_type(my_num, num_players, get_prep_sub_dir(num_players), this->usage, thread_num); return *part; diff --git a/Processor/DummyProtocol.h b/Processor/DummyProtocol.h index 95bcd029a..b3ed5bc54 100644 --- a/Processor/DummyProtocol.h +++ b/Processor/DummyProtocol.h @@ -87,10 +87,10 @@ class DummyProtocol : public ProtocolBase { } - void init_mul(SubProcessor* = 0) + void init_mul() { } - typename T::clear prepare_mul(const T&, const T&, int = 0) + void prepare_mul(const T&, const T&, int = 0) { throw not_implemented(); } diff --git a/Processor/FieldMachine.h b/Processor/FieldMachine.h index c544fb96a..859c64a1f 100644 --- a/Processor/FieldMachine.h +++ b/Processor/FieldMachine.h @@ -9,6 +9,9 @@ #include "RingMachine.h" #include "HonestMajorityMachine.h" #include "Tools/ezOptionParser.h" +#include "Math/gfp.h" + +#include "OnlineOptions.hpp" template class U, class V = HonestMajorityMachine> class HonestMajorityFieldMachine @@ -36,7 +39,7 @@ class DishonestMajorityFieldMachine ez::ezOptionParser& opt, bool live_prep_default = true) { OnlineOptions& online_opts = OnlineOptions::singleton; - online_opts = {opt, argc, argv, 1000, live_prep_default, true}; + online_opts = {opt, argc, argv, T(), live_prep_default}; FieldMachine(argc, argv, opt, online_opts); } diff --git a/Processor/FieldMachine.hpp b/Processor/FieldMachine.hpp index f93517d98..89ec66e1c 100644 --- a/Processor/FieldMachine.hpp +++ b/Processor/FieldMachine.hpp @@ -10,6 +10,7 @@ #include "HonestMajorityMachine.h" #include "Math/gfp.h" #include "OnlineMachine.hpp" +#include "OnlineOptions.hpp" template class T, class V> @@ -24,7 +25,7 @@ template class T, class V> HonestMajorityFieldMachine::HonestMajorityFieldMachine(int argc, const char **argv, ez::ezOptionParser& opt, int nplayers) { - OnlineOptions online_opts(opt, argc, argv, 0, true, true); + OnlineOptions online_opts(opt, argc, argv, T()); FieldMachine(argc, argv, opt, online_opts, nplayers); } diff --git a/Processor/HonestMajorityMachine.cpp b/Processor/HonestMajorityMachine.cpp index 3a756bc8b..295ef5fa0 100644 --- a/Processor/HonestMajorityMachine.cpp +++ b/Processor/HonestMajorityMachine.cpp @@ -18,7 +18,6 @@ HonestMajorityMachine::HonestMajorityMachine(int argc, const char** argv, ez::ezOptionParser& opt, OnlineOptions& online_opts, int nplayers) : OnlineMachine(argc, argv, opt, online_opts, nplayers) { - OnlineOptions::singleton = online_opts; opt.add( "", // Default. 0, // Required? @@ -29,6 +28,7 @@ HonestMajorityMachine::HonestMajorityMachine(int argc, const char** argv, "--unencrypted" // Flag token. ); online_opts.finalize(opt, argc, argv); + OnlineOptions::singleton = online_opts; use_encryption = not opt.get("-u")->isSet; diff --git a/Processor/Input.h b/Processor/Input.h index 9816c3578..98c6c83b0 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -14,6 +14,8 @@ using namespace std; #include "Tools/PointerVector.h" class ArithmeticProcessor; +template class SubProcessor; +template class Preprocessing; /** * Abstract base for input protocols @@ -25,6 +27,7 @@ class InputBase protected: Player* P; + int my_num; Buffer buffer; Timer timer; @@ -58,7 +61,7 @@ class InputBase /// Schedule input from other player virtual void add_other(int player, int n_bits = -1) = 0; /// Schedule input from all players - void add_from_all(const clear& input); + void add_from_all(const clear& input, int n_bits = -1); /// Send my inputs virtual void send_mine() = 0; diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 9272535bc..b9f7a77ab 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -19,6 +19,7 @@ template InputBase::InputBase(ArithmeticProcessor* proc) : P(0), values_input(0) { + my_num = -1; if (proc) buffer.setup(&proc->private_input, -1, proc->private_input_filename); } @@ -83,6 +84,7 @@ template void InputBase::reset_all(Player& P) { this->P = &P; + my_num = P.my_num(); os.resize(P.num_players()); for (int i = 0; i < P.num_players(); i++) reset(i); @@ -111,13 +113,13 @@ void Input::add_other(int player, int) } template -void InputBase::add_from_all(const clear& input) +void InputBase::add_from_all(const clear& input, int n_bits) { for (int i = 0; i < P->num_players(); i++) if (i == P->my_num()) - add_mine(input); + add_mine(input, n_bits); else - add_other(i); + add_other(i, n_bits); } template @@ -202,7 +204,7 @@ void Input::finalize_other(int player, T& target, template T InputBase::finalize(int player, int n_bits) { - if (player == P->my_num()) + if (player == my_num) return finalize_mine(); else { diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index e516fdf37..e45a85045 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -1091,9 +1091,11 @@ inline void Instruction::execute(Processor& Proc) const Proc.machine.time(); break; case START: + Proc.machine.set_thread_comm(Proc.P.total_comm()); Proc.machine.start(n); break; case STOP: + Proc.machine.set_thread_comm(Proc.P.total_comm()); Proc.machine.stop(n); break; case RUN_TAPE: diff --git a/Processor/Machine.h b/Processor/Machine.h index 3f23dc9f9..331a9a22c 100644 --- a/Processor/Machine.h +++ b/Processor/Machine.h @@ -69,7 +69,6 @@ class Machine : public BaseMachine OnlineOptions opts; - NamedCommStats comm_stats; ExecutionStats stats; Machine(int my_number, Names& playerNames, const string& progname, diff --git a/Processor/Machine.hpp b/Processor/Machine.hpp index 804dc51aa..d7d1a3ec3 100644 --- a/Processor/Machine.hpp +++ b/Processor/Machine.hpp @@ -142,6 +142,8 @@ template Machine::~Machine() { delete P; + for (auto& queue : queues) + delete queue; } template @@ -324,7 +326,7 @@ void Machine::run() { Timer proc_timer(CLOCK_PROCESS_CPUTIME_ID); proc_timer.start(); - timer[0].start(); + timer[0].start({}); // run main tape run_tape(0, 0, 0, N.num_players()); @@ -352,7 +354,6 @@ void Machine::run() queues[i]->schedule({}); pos.increase(queues[i]->result().pos); pthread_join(threads[i],NULL); - delete queues[i]; } finish_timer.stop(); @@ -372,6 +373,8 @@ void Machine::run() cerr << "Finish timer: " << finish_timer.elapsed() << endl; #endif + NamedCommStats comm_stats = total_comm(); + if (opts.verbose) { cerr << "Communication details " @@ -457,9 +460,12 @@ void Machine::run() } #ifndef INSECURE - Data_Files df(*this); - df.seekg(pos); - df.prune(); + if (not opts.file_prep_per_thread) + { + Data_Files df(*this); + df.seekg(pos); + df.prune(); + } #endif sint::LivePrep::teardown(); diff --git a/Processor/Memory.h b/Processor/Memory.h index 2c4a3d2e3..9ec02d2b8 100644 --- a/Processor/Memory.h +++ b/Processor/Memory.h @@ -43,8 +43,11 @@ class Memory template static void check_index(const vector& M, size_t i) { + (void) M, (void) i; +#ifdef NO_CHECK_INDEX if (i >= M.size()) throw overflow("memory", i, M.size()); +#endif } const typename T::clear& read_C(size_t i) const diff --git a/Processor/NoFilePrep.h b/Processor/NoFilePrep.h new file mode 100644 index 000000000..fbb44912e --- /dev/null +++ b/Processor/NoFilePrep.h @@ -0,0 +1,22 @@ +/* + * NoFilePrep.h + * + */ + +#ifndef PROCESSOR_NOFILEPREP_H_ +#define PROCESSOR_NOFILEPREP_H_ + +#include "Data_Files.h" + +template +class NoFilePrep : public Preprocessing +{ +public: + NoFilePrep(int, int, const string&, DataPositions& usage, int = -1) : + Preprocessing(usage) + { + throw runtime_error("don't call this"); + } +}; + +#endif /* PROCESSOR_NOFILEPREP_H_ */ diff --git a/Processor/OfflineMachine.hpp b/Processor/OfflineMachine.hpp index cffaded40..dcfafe553 100644 --- a/Processor/OfflineMachine.hpp +++ b/Processor/OfflineMachine.hpp @@ -71,7 +71,7 @@ void OfflineMachine::generate() auto my_usage = domain_usage[i]; Dtype dtype = Dtype(i); string filename = Sub_Data_Files::get_filename(playerNames, dtype, - T::clear::field_type() == DATA_GF2 ? 0 : -1); + 0); if (my_usage > 0) { ofstream out(filename, iostream::out | iostream::binary); @@ -106,7 +106,7 @@ void OfflineMachine::generate() for (int i = 0; i < P.num_players(); i++) { auto n_inputs = usage.inputs[i][T::clear::field_type()]; - string filename = Sub_Data_Files::get_input_filename(playerNames, i); + string filename = Sub_Data_Files::get_input_filename(playerNames, i, 0); if (n_inputs > 0) { ofstream out(filename, iostream::out | iostream::binary); @@ -137,7 +137,7 @@ void OfflineMachine::generate() int total = usage.edabits[{false, n_bits}] + usage.edabits[{true, n_bits}]; string filename = Sub_Data_Files::get_edabit_filename(playerNames, - n_bits); + n_bits, 0); if (total > 0) { ofstream out(filename, ios::binary); diff --git a/Processor/Online-Thread.hpp b/Processor/Online-Thread.hpp index cb25b4261..e98f1a3a1 100644 --- a/Processor/Online-Thread.hpp +++ b/Processor/Online-Thread.hpp @@ -279,7 +279,7 @@ void thread_info::Sub_Main_Func() printf("\tSignalling I have finished\n"); #endif wait_timer.start(); - queues->finished(job); + queues->finished(job, P.total_comm()); wait_timer.stop(); } } @@ -287,6 +287,11 @@ void thread_info::Sub_Main_Func() // final check Proc.check(); +#ifndef INSECURE + if (machine.opts.file_prep_per_thread) + Proc.DataF.prune(); +#endif + wait_timer.start(); queues->next(); wait_timer.stop(); @@ -314,16 +319,10 @@ void thread_info::Sub_Main_Func() #endif // wind down thread by thread - auto prep_stats = Proc.DataF.comm_stats(); - prep_stats += Proc.share_thread.DataF.comm_stats(); - prep_stats += Proc.Procp.bit_prep.comm_stats(); - for (auto& x : Proc.Procp.personal_bit_preps) - prep_stats += x->comm_stats(); machine.stats += Proc.stats; delete processor; - machine.comm_stats += P.comm_stats + prep_stats; - queues->finished(actual_usage); + queues->finished(actual_usage, P.total_comm()); delete MC2; delete MCp; diff --git a/Processor/OnlineOptions.cpp b/Processor/OnlineOptions.cpp index 41308603b..2a5e090bd 100644 --- a/Processor/OnlineOptions.cpp +++ b/Processor/OnlineOptions.cpp @@ -29,6 +29,7 @@ OnlineOptions::OnlineOptions() : playerno(-1) cmd_private_input_file = "Player-Data/Input"; cmd_private_output_file = ""; file_prep_per_thread = false; + trunc_error = 40; #ifdef VERBOSE verbose = true; #else @@ -326,6 +327,19 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc, #endif lgp = max(lgp, gfp0::MAX_N_BITS); } + + set_trunc_error(opt); +} + +void OnlineOptions::set_trunc_error(ez::ezOptionParser& opt) +{ + if (opt.get("-E")) + { + opt.get("-E")->getInt(trunc_error); +#ifdef VERBOSE + cerr << "Truncation error probability 2^-" << trunc_error << endl; +#endif + } } int OnlineOptions::prime_length() diff --git a/Processor/OnlineOptions.h b/Processor/OnlineOptions.h index de8f1e722..4b2fe4f8c 100644 --- a/Processor/OnlineOptions.h +++ b/Processor/OnlineOptions.h @@ -30,6 +30,7 @@ class OnlineOptions std::string cmd_private_output_file; bool verbose; bool file_prep_per_thread; + int trunc_error; OnlineOptions(); OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, @@ -37,10 +38,15 @@ class OnlineOptions OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, int default_batch_size = 0, bool default_live_prep = true, bool variable_prime_length = false); + template + OnlineOptions(ez::ezOptionParser& opt, int argc, const char** argv, T, + bool default_live_prep = true); ~OnlineOptions() {} void finalize(ez::ezOptionParser& opt, int argc, const char** argv); + void set_trunc_error(ez::ezOptionParser& opt); + int prime_length(); int prime_limbs(); diff --git a/Processor/OnlineOptions.hpp b/Processor/OnlineOptions.hpp new file mode 100644 index 000000000..8961853e5 --- /dev/null +++ b/Processor/OnlineOptions.hpp @@ -0,0 +1,30 @@ +/* + * OnlineOptions.hpp + * + */ + +#ifndef PROCESSOR_ONLINEOPTIONS_HPP_ +#define PROCESSOR_ONLINEOPTIONS_HPP_ + +#include "OnlineOptions.h" + +template +OnlineOptions::OnlineOptions(ez::ezOptionParser& opt, int argc, + const char** argv, T, bool default_live_prep) : + OnlineOptions(opt, argc, argv, T::dishonest_majority ? 1000 : 0, + default_live_prep, T::clear::prime_field) +{ + if (T::has_trunc_pr) + opt.add( + to_string(trunc_error).c_str(), // Default. + 0, // Required? + 1, // Number of args expected. + 0, // Delimiter if expecting multiple args. + "Probabilistic truncation error " + "(2^-x, default: 40)", // Help description. + "-E", // Flag token. + "--trunc-error" // Flag token. + ); +} + +#endif /* PROCESSOR_ONLINEOPTIONS_HPP_ */ diff --git a/Processor/PrepBase.cpp b/Processor/PrepBase.cpp index 5c44b9087..4ca77daa1 100644 --- a/Processor/PrepBase.cpp +++ b/Processor/PrepBase.cpp @@ -40,21 +40,33 @@ string PrepBase::get_edabit_filename(const string& prep_data_dir, + to_string(my_num) + get_suffix(thread_num); } -void PrepBase::print_left(const char* name, size_t n, const string& type_string) +void PrepBase::print_left(const char* name, size_t n, const string& type_string, + size_t used) { - if (n > 0) + if (n > 0 and OnlineOptions::singleton.verbose) cerr << "\t" << n << " " << name << " of " << type_string << " left" << endl; + + if (n > used / 10) + cerr << "Significant amount of unused " << name << " of " << type_string + << ". For more accurate benchmarks, " + << "consider reducing the batch size with -b." << endl; } void PrepBase::print_left_edabits(size_t n, size_t n_batch, bool strict, - int n_bits) + int n_bits, size_t used) { - if (n > 0) + if (n > 0 and OnlineOptions::singleton.verbose) { cerr << "\t~" << n * n_batch; if (not strict) cerr << " loose"; cerr << " edaBits of size " << n_bits << " left" << endl; } + + if (n > used / 10) + cerr << "Significant amount of unused edaBits of size " << n_bits + << ". For more accurate benchmarks, " + << "consider reducing the batch size with -b " + << "or increasing the bucket size with -B." << endl; } diff --git a/Processor/PrepBase.h b/Processor/PrepBase.h index bedba6299..ccc2f4b40 100644 --- a/Processor/PrepBase.h +++ b/Processor/PrepBase.h @@ -24,8 +24,10 @@ class PrepBase static string get_edabit_filename(const string& prep_data_dir, int n_bits, int my_num, int thread_num = 0); - static void print_left(const char* name, size_t n, const string& type_string); - static void print_left_edabits(size_t n, size_t n_batch, bool strict, int n_bits); + static void print_left(const char* name, size_t n, + const string& type_string, size_t used); + static void print_left_edabits(size_t n, size_t n_batch, bool strict, + int n_bits, size_t used); }; #endif /* PROCESSOR_PREPBASE_H_ */ diff --git a/Processor/Processor.h b/Processor/Processor.h index d9141855c..a78058cd1 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -243,10 +243,6 @@ class Processor : public ArithmeticProcessor cint get_inverse2(unsigned m); - // Print the processor state - template - friend ostream& operator<<(ostream& s,const Processor& P); - private: template friend class SPDZ; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index 6206e27c2..caea1e678 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -28,8 +28,8 @@ SubProcessor::SubProcessor(typename T::MAC_Check& MC, bit_prep(bit_usage) { DataF.set_proc(this); + protocol.init(DataF, MC); DataF.set_protocol(protocol); - protocol.init_mul(this); bit_usage.set_num_players(P.num_players()); personal_bit_preps.resize(P.num_players()); for (int i = 0; i < P.num_players(); i++) @@ -39,22 +39,12 @@ SubProcessor::SubProcessor(typename T::MAC_Check& MC, template SubProcessor::~SubProcessor() { - protocol.check(); - for (size_t i = 0; i < personal_bit_preps.size(); i++) { auto& x = personal_bit_preps[i]; -#ifdef VERBOSE - if (x->data_sent()) - cerr << "Sent for personal bit preprocessing threads of player " << i << ": " << - x->data_sent() * 1e-6 << " MB" << endl; -#endif delete x; } #ifdef VERBOSE - if (bit_prep.data_sent()) - cerr << "Sent for global bit preprocessing threads: " << - bit_prep.data_sent() * 1e-6 << " MB" << endl; if (not bit_usage.empty()) { cerr << "Mixed-circuit preprocessing cost:" << endl; @@ -423,7 +413,7 @@ void SubProcessor::muls(const vector& reg, int size) int n = reg.size() / 3; SubProcessor& proc = *this; - protocol.init_mul(&proc); + protocol.init_mul(); for (int i = 0; i < n; i++) for (int j = 0; j < size; j++) { @@ -448,7 +438,7 @@ void SubProcessor::mulrs(const vector& reg) int n = reg.size() / 4; SubProcessor& proc = *this; - protocol.init_mul(&proc); + protocol.init_mul(); for (int i = 0; i < n; i++) for (int j = 0; j < reg[4 * i]; j++) { @@ -470,7 +460,7 @@ void SubProcessor::mulrs(const vector& reg) template void SubProcessor::dotprods(const vector& reg, int size) { - protocol.init_dotprod(this); + protocol.init_dotprod(); for (int i = 0; i < size; i++) { auto it = reg.begin(); @@ -512,7 +502,7 @@ void SubProcessor::matmuls(const vector& source, assert(B + dim[1] * dim[2] <= source.end()); assert(C + dim[0] * dim[2] <= S.end()); - protocol.init_dotprod(this); + protocol.init_dotprod(); for (int i = 0; i < dim[0]; i++) for (int j = 0; j < dim[2]; j++) { @@ -536,7 +526,7 @@ void SubProcessor::matmulsm(const CheckVector& source, assert(C + dim[0] * dim[2] <= S.end()); assert(Proc); - protocol.init_dotprod(this); + protocol.init_dotprod(); for (int i = 0; i < dim[0]; i++) { auto ii = Proc->get_Ci().at(dim[3] + i); @@ -562,7 +552,7 @@ void SubProcessor::matmulsm(const CheckVector& source, template void SubProcessor::conv2ds(const Instruction& instruction) { - protocol.init_dotprod(this); + protocol.init_dotprod(); auto& args = instruction.get_start(); int output_h = args[0], output_w = args[1]; int inputs_h = args[2], inputs_w = args[3]; @@ -670,30 +660,4 @@ typename sint::clear Processor::get_inverse2(unsigned m) return inverses2m[m]; } -template -ostream& operator<<(ostream& s,const Processor& P) -{ - s << "Processor State" << endl; - s << "Char 2 Registers" << endl; - s << "Val\tClearReg\tSharedReg" << endl; - for (int i=0; i(), live_prep_default}; RingMachine(argc, argv, opt, online_opts); } diff --git a/Processor/RingMachine.hpp b/Processor/RingMachine.hpp index add3f43cc..e422e0aa5 100644 --- a/Processor/RingMachine.hpp +++ b/Processor/RingMachine.hpp @@ -12,6 +12,7 @@ #include "Tools/ezOptionParser.h" #include "Math/gf2n.h" #include "OnlineMachine.hpp" +#include "OnlineOptions.hpp" template class U, template class V> @@ -25,7 +26,7 @@ template class U, template class V> HonestMajorityRingMachine::HonestMajorityRingMachine(int argc, const char** argv, ez::ezOptionParser& opt, int nplayers) { - OnlineOptions online_opts(opt, argc, argv); + OnlineOptions online_opts(opt, argc, argv, U<64>()); RingMachine(argc, argv, opt, online_opts, nplayers); } diff --git a/Processor/ThreadQueue.cpp b/Processor/ThreadQueue.cpp index 3f5b1c76d..6358e4a4a 100644 --- a/Processor/ThreadQueue.cpp +++ b/Processor/ThreadQueue.cpp @@ -27,6 +27,19 @@ void ThreadQueue::finished(const ThreadJob& job) out.push(job); } +void ThreadQueue::finished(const ThreadJob& job, const NamedCommStats& new_comm_stats) +{ + finished(job); + set_comm_stats(new_comm_stats); +} + +void ThreadQueue::set_comm_stats(const NamedCommStats& new_comm_stats) +{ + lock.lock(); + comm_stats = new_comm_stats; + lock.unlock(); +} + ThreadJob ThreadQueue::result() { auto res = out.pop(); @@ -38,3 +51,11 @@ ThreadJob ThreadQueue::result() lock.unlock(); return res; } + +NamedCommStats ThreadQueue::get_comm_stats() +{ + lock.lock(); + auto res = comm_stats; + lock.unlock(); + return res; +} diff --git a/Processor/ThreadQueue.h b/Processor/ThreadQueue.h index 2e994b3ad..f49722abb 100644 --- a/Processor/ThreadQueue.h +++ b/Processor/ThreadQueue.h @@ -13,6 +13,7 @@ class ThreadQueue WaitQueue in, out; Lock lock; int left; + NamedCommStats comm_stats; public: ThreadQueue() : @@ -28,7 +29,11 @@ class ThreadQueue void schedule(const ThreadJob& job); ThreadJob next(); void finished(const ThreadJob& job); + void finished(const ThreadJob& job, const NamedCommStats& comm_stats); ThreadJob result(); + + void set_comm_stats(const NamedCommStats& new_comm_stats); + NamedCommStats get_comm_stats(); }; #endif /* PROCESSOR_THREADQUEUE_H_ */ diff --git a/Processor/TruncPrTuple.h b/Processor/TruncPrTuple.h index 06a96845f..267acae48 100644 --- a/Processor/TruncPrTuple.h +++ b/Processor/TruncPrTuple.h @@ -10,26 +10,35 @@ #include using namespace std; +#include "OnlineOptions.h" + template class TruncPrTuple { public: + const static int n = 4; + int dest_base; int source_base; int k; int m; int n_shift; - TruncPrTuple(const vector& regs, size_t base) + TruncPrTuple(const vector& regs, size_t base) : + TruncPrTuple(regs.begin() + base) + { + } + + TruncPrTuple(vector::const_iterator it) { - dest_base = regs[base]; - source_base = regs[base + 1]; - k = regs[base + 2]; - m = regs[base + 3]; + dest_base = *it++; + source_base = *it++; + k = *it++; + m = *it++; n_shift = T::N_BITS - 1 - k; assert(m < k); assert(0 < k); - assert(m < T::N_BITS); + assert(m < T::n_bits()); } T upper(T mask) @@ -49,10 +58,17 @@ class TruncPrTupleWithGap : public TruncPrTuple { public: TruncPrTupleWithGap(const vector& regs, size_t base) : - TruncPrTuple(regs, base) + TruncPrTupleWithGap(regs.begin() + base) { } + TruncPrTupleWithGap(vector::const_iterator it) : + TruncPrTuple(it) + { + if (T::prime_field and small_gap()) + throw runtime_error("domain too small for chosen truncation error"); + } + T upper(T mask) { if (big_gap()) @@ -69,7 +85,12 @@ class TruncPrTupleWithGap : public TruncPrTuple bool big_gap() { - return this->k <= T::N_BITS - 40; + return this->k <= T::n_bits() - OnlineOptions::singleton.trunc_error; + } + + bool small_gap() + { + return not big_gap(); } }; diff --git a/Programs/Source/keras_mnist_lenet_predict.mpc b/Programs/Source/keras_mnist_lenet_predict.mpc new file mode 100644 index 000000000..8b55de560 --- /dev/null +++ b/Programs/Source/keras_mnist_lenet_predict.mpc @@ -0,0 +1,44 @@ +# this trains LeNet on MNIST with a dropout layer +# see https://github.com/csiro-mlai/mnist-mpc for data preparation + +program.options_from_args() + +# training_samples = MultiArray([60000, 28, 28], sfix) +# training_labels = MultiArray([60000, 10], sint) + +test_samples = MultiArray([1, 28, 28], sfix) +test_labels = MultiArray([1, 10], sint) + +# training_labels.input_from(0) +# training_samples.input_from(0) + +# test_labels.input_from(0) +# test_samples.input_from(0) + +from Compiler import ml +tf = ml + +layers = [ + tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), + tf.keras.layers.MaxPooling2D(2), + tf.keras.layers.Flatten(), + tf.keras.layers.Dropout(0.5), + tf.keras.layers.Dense(500, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') +] + +model = tf.keras.models.Sequential(layers) + +model.build(test_samples.sizes) + +start = 0 +for var in model.trainable_variables: + var.assign_all(0) +# start = var.read_from_file(start) + +guesses = model.predict(test_samples, batch_size=1) + +print_ln('guess %s', guesses.reveal_nested()[:3]) +print_ln('truth %s', test_labels.reveal_nested()[:3]) diff --git a/Protocols/Atlas.h b/Protocols/Atlas.h index 3dd34d173..c99d911a9 100644 --- a/Protocols/Atlas.h +++ b/Protocols/Atlas.h @@ -53,18 +53,13 @@ class Atlas : public ProtocolBase return shamir.get_n_relevant_players(); } - void init_mul(Preprocessing&, typename T::MAC_Check&) - { - init_mul(); - } - - void init_mul(SubProcessor* proc = 0); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void prepare(const typename T::open_type& product); void exchange(); T finalize_mul(int n = -1); - void init_dotprod(SubProcessor* proc); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); diff --git a/Protocols/Atlas.hpp b/Protocols/Atlas.hpp index bb6f18bfb..c3a919b3d 100644 --- a/Protocols/Atlas.hpp +++ b/Protocols/Atlas.hpp @@ -38,7 +38,7 @@ array Atlas::get_double_sharing() } template -void Atlas::init_mul(SubProcessor*) +void Atlas::init_mul() { oss.reset(); oss2.reset(); @@ -47,10 +47,9 @@ void Atlas::init_mul(SubProcessor*) } template -typename T::clear Atlas::prepare_mul(const T& x, const T& y, int) +void Atlas::prepare_mul(const T& x, const T& y, int) { prepare(x * y); - return {}; } template @@ -98,9 +97,9 @@ T Atlas::finalize_mul(int) } template -void Atlas::init_dotprod(SubProcessor* proc) +void Atlas::init_dotprod() { - init_mul(proc); + init_mul(); dotprod_share = 0; } diff --git a/Protocols/Beaver.h b/Protocols/Beaver.h index e0c24e49e..2d28127c7 100644 --- a/Protocols/Beaver.h +++ b/Protocols/Beaver.h @@ -38,14 +38,17 @@ class Beaver : public ProtocolBase Beaver(Player& P) : prep(0), MC(0), P(P) {} - Player& branch(); + typename T::Protocol branch(); - void init_mul(SubProcessor* proc); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init(Preprocessing& prep, typename T::MAC_Check& MC); + + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); + void check(); + void start_exchange(); void stop_exchange(); diff --git a/Protocols/Beaver.hpp b/Protocols/Beaver.hpp index 639930059..dc9814870 100644 --- a/Protocols/Beaver.hpp +++ b/Protocols/Beaver.hpp @@ -13,30 +13,34 @@ #include template -Player& Beaver::branch() +typename T::Protocol Beaver::branch() { - return P; + typename T::Protocol res(P); + res.prep = prep; + res.MC = MC; + res.init_mul(); + return res; } template -void Beaver::init_mul(SubProcessor* proc) +void Beaver::init(Preprocessing& prep, typename T::MAC_Check& MC) { - assert(proc != 0); - init_mul(proc->DataF, proc->MC); + this->prep = &prep; + this->MC = &MC; } template -void Beaver::init_mul(Preprocessing& prep, typename T::MAC_Check& MC) +void Beaver::init_mul() { - this->prep = &prep; - this->MC = &MC; + assert(this->prep); + assert(this->MC); shares.clear(); opened.clear(); triples.clear(); } template -typename T::clear Beaver::prepare_mul(const T& x, const T& y, int n) +void Beaver::prepare_mul(const T& x, const T& y, int n) { (void) n; triples.push_back({{}}); @@ -44,7 +48,6 @@ typename T::clear Beaver::prepare_mul(const T& x, const T& y, int n) triple = prep->get_triple(n); shares.push_back(x - triple[0]); shares.push_back(y - triple[1]); - return 0; } template @@ -86,4 +89,11 @@ T Beaver::finalize_mul(int n) return tmp; } +template +void Beaver::check() +{ + assert(MC); + MC->Check(P); +} + #endif diff --git a/Protocols/BrainShare.h b/Protocols/BrainShare.h index 301ed9b0b..77f2e35f6 100644 --- a/Protocols/BrainShare.h +++ b/Protocols/BrainShare.h @@ -38,6 +38,8 @@ class BrainShare : public Rep3Share> const static int N_MASK_BITS = clear::N_BITS + S; const static int Z_BITS = 2 * (N_MASK_BITS) + 5 + S; + static const bool has_trunc_pr = false; + BrainShare() { } diff --git a/Protocols/FakeProtocol.h b/Protocols/FakeProtocol.h index 853782459..fb55f0cf4 100644 --- a/Protocols/FakeProtocol.h +++ b/Protocols/FakeProtocol.h @@ -9,6 +9,7 @@ #include "Replicated.h" #include "Math/Z2k.h" #include "Processor/Instruction.h" +#include "Processor/TruncPrTuple.h" #include @@ -75,15 +76,14 @@ class FakeProtocol : public ProtocolBase return P; } - void init_mul(SubProcessor*) + void init_mul() { results.clear(); } - typename T::clear prepare_mul(const T& x, const T& y, int = -1) + void prepare_mul(const T& x, const T& y, int = -1) { results.push_back(x * y); - return {}; } void exchange() @@ -95,9 +95,9 @@ class FakeProtocol : public ProtocolBase return results.next(); } - void init_dotprod(SubProcessor* proc) + void init_dotprod() { - init_mul(proc); + init_mul(); dot_prod = {}; } @@ -177,19 +177,22 @@ class FakeProtocol : public ProtocolBase res += overflow; } #else -#ifdef RISKY_TRUNCATION_IN_EMULATION - T r; - r.randomize(G); + if (TruncPrTupleWithGap(regs, i).big_gap()) + { + T r; + r.randomize(G); - if (source.negative()) - res = -T(((-source + r) >> n_shift) - (r >> n_shift)); + if (source.negative()) + res = -T(((-source + r) >> n_shift) - (r >> n_shift)); + else + res = ((source + r) >> n_shift) - (r >> n_shift); + } else - res = ((source + r) >> n_shift) - (r >> n_shift); -#else - T r; - r.randomize_part(G, n_shift); - res = (source + r) >> n_shift; -#endif + { + T r; + r.randomize_part(G, n_shift); + res = (source + r) >> n_shift; + } #endif } } diff --git a/Protocols/FakeShare.h b/Protocols/FakeShare.h index f36a7b754..569c136e6 100644 --- a/Protocols/FakeShare.h +++ b/Protocols/FakeShare.h @@ -32,6 +32,9 @@ class FakeShare : public T, public ShareInterface typedef GC::FakeSecret bit_type; + static const bool has_trunc_pr = true; + static const bool dishonest_majority = false; + static string type_short() { return "emul"; diff --git a/Protocols/Hemi.h b/Protocols/Hemi.h index 1e8021467..8a00c793c 100644 --- a/Protocols/Hemi.h +++ b/Protocols/Hemi.h @@ -6,14 +6,14 @@ #ifndef PROTOCOLS_HEMI_H_ #define PROTOCOLS_HEMI_H_ -#include "SPDZ.h" +#include "Semi.h" #include "HemiMatrixPrep.h" /** * Matrix multiplication optimized with semi-homomorphic encryption */ template -class Hemi : public SPDZ +class Hemi : public Semi { map, HemiMatrixPrep*> matrix_preps; @@ -22,7 +22,7 @@ class Hemi : public SPDZ public: Hemi(Player& P) : - SPDZ(P) + Semi(P) { } ~Hemi(); diff --git a/Protocols/Hemi.hpp b/Protocols/Hemi.hpp index dc285c14c..e67b28a97 100644 --- a/Protocols/Hemi.hpp +++ b/Protocols/Hemi.hpp @@ -51,19 +51,20 @@ void Hemi::matmulsm(SubProcessor& processor, CheckVector& source, ShareMatrix A(dim[0], dim[1]), B(dim[1], dim[2]); - for (int i = 0; i < dim[0]; i++) + for (int k = 0; k < dim[1]; k++) { - auto ii = Proc->get_Ci().at(dim[3] + i); + for (int i = 0; i < dim[0]; i++) + { + auto kk = Proc->get_Ci().at(dim[4] + k); + auto ii = Proc->get_Ci().at(dim[3] + i); + A[{i, k}] = source.at(a + ii * dim[7] + kk); + } + for (int j = 0; j < dim[2]; j++) { auto jj = Proc->get_Ci().at(dim[6] + j); - for (int k = 0; k < dim[1]; k++) - { - auto kk = Proc->get_Ci().at(dim[4] + k); - auto ll = Proc->get_Ci().at(dim[5] + k); - A[{i, k}] = source.at(a + ii * dim[7] + kk); - B[{k, j}] = source.at(b + ll * dim[8] + jj); - } + auto ll = Proc->get_Ci().at(dim[5] + k); + B[{k, j}] = source.at(b + ll * dim[8] + jj); } } @@ -93,7 +94,8 @@ ShareMatrix Hemi::matrix_multiply(const ShareMatrix& A, subdim[2] = min(max_cols, B.n_cols - j); auto& prep = get_matrix_prep(subdim, processor); MatrixMC mc; - beaver.init_mul(prep, mc); + beaver.init(prep, mc); + beaver.init_mul(); beaver.prepare_mul(A.from(0, i, subdim.data()), B.from(i, j, subdim.data() + 1)); beaver.exchange(); diff --git a/Protocols/HighGearKeyGen.cpp b/Protocols/HighGearKeyGen.cpp index 2618feba6..1c8f9f74d 100644 --- a/Protocols/HighGearKeyGen.cpp +++ b/Protocols/HighGearKeyGen.cpp @@ -19,5 +19,5 @@ template<> void PartSetup::key_and_mac_generation(Player& P, MachineBase& machine, int, false_type) { - HighGearKeyGen<2, 2>(P, params).run(*this, machine); + HighGearKeyGen<0, 0>(P, params).run(*this, machine); } diff --git a/Protocols/LowGearKeyGen.cpp b/Protocols/LowGearKeyGen.cpp index 2b149bc0f..61829b368 100644 --- a/Protocols/LowGearKeyGen.cpp +++ b/Protocols/LowGearKeyGen.cpp @@ -19,5 +19,5 @@ template<> void PairwiseSetup::key_and_mac_generation(Player& P, PairwiseMachine& machine, int, false_type) { - LowGearKeyGen<2>(P, machine, params).run(*this); + LowGearKeyGen<0>(P, machine, params).run(*this); } diff --git a/Protocols/LowGearKeyGen.hpp b/Protocols/LowGearKeyGen.hpp index a59820404..9ff92fb0e 100644 --- a/Protocols/LowGearKeyGen.hpp +++ b/Protocols/LowGearKeyGen.hpp @@ -126,7 +126,7 @@ typename KeyGenProtocol::vector_type KeyGenProtocol::schur_product( vector_type res; assert(x.size() == y.size()); auto& protocol = proc->protocol; - protocol.init_mul(proc); + protocol.init_mul(); for (size_t i = 0; i < x.size(); i++) protocol.prepare_mul(x[i], y[i]); protocol.exchange(); diff --git a/Protocols/MAC_Check.hpp b/Protocols/MAC_Check.hpp index db3f8dc71..85e9c84a1 100644 --- a/Protocols/MAC_Check.hpp +++ b/Protocols/MAC_Check.hpp @@ -50,11 +50,13 @@ Tree_MAC_Check::Tree_MAC_Check(const typename U::mac_key_type::Scalar& ai, in template Tree_MAC_Check::~Tree_MAC_Check() { +#ifndef NO_SECURITY_CHECK if (WaitingForCheck() > 0) { cerr << endl << "SECURITY BUG: insufficient checking" << endl; terminate(); } +#endif } template diff --git a/Protocols/MalRepRingPrep.hpp b/Protocols/MalRepRingPrep.hpp index ce34b64e4..96f2c8138 100644 --- a/Protocols/MalRepRingPrep.hpp +++ b/Protocols/MalRepRingPrep.hpp @@ -121,21 +121,6 @@ void shuffle_triple_generation(vector>& triples, Player& P, #endif } -template -void ShuffleSacrifice::shuffle(vector& check_triples, Player& P) -{ - int buffer_size = check_triples.size(); - - // shuffle - GlobalPRNG G(P); - for (int i = 0; i < buffer_size; i++) - { - int remaining = buffer_size - i; - int pos = G.get_uint(remaining); - swap(check_triples[i], check_triples[i + pos]); - } -} - template TripleShuffleSacrifice::TripleShuffleSacrifice() { @@ -251,32 +236,6 @@ void RingOnlyBitsFromSquaresPrep::buffer_bits() bits_from_square_in_ring(this->bits, this->buffer_size, &prep); } -template -void MaliciousRingPrep::buffer_edabits(bool strict, int n_bits, - ThreadQueues* queues) -{ - RunningTimer timer; -#ifndef NONPERSONAL_EDA - this->buffer_edabits_from_personal(strict, n_bits, queues); -#else - assert(this->proc != 0); - ShuffleSacrifice shuffle_sacrifice; - typedef typename T::bit_type::part_type bit_type; - vector> bits; - vector sums; - this->buffer_edabits_without_check(n_bits, sums, bits, - shuffle_sacrifice.minimum_n_inputs(), queues); - vector>& checked = this->edabits[{strict, n_bits}]; - shuffle_sacrifice.edabit_sacrifice(checked, sums, bits, - n_bits, *this->proc, strict, -1, queues); - if (strict) - this->sanitize(checked, n_bits, -1, queues); -#endif -#ifdef VERBOSE_EDA - cerr << "Total edaBit generation took " << timer.elapsed() << " seconds" << endl; -#endif -} - template void MalRepRingPrep::buffer_inputs(int player) { diff --git a/Protocols/MaliciousRep3Share.h b/Protocols/MaliciousRep3Share.h index 1967994d0..f98e9797f 100644 --- a/Protocols/MaliciousRep3Share.h +++ b/Protocols/MaliciousRep3Share.h @@ -42,6 +42,7 @@ class MaliciousRep3Share : public Rep3Share typedef GC::MaliciousRepSecret bit_type; const static bool expensive = true; + static const bool has_trunc_pr = false; static string type_short() { diff --git a/Protocols/MaliciousRepPO.h b/Protocols/MaliciousRepPO.h index 62d4b1783..7b58970f7 100644 --- a/Protocols/MaliciousRepPO.h +++ b/Protocols/MaliciousRepPO.h @@ -11,17 +11,21 @@ template class MaliciousRepPO { +protected: Player& P; octetStream to_send; octetStream to_receive[2]; + PointerVector secrets; public: MaliciousRepPO(Player& P); + virtual ~MaliciousRepPO() {} void prepare_sending(const T& secret, int player); - void send(int player); - void receive(); + virtual void send(int player); + virtual void receive(); typename T::clear finalize(const T& secret); + typename T::clear finalize(); }; #endif /* PROTOCOLS_MALICIOUSREPPO_H_ */ diff --git a/Protocols/MaliciousRepPO.hpp b/Protocols/MaliciousRepPO.hpp index 38a3a274a..bae235647 100644 --- a/Protocols/MaliciousRepPO.hpp +++ b/Protocols/MaliciousRepPO.hpp @@ -3,6 +3,9 @@ * */ +#ifndef PROTOCOLS_MALICIOUSREPPO_HPP_ +#define PROTOCOLS_MALICIOUSREPPO_HPP_ + #include "MaliciousRepPO.h" #include @@ -16,7 +19,10 @@ MaliciousRepPO::MaliciousRepPO(Player& P) : P(P) template void MaliciousRepPO::prepare_sending(const T& secret, int player) { - secret[2 - P.get_offset(player)].pack(to_send); + if (player == P.my_num()) + secrets.push_back(secret); + else + secret[2 - P.get_offset(player)].pack(to_send); } template @@ -24,7 +30,7 @@ void MaliciousRepPO::send(int player) { if (P.get_offset(player) == 2) P.send_to(player, to_send); - else + else if (P.my_num() != player) P.send_to(player, to_send.hash()); } @@ -42,3 +48,11 @@ typename T::clear MaliciousRepPO::finalize(const T& secret) { return secret.sum() + to_receive[0].template get(); } + +template +typename T::clear MaliciousRepPO::finalize() +{ + return finalize(secrets.next()); +} + +#endif diff --git a/Protocols/MaliciousRepPrep.hpp b/Protocols/MaliciousRepPrep.hpp index 4be3fc63a..8ffbff7bd 100644 --- a/Protocols/MaliciousRepPrep.hpp +++ b/Protocols/MaliciousRepPrep.hpp @@ -61,8 +61,9 @@ void MaliciousBitOnlyRepPrep::set_protocol(typename T::Protocol& protocol) template void MaliciousBitOnlyRepPrep::init_honest(Player& P) { - honest_proc = new SubProcessor(honest_mc, honest_prep, - P); + if (not honest_proc) + honest_proc = new SubProcessor(honest_mc, + honest_prep, P); } template diff --git a/Protocols/MamaPrep.hpp b/Protocols/MamaPrep.hpp index ef61ec7b9..c9eb63cf6 100644 --- a/Protocols/MamaPrep.hpp +++ b/Protocols/MamaPrep.hpp @@ -6,6 +6,7 @@ #include "MamaPrep.h" #include "SemiMC.hpp" +#include "MalRepRingPrep.hpp" template MamaPrep::MamaPrep(SubProcessor* proc, DataPositions& usage) : diff --git a/Protocols/MascotPrep.h b/Protocols/MascotPrep.h index 734453d31..5cfa82b84 100644 --- a/Protocols/MascotPrep.h +++ b/Protocols/MascotPrep.h @@ -21,8 +21,6 @@ class OTPrep : public virtual BitPrep ~OTPrep(); void set_protocol(typename T::Protocol& protocol); - - NamedCommStats comm_stats(); }; /** diff --git a/Protocols/MascotPrep.hpp b/Protocols/MascotPrep.hpp index cef603a25..1393bb464 100644 --- a/Protocols/MascotPrep.hpp +++ b/Protocols/MascotPrep.hpp @@ -40,8 +40,9 @@ void OTPrep::set_protocol(typename T::Protocol& protocol) // make sure not to use Montgomery multiplication T::open_type::next::template init(false); + assert(not triple_generator); triple_generator = new typename T::TripleGenerator( - BaseMachine::s().fresh_ot_setup(), + BaseMachine::fresh_ot_setup(proc->P), proc->P.N, -1, OnlineOptions::singleton.batch_size, 1, params, proc->MC.get_alphai(), &proc->P); @@ -121,13 +122,4 @@ T Preprocessing::get_random_from_inputs(int nplayers) return res; } -template -NamedCommStats OTPrep::comm_stats() -{ - auto res = BitPrep::comm_stats(); - if (triple_generator) - res += triple_generator->comm_stats(); - return res; -} - #endif diff --git a/Protocols/NoProtocol.h b/Protocols/NoProtocol.h index b99ce4e3e..d8259eb0f 100644 --- a/Protocols/NoProtocol.h +++ b/Protocols/NoProtocol.h @@ -45,12 +45,12 @@ class NoProtocol : public ProtocolBase } // prepare next round of multiplications - void init_mul(SubProcessor*) + void init_mul() { } // schedule multiplication - typename T::clear prepare_mul(const T&, const T&, int = -1) + void prepare_mul(const T&, const T&, int = -1) { throw runtime_error("no multiplication preparation"); } diff --git a/Protocols/PostSacriRepRingShare.h b/Protocols/PostSacriRepRingShare.h index 70371744d..d4f2ab0fd 100644 --- a/Protocols/PostSacriRepRingShare.h +++ b/Protocols/PostSacriRepRingShare.h @@ -22,6 +22,8 @@ class PostSacriRepRingShare : public Rep3Share2 static const int BIT_LENGTH = K; static const int SECURITY = S; + static const bool has_trunc_pr = false; + typedef SignedZ2 clear; typedef MaliciousRep3Share> prep_type; typedef Z2 random_type; diff --git a/Protocols/PostSacrifice.h b/Protocols/PostSacrifice.h index 73ec766e4..54b178a74 100644 --- a/Protocols/PostSacrifice.h +++ b/Protocols/PostSacrifice.h @@ -30,8 +30,8 @@ class PostSacrifice : public ProtocolBase Player& branch(); - void init_mul(SubProcessor* proc); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange() { internal.exchange(); } T finalize_mul(int n = -1); diff --git a/Protocols/PostSacrifice.hpp b/Protocols/PostSacrifice.hpp index 4db3b73b4..0f72f4e81 100644 --- a/Protocols/PostSacrifice.hpp +++ b/Protocols/PostSacrifice.hpp @@ -25,9 +25,8 @@ Player& PostSacrifice::branch() } template -void PostSacrifice::init_mul(SubProcessor* proc) +void PostSacrifice::init_mul() { - (void) proc; // throw away unused operands operands.resize(results.size()); if ((int) results.size() >= OnlineOptions::singleton.batch_size) @@ -36,11 +35,11 @@ void PostSacrifice::init_mul(SubProcessor* proc) } template -typename T::clear PostSacrifice::prepare_mul(const T& x, const T& y, int n) +void PostSacrifice::prepare_mul(const T& x, const T& y, int n) { (void) n; operands.push_back({{x, y}}); - return internal.prepare_mul(x, y); + internal.prepare_mul(x, y); } template diff --git a/Protocols/ProtocolSet.h b/Protocols/ProtocolSet.h new file mode 100644 index 000000000..e6a8eb525 --- /dev/null +++ b/Protocols/ProtocolSet.h @@ -0,0 +1,107 @@ +/* + * ProtocolSet.h + * + */ + +#ifndef PROTOCOLS_PROTOCOLSET_H_ +#define PROTOCOLS_PROTOCOLSET_H_ + +#include "Processor/Processor.h" +#include "GC/ShareThread.h" +#include "ProtocolSetup.h" + +/** + * Input, multiplication, and output protocol instance + * for an arithmetic share type + */ +template +class ProtocolSet +{ + DataPositions usage; + +public: + typename T::MAC_Check output; + typename T::LivePrep preprocessing; + SubProcessor processor; + typename T::Protocol& protocol; + typename T::Input& input; + + ProtocolSet(Player& P, typename T::mac_key_type mac_key) : + usage(P.num_players()), output(mac_key), preprocessing(0, usage), processor( + output, preprocessing, P), protocol(processor.protocol), input( + processor.input) + { + } + + /** + * @param P communication instance + * @param setup one-time setup instance + */ + ProtocolSet(Player& P, const ProtocolSetup& setup) : + ProtocolSet(P, setup.get_mac_key()) + { + } + + ~ProtocolSet() + { + } +}; + +/** + * Input, multiplication, and output protocol instance + * for a binary share type + */ +template +class BinaryProtocolSet +{ + DataPositions usage; + typename T::LivePrep prep; + GC::ShareThread thread; + +public: + typename T::MAC_Check& output; + typename T::Protocol& protocol; + typename T::Input input; + + /** + * @param P communication instance + * @param setup one-time setup instance + */ + BinaryProtocolSet(Player& P, const BinaryProtocolSetup& setup) : + usage(P.num_players()), prep(usage), thread(prep, P, + setup.get_mac_key()), output(*thread.MC), protocol( + *thread.protocol), input(output, prep, P) + { + } +}; + +/** + * Input, multiplication, and output protocol instance + * for an arithmetic share type and the corresponding binary one + */ +template +class MixedProtocolSet +{ + ProtocolSet arithmetic; + +public: + BinaryProtocolSet binary; + + typename T::MAC_Check& output; + typename T::LivePrep& preprocessing; + typename T::Protocol& protocol; + typename T::Input& input; + + /** + * @param P communication instance + * @param setup one-time setup instance + */ + MixedProtocolSet(Player& P, const MixedProtocolSetup& setup) : + arithmetic(P, setup), binary(P, setup.binary), output( + arithmetic.output), preprocessing(arithmetic.preprocessing), protocol( + arithmetic.protocol), input(arithmetic.input) + { + } +}; + +#endif /* PROTOCOLS_PROTOCOLSET_H_ */ diff --git a/Protocols/ProtocolSetup.h b/Protocols/ProtocolSetup.h new file mode 100644 index 000000000..b6d91b2bc --- /dev/null +++ b/Protocols/ProtocolSetup.h @@ -0,0 +1,95 @@ +/* + * ProtocolSetup.h + * + */ + +#ifndef PROTOCOLS_PROTOCOLSETUP_H_ +#define PROTOCOLS_PROTOCOLSETUP_H_ + +#include "Networking/Player.h" + +/** + * Global setup for an arithmetic share type + */ +template +class ProtocolSetup +{ + typename T::mac_key_type mac_key; + +public: + /** + * @param P communication instance (used for MAC generation if needed) + * @param prime_length length of prime if computing modulo a prime + * @param directory location to read MAC if needed + */ + ProtocolSetup(Player& P, int prime_length = 0, string directory = "") + { + // initialize fields + if (prime_length == 0) + prime_length = T::clear::MAX_N_BITS; + + T::clear::init_default(prime_length); + T::clear::next::init_default(prime_length, false); + + // must initialize MAC key for security of some protocols + T::read_or_generate_mac_key(directory, P, mac_key); + } + + ~ProtocolSetup() + { + T::LivePrep::teardown(); + } + + typename T::mac_key_type get_mac_key() const + { + return mac_key; + } +}; + +/** + * Global setup for a binary share type + */ +template +class BinaryProtocolSetup +{ + typename T::mac_key_type mac_key; + +public: + /** + * @param P communication instance (used for MAC generation if needed) + * @param directory location to read MAC if needed + */ + BinaryProtocolSetup(Player& P, string directory = "") + { + T::part_type::open_type::init_field(); + T::mac_key_type::init_field(); + T::part_type::read_or_generate_mac_key(directory, P, mac_key); + } + + typename T::mac_key_type get_mac_key() const + { + return mac_key; + } +}; + +/** + * Global setup for an arithmetic share type and the corresponding binary one + */ +template +class MixedProtocolSetup : public ProtocolSetup +{ +public: + BinaryProtocolSetup binary; + + /** + * @param P communication instance (used for MAC generation if needed) + * @param prime_length length of prime if computing modulo a prime + * @param directory location to read MAC if needed + */ + MixedProtocolSetup(Player& P, int prime_length = 0, string directory = "") : + ProtocolSetup(P, prime_length, directory), binary(P, directory) + { + } +}; + +#endif /* PROTOCOLS_PROTOCOLSETUP_H_ */ diff --git a/Protocols/Rep3Share.h b/Protocols/Rep3Share.h index d115b4c5c..e85065ac0 100644 --- a/Protocols/Rep3Share.h +++ b/Protocols/Rep3Share.h @@ -11,6 +11,7 @@ #include "Protocols/Replicated.h" #include "GC/ShareSecret.h" #include "ShareInterface.h" +#include "Processor/Instruction.h" template class ReplicatedPrep; template class ReplicatedRingPrep; @@ -67,6 +68,31 @@ class RepShare : public FixedVec, public ShareInterface assert(full); FixedVec::unpack(os); } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst) + { + shrsi(proc, inst, T::invertible); + } + + template + static void shrsi(SubProcessor&, const Instruction&, + true_type) + { + throw runtime_error("shrsi not implemented"); + } + + template + static void shrsi(SubProcessor& proc, const Instruction& inst, + false_type) + { + for (int i = 0; i < inst.get_size(); i++) + { + auto& dest = proc.get_S_ref(inst.get_r(0) + i); + auto& source = proc.get_S_ref(inst.get_r(1) + i); + dest = source >> inst.get_n(); + } + } }; template @@ -94,6 +120,7 @@ class Rep3Share : public RepShare const static bool dishonest_majority = false; const static bool expensive = false; const static bool variable_players = false; + static const bool has_trunc_pr = true; static string type_short() { diff --git a/Protocols/Rep3Share2k.h b/Protocols/Rep3Share2k.h index c7a494525..23f28cf9b 100644 --- a/Protocols/Rep3Share2k.h +++ b/Protocols/Rep3Share2k.h @@ -31,7 +31,6 @@ class Rep3Share2 : public Rep3Share> typedef GC::SemiHonestRepSecret bit_type; - static const bool has_trunc_pr = true; static const bool has_split = true; Rep3Share2() @@ -132,17 +131,6 @@ class Rep3Share2 : public Rep3Share> } } } - - template - static void shrsi(SubProcessor& proc, const Instruction& inst) - { - for (int i = 0; i < inst.get_size(); i++) - { - auto& dest = proc.get_S_ref(inst.get_r(0) + i); - auto& source = proc.get_S_ref(inst.get_r(1) + i); - dest = source >> inst.get_n(); - } - } }; #endif /* PROTOCOLS_REP3SHARE2K_H_ */ diff --git a/Protocols/Rep4.h b/Protocols/Rep4.h index aa0fc7bce..6acfae421 100644 --- a/Protocols/Rep4.h +++ b/Protocols/Rep4.h @@ -60,6 +60,11 @@ class Rep4 : public ProtocolBase void trunc_pr(const vector& regs, int size, SubProcessor& proc, false_type); + template + T finalize_mul(int n_bits, true_type); + template + T finalize_mul(int n_bits, false_type); + public: prngs_type rep_prngs; Player& P; @@ -70,14 +75,13 @@ class Rep4 : public ProtocolBase Rep4 branch(); - void init_mul(SubProcessor* proc = 0); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); void check(); - void init_dotprod(SubProcessor* proc); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); diff --git a/Protocols/Rep4.hpp b/Protocols/Rep4.hpp index e77b4e6f5..a2deab2be 100644 --- a/Protocols/Rep4.hpp +++ b/Protocols/Rep4.hpp @@ -59,7 +59,7 @@ Rep4 Rep4::branch() } template -void Rep4::init_mul(SubProcessor*) +void Rep4::init_mul() { for (auto& x : add_shares) x.clear(); @@ -70,12 +70,6 @@ void Rep4::init_mul(SubProcessor*) channels.resize(P.num_players(), vector(P.num_players(), false)); } -template -void Rep4::init_mul(Preprocessing&, typename T::MAC_Check&) -{ - init_mul(); -} - template void Rep4::reset_joint_input(int n_inputs) { @@ -194,13 +188,12 @@ int Rep4::get_player(int offset) } template -typename T::clear Rep4::prepare_mul(const T& x, const T& y, int n_bits) +void Rep4::prepare_mul(const T& x, const T& y, int n_bits) { auto a = get_addshares(x, y); for (int i = 0; i < 5; i++) add_shares[i].push_back(a[i]); bit_lengths.push_back(n_bits); - return {}; } template @@ -215,7 +208,7 @@ array Rep4::get_addshares(const T& x, const T& y) } template -void Rep4::init_dotprod(SubProcessor*) +void Rep4::init_dotprod() { init_mul(); dotprod_shares = {}; @@ -260,10 +253,27 @@ void Rep4::exchange() } template -T Rep4::finalize_mul(int) +T Rep4::finalize_mul(int n_bits) { this->counter++; - return results.next().res; + if (n_bits == -1) + return results.next().res; + else + return finalize_mul(n_bits, T::clear::binary); +} + +template +template +T Rep4::finalize_mul(int n_bits, true_type) +{ + return results.next().res.mask(n_bits); +} + +template +template +T Rep4::finalize_mul(int, false_type) +{ + throw runtime_error("bit-wise multiplication not supported"); } template diff --git a/Protocols/Rep4Prep.hpp b/Protocols/Rep4Prep.hpp index 17915e43d..e871e82c9 100644 --- a/Protocols/Rep4Prep.hpp +++ b/Protocols/Rep4Prep.hpp @@ -54,7 +54,7 @@ template void Rep4RingPrep::buffer_squares() { generate_squares(this->squares, OnlineOptions::singleton.batch_size, - this->protocol, this->proc); + this->protocol); } template diff --git a/Protocols/Replicated.h b/Protocols/Replicated.h index 3de9bfabc..67527a208 100644 --- a/Protocols/Replicated.h +++ b/Protocols/Replicated.h @@ -76,10 +76,13 @@ class ProtocolBase /// Single multiplication T mul(const T& x, const T& y); + /// Initialize protocol if needed (repeated call possible) + virtual void init(Preprocessing&, typename T::MAC_Check&) {} + /// Initialize multiplication round - virtual void init_mul(SubProcessor* proc) = 0; + virtual void init_mul() = 0; /// Schedule multiplication of operand pair - virtual typename T::clear prepare_mul(const T& x, const T& y, int n = -1) = 0; + virtual void prepare_mul(const T& x, const T& y, int n = -1) = 0; /// Run multiplication protocol virtual void exchange() = 0; /// Get next multiplication result @@ -88,7 +91,7 @@ class ProtocolBase virtual void finalize_mult(T& res, int n = -1); /// Initialize dot product round - void init_dotprod(SubProcessor* proc) { init_mul(proc); } + void init_dotprod() { init_mul(); } /// Add operand pair to current dot product void prepare_dotprod(const T& x, const T& y) { prepare_mul(x, y); } /// Finish dot product @@ -132,6 +135,11 @@ class Replicated : public ReplicatedBase, public ProtocolBase PointerVector add_shares; typename T::clear dotprod_share; + template + void trunc_pr(const vector& regs, int size, U& proc, true_type); + template + void trunc_pr(const vector& regs, int size, U& proc, false_type); + public: typedef ReplicatedMC MAC_Check; typedef ReplicatedInput Input; @@ -149,17 +157,13 @@ class Replicated : public ReplicatedBase, public ProtocolBase share[my_num] = value; } - void init_mul(SubProcessor* proc); - void init_mul(Preprocessing& prep, typename T::MAC_Check& MC); - void init_mul(); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); void prepare_reshare(const typename T::clear& share, int n = -1); - void init_dotprod(SubProcessor*) { init_mul(); } void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); diff --git a/Protocols/Replicated.hpp b/Protocols/Replicated.hpp index 75dc785be..374ed89b1 100644 --- a/Protocols/Replicated.hpp +++ b/Protocols/Replicated.hpp @@ -11,12 +11,10 @@ #include "Processor/TruncPrTuple.h" #include "Tools/benchmarking.h" -#include "SemiShare.h" -#include "SemiMC.h" #include "ReplicatedInput.h" #include "Rep3Share2k.h" -#include "SemiMC.hpp" +#include "ReplicatedPO.hpp" #include "Math/Z2k.hpp" template @@ -99,7 +97,8 @@ void ProtocolBase::multiply(vector& products, BaseMachine::thread_num); #endif - init_mul(&proc); + init(proc.DataF, proc.MC); + init_mul(); for (int i = begin; i < end; i++) prepare_mul(multiplicands[i].first, multiplicands[i].second); exchange(); @@ -110,7 +109,7 @@ void ProtocolBase::multiply(vector& products, template T ProtocolBase::mul(const T& x, const T& y) { - init_mul(0); + init_mul(); prepare_mul(x, y); exchange(); return finalize_mul(); @@ -146,20 +145,6 @@ T ProtocolBase::get_random() return res; } -template -void Replicated::init_mul(SubProcessor* proc) -{ - (void) proc; - init_mul(); -} - -template -void Replicated::init_mul(Preprocessing& prep, typename T::MAC_Check& MC) -{ - (void) prep, (void) MC; - init_mul(); -} - template void Replicated::init_mul() { @@ -169,12 +154,11 @@ void Replicated::init_mul() } template -inline typename T::clear Replicated::prepare_mul(const T& x, +void Replicated::prepare_mul(const T& x, const T& y, int n) { typename T::value_type add_share = x.local_mul(y); prepare_reshare(add_share, n); - return add_share; } template @@ -276,109 +260,89 @@ void Replicated::randoms(T& res, int n_bits) res[i].randomize_part(shared_prngs[i], n_bits); } -template -void trunc_pr(const vector& regs, int size, - SubProcessor>& proc) +template +template +void Replicated::trunc_pr(const vector& regs, int size, U& proc, + false_type) { assert(regs.size() % 4 == 0); assert(proc.P.num_players() == 3); assert(proc.Proc != 0); - typedef SignedZ2 value_type; - typedef Rep3Share T; - bool generate = proc.P.my_num() == 2; + typedef typename T::clear value_type; + int gen_player = 2; + int comp_player = 1; + bool generate = P.my_num() == gen_player; + bool compute = P.my_num() == comp_player; + ArgList> infos(regs); + auto& S = proc.get_S(); + + octetStream cs; + ReplicatedInput input(P); + if (generate) { - octetStream os[2]; - for (size_t i = 0; i < regs.size(); i += 4) - { - TruncPrTuple info(regs, i); - for (int l = 0; l < size; l++) + SeededPRNG G; + for (auto info : infos) + for (int i = 0; i < size; i++) { - auto& res = proc.get_S_ref(regs[i] + l); - auto& G = proc.Proc->secure_prng; - auto mask = G.template get(); - auto unmask = info.upper(mask); - T shares[4]; - shares[0].randomize_to_sum(mask, G); - shares[1].randomize_to_sum(unmask, G); - shares[2].randomize_to_sum(info.msb(mask), G); - res.randomize(G); - shares[3] = res; - for (int i = 0; i < 2; i++) - { - for (int j = 0; j < 4; j++) - shares[j][i].pack(os[i]); - } + auto r = G.get(); + input.add_mine(info.upper(r)); + if (info.small_gap()) + input.add_mine(info.msb(r)); + (r + S[info.source_base + i][0]).pack(cs); } - } - for (int i = 0; i < 2; i++) - proc.P.send_to(i, os[i]); + P.send_to(comp_player, cs); } else + input.add_other(gen_player); + + if (compute) { - octetStream os; - proc.P.receive_player(2, os); - OffsetPlayer player(proc.P, 1 - 2 * proc.P.my_num()); - typedef SemiShare semi_type; - vector> to_open; - PointerVector> mask_shares[3]; - for (size_t i = 0; i < regs.size(); i += 4) - for (int l = 0; l < size; l++) + P.receive_player(gen_player, cs); + for (auto info : infos) + for (int i = 0; i < size; i++) { - SemiShare share; - auto& x = proc.get_S_ref(regs[i + 1] + l); - if (proc.P.my_num() == 0) - share = x.sum(); - else - share = x[0]; - for (auto& mask_share : mask_shares) - mask_share.push_back(os.get()); - to_open.push_back(share + mask_shares[0].next()); - auto& res = proc.get_S_ref(regs[i] + l); - auto& a = res[1 - proc.P.my_num()]; - a.unpack(os); + auto c = cs.get() + S[info.source_base + i].sum(); + input.add_mine(info.upper(c)); + if (info.small_gap()) + input.add_mine(info.msb(c)); } - PointerVector opened; - DirectSemiMC> MC; - MC.POpen_(opened, to_open, player); - os.reset_write_head(); - for (size_t i = 0; i < regs.size(); i += 4) + } + + input.add_other(comp_player); + input.exchange(); + init_mul(); + + for (auto info : infos) + for (int i = 0; i < size; i++) { - int k = regs[i + 2]; - int m = regs[i + 3]; - int n_shift = value_type::N_BITS - 1 - k; - assert(m < k); - assert(0 < k); - assert(m < value_type::N_BITS); - for (int l = 0; l < size; l++) + auto c_prime = input.finalize(comp_player); + auto r_prime = input.finalize(gen_player); + S[info.dest_base + i] = c_prime - r_prime; + + if (info.small_gap()) { - auto& res = proc.get_S_ref(regs[i] + l); - auto masked = opened.next() << n_shift; - auto shifted = (masked << 1) >> (n_shift + m + 1); - auto diff = SemiShare::constant(shifted, - player.my_num()) - mask_shares[1].next(); - auto msb = masked >> (value_type::N_BITS - 1); - auto bit_mask = mask_shares[2].next(); - auto overflow = (bit_mask - + SemiShare::constant(msb, player.my_num()) - - bit_mask * msb * 2); - auto res_share = diff + (overflow << (k - m)); - auto& a = res[1 - proc.P.my_num()]; - auto& b = res[proc.P.my_num()]; - b = res_share - a; - b.pack(os); + auto c_dprime = input.finalize(comp_player); + auto r_msb = input.finalize(gen_player); + S[info.dest_base + i] += ((r_msb + c_dprime) + << (info.k - info.m)); + prepare_mul(r_msb, c_dprime); } } - player.exchange(os); - for (size_t i = 0; i < regs.size(); i += 4) - for (int l = 0; l < size; l++) - proc.get_S_ref(regs[i] + l)[proc.P.my_num()] += - os.get(); - } + + exchange(); + + for (auto info : infos) + for (int i = 0; i < size; i++) + if (info.small_gap()) + S[info.dest_base + i] -= finalize_mul() + << (info.k - info.m + 1); } template -void trunc_pr(const vector& regs, int size, SubProcessor& proc) +template +void Replicated::trunc_pr(const vector& regs, int size, U& proc, + true_type) { (void) regs, (void) size, (void) proc; throw runtime_error("trunc_pr not implemented"); @@ -390,7 +354,7 @@ void Replicated::trunc_pr(const vector& regs, int size, U& proc) { this->trunc_rounds++; - ::trunc_pr(regs, size, proc); + trunc_pr(regs, size, proc, T::clear::characteristic_two); } #endif diff --git a/Protocols/ReplicatedInput.h b/Protocols/ReplicatedInput.h index 7d62838a3..9bb3c30a3 100644 --- a/Protocols/ReplicatedInput.h +++ b/Protocols/ReplicatedInput.h @@ -72,9 +72,8 @@ class ReplicatedInput : public PrepLessInput PrepLessInput(proc), proc(proc), P(P), protocol(P) { assert(T::length == 2); - InputBase::P = &P; - InputBase::os.resize(P.num_players()); expect.resize(P.num_players()); + this->reset_all(P); } void reset(int player); diff --git a/Protocols/ReplicatedInput.hpp b/Protocols/ReplicatedInput.hpp index 741d2c490..1cfac4a16 100644 --- a/Protocols/ReplicatedInput.hpp +++ b/Protocols/ReplicatedInput.hpp @@ -71,7 +71,7 @@ template inline void ReplicatedInput::finalize_other(int player, T& target, octetStream& o, int n_bits) { - int offset = player - P.my_num(); + int offset = player - this->my_num; if (offset == 1 or offset == -2) { typename T::value_type t; diff --git a/Protocols/ReplicatedPO.h b/Protocols/ReplicatedPO.h new file mode 100644 index 000000000..a533a5b1a --- /dev/null +++ b/Protocols/ReplicatedPO.h @@ -0,0 +1,24 @@ +/* + * ReplicatedPO.h + * + */ + +#ifndef PROTOCOLS_REPLICATEDPO_H_ +#define PROTOCOLS_REPLICATEDPO_H_ + +#include "MaliciousRepPO.h" + +template +class ReplicatedPO : public MaliciousRepPO +{ +public: + ReplicatedPO(Player& P) : + MaliciousRepPO(P) + { + } + + void send(int player); + void receive(); +}; + +#endif /* PROTOCOLS_REPLICATEDPO_H_ */ diff --git a/Protocols/ReplicatedPO.hpp b/Protocols/ReplicatedPO.hpp new file mode 100644 index 000000000..aecd85b3f --- /dev/null +++ b/Protocols/ReplicatedPO.hpp @@ -0,0 +1,21 @@ +/* + * ReplicatedPO.cpp + * + */ + +#include "ReplicatedPO.h" + +#include "MaliciousRepPO.hpp" + +template +void ReplicatedPO::send(int player) +{ + if (this->P.get_offset(player) == 2) + this->P.send_to(player, this->to_send); +} + +template +void ReplicatedPO::receive() +{ + this->P.receive_relative(1, this->to_receive[0]); +} diff --git a/Protocols/ReplicatedPrep.h b/Protocols/ReplicatedPrep.h index 8c3ed3f13..8a30749c3 100644 --- a/Protocols/ReplicatedPrep.h +++ b/Protocols/ReplicatedPrep.h @@ -184,6 +184,15 @@ class RingPrep : public virtual BitPrep template void sanitize(vector>& edabits, int n_bits); + template + void buffer_personal_edabits_without_check_pre(int n_bits, + Player& P, typename T::Input& input, typename BT::Input& bit_input, + int input_player, int buffer_size); + template + void buffer_personal_edabits_without_check_post(int n_bits, + vector& sums, vector >& bits, typename T::Input& input, + typename BT::Input& bit_input, int input_player, int begin, int end); + public: RingPrep(SubProcessor* proc, DataPositions& usage); virtual ~RingPrep(); @@ -224,6 +233,13 @@ class RingPrep : public virtual BitPrep template class SemiHonestRingPrep : public virtual RingPrep { + template + void buffer_bits(false_type, false_type); + template + void buffer_bits(true_type, false_type); + template + void buffer_bits(false_type, true_type); + public: SemiHonestRingPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), @@ -232,7 +248,7 @@ class SemiHonestRingPrep : public virtual RingPrep } virtual ~SemiHonestRingPrep() {} - virtual void buffer_bits() { this->buffer_bits_without_check(); } + virtual void buffer_bits(); virtual void buffer_inputs(int player) { this->buffer_inputs_as_usual(player, this->proc); } @@ -358,11 +374,6 @@ template class ReplicatedPrep : public virtual ReplicatedRingPrep, public virtual SemiHonestRingPrep { - template - void buffer_bits(false_type); - template - void buffer_bits(true_type); - public: ReplicatedPrep(SubProcessor* proc, DataPositions& usage) : BufferPrep(usage), BitPrep(proc, usage), @@ -384,7 +395,7 @@ class ReplicatedPrep : public virtual ReplicatedRingPrep, } void buffer_squares() { ReplicatedRingPrep::buffer_squares(); } - void buffer_bits(); + void buffer_bits() { SemiHonestRingPrep::buffer_bits(); } }; #endif /* PROTOCOLS_REPLICATEDPREP_H_ */ diff --git a/Protocols/ReplicatedPrep.hpp b/Protocols/ReplicatedPrep.hpp index 2b8aa1604..916ee6b8f 100644 --- a/Protocols/ReplicatedPrep.hpp +++ b/Protocols/ReplicatedPrep.hpp @@ -56,24 +56,23 @@ BufferPrep::~BufferPrep() << " bit generation" << endl; #endif - if (OnlineOptions::singleton.verbose) - { - this->print_left("triples", triples.size() * T::default_length, - type_string); - -#define X(KIND) \ - this->print_left(#KIND, KIND.size(), type_string); - X(squares) - X(inverses) - X(bits) - X(dabits) + this->print_left("triples", triples.size() * T::default_length, type_string, + this->usage.files.at(T::clear::field_type()).at(DATA_TRIPLE) + * T::default_length); + +#define X(KIND, TYPE) \ + this->print_left(#KIND, KIND.size(), type_string, \ + this->usage.files.at(T::clear::field_type()).at(TYPE)); + X(squares, DATA_SQUARE) + X(inverses, DATA_INVERSE) + X(bits, DATA_BIT) + X(dabits, DATA_DABIT) #undef X - for (auto& x : this->edabits) - { - this->print_left_edabits(x.second.size(), x.second[0].size(), - x.first.first, x.first.second); - } + for (auto& x : this->edabits) + { + this->print_left_edabits(x.second.size(), x.second[0].size(), + x.first.first, x.first.second, this->usage.edabits[x.first]); } } @@ -100,7 +99,9 @@ RingPrep::~RingPrep() template void BitPrep::set_protocol(typename T::Protocol& protocol) { - this->protocol = new typename T::Protocol(protocol.branch()); + if (not this->protocol) + this->protocol = new typename T::Protocol(protocol.branch()); + this->protocol->init_mul(); auto proc = this->proc; if (proc and proc->Proc) this->base_player = proc->Proc->thread_num; @@ -202,16 +203,16 @@ template void ReplicatedRingPrep::buffer_squares() { generate_squares(this->squares, this->buffer_size, - this->protocol, this->proc); + this->protocol); } template void generate_squares(vector>& squares, int n_squares, - U* protocol, SubProcessor* proc) + U* protocol) { assert(protocol != 0); squares.resize(n_squares); - protocol->init_mul(proc); + protocol->init_mul(); for (size_t i = 0; i < squares.size(); i++) { auto& square = squares[i]; @@ -289,7 +290,7 @@ void BufferPrep::get_two_no_count(Dtype dtype, T& a, T& b) template void XOR(vector& res, vector& x, vector& y, - typename T::Protocol& prot, SubProcessor* proc) + typename T::Protocol& prot) { assert(x.size() == y.size()); int buffer_size = x.size(); @@ -302,7 +303,7 @@ void XOR(vector& res, vector& x, vector& y, return; } - prot.init_mul(proc); + prot.init_mul(); for (int i = 0; i < buffer_size; i++) prot.prepare_mul(x[i], y[i]); prot.exchange(); @@ -337,13 +338,14 @@ void buffer_bits_from_squares(RingPrep& prep) template template -void ReplicatedPrep::buffer_bits(true_type) +void SemiHonestRingPrep::buffer_bits(true_type, false_type) { if (this->protocol->get_n_relevant_players() > 10 - or OnlineOptions::singleton.bits_from_squares) + or OnlineOptions::singleton.bits_from_squares + or T::dishonest_majority) buffer_bits_from_squares(*this); else - ReplicatedRingPrep::buffer_bits(); + this->buffer_bits_without_check(); } template @@ -409,10 +411,9 @@ void MaliciousRingPrep::buffer_personal_dabits_without_check( auto& P = this->proc->P; auto &party = GC::ShareThread::s(); typedef typename T::bit_type::part_type BT; - SubProcessor bit_proc(party.MC->get_part_MC(), + typename BT::Input bit_input(party.MC->get_part_MC(), this->proc->bit_prep, this->proc->P); typename T::Input input(*this->proc, this->proc->MC); - typename BT::Input bit_input(bit_proc, bit_proc.MC); input.reset_all(P); bit_input.reset_all(P); SeededPRNG G; @@ -454,10 +455,24 @@ void RingPrep::buffer_personal_edabits_without_check(int n_bits, typename BT::Input bit_input(proc, proc.MC); input.reset_all(P); bit_input.reset_all(P); - SeededPRNG G; assert(begin % BT::default_length == 0); int buffer_size = end - begin; + buffer_personal_edabits_without_check_pre(n_bits, P, input, bit_input, + input_player, buffer_size); + input.exchange(); + bit_input.exchange(); + buffer_personal_edabits_without_check_post(n_bits, sums, bits, input, + bit_input, input_player, begin, end); +} + +template +template +void RingPrep::buffer_personal_edabits_without_check_pre(int n_bits, + Player& P, typename T::Input& input, typename BT::Input& bit_input, + int input_player, int buffer_size) +{ int n_chunks = DIV_CEIL(buffer_size, BT::default_length); + SeededPRNG G; if (input_player == P.my_num()) { for (int i = 0; i < n_chunks; i++) @@ -482,8 +497,16 @@ void RingPrep::buffer_personal_edabits_without_check(int n_bits, for (int i = 0; i < BT::default_length; i++) input.add_other(input_player); } - input.exchange(); - bit_input.exchange(); +} + +template +template +void RingPrep::buffer_personal_edabits_without_check_post(int n_bits, + vector& sums, vector >& bits, typename T::Input& input, + typename BT::Input& bit_input, int input_player, int begin, int end) +{ + int buffer_size = end - begin; + int n_chunks = DIV_CEIL(buffer_size, BT::default_length); for (int i = 0; i < buffer_size; i++) sums[begin + i] = input.finalize(input_player); assert(bits.size() == size_t(n_bits)); @@ -600,18 +623,18 @@ void BitPrep::buffer_ring_bits_without_check(vector& bits, PRNG& G, assert(proc != 0); int n_relevant_players = protocol->get_n_relevant_players(); vector> player_bits; - auto stat = proc->P.comm_stats; + auto stat = proc->P.total_comm(); buffer_bits_from_players(player_bits, G, *proc, this->base_player, buffer_size, 1); auto& prot = *protocol; - XOR(bits, player_bits[0], player_bits[1], prot, proc); + XOR(bits, player_bits[0], player_bits[1], prot); for (int i = 2; i < n_relevant_players; i++) - XOR(bits, bits, player_bits[i], prot, proc); + XOR(bits, bits, player_bits[i], prot); this->base_player++; (void) stat; #ifdef VERBOSE_PREP cerr << "bit generation" << endl; - (proc->P.comm_stats - stat).print(true); + (proc->P.total_comm() - stat).print(true); #endif } @@ -730,9 +753,22 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector& sums, vector> player_ints(n_relevant, vector(buffer_size)); vector>> parts(n_relevant, vector>(n_bits, vector(buffer_size / dl))); + InScope in_scope(this->do_count, false); + assert(this->proc != 0); + auto& P = proc->P; + typename T::Input input(*this->proc, this->proc->MC); + typename BT::Input bit_input(bit_proc, bit_proc.MC); + input.reset_all(P); + bit_input.reset_all(P); + assert(begin % BT::default_length == 0); + for (int i = 0; i < n_relevant; i++) + buffer_personal_edabits_without_check_pre(n_bits, P, input, bit_input, + i, buffer_size); + input.exchange(); + bit_input.exchange(); for (int i = 0; i < n_relevant; i++) - buffer_personal_edabits_without_check<0>(n_bits, player_ints[i], parts[i], - bit_proc, i, 0, buffer_size); + buffer_personal_edabits_without_check_post(n_bits, player_ints[i], + parts[i], input, bit_input, i, 0, buffer_size); vector>> player_bits(n_bits, vector>(n_relevant)); for (int i = 0; i < n_bits; i++) @@ -754,7 +790,7 @@ template void RingPrep::buffer_edabits_without_check(int n_bits, vector>& edabits, int buffer_size) { - auto stat = this->proc->P.comm_stats; + auto stat = this->proc->P.total_comm(); typedef typename T::bit_type::part_type bit_type; vector> bits; vector sums; @@ -763,7 +799,7 @@ void RingPrep::buffer_edabits_without_check(int n_bits, vector>& (void) stat; #ifdef VERBOSE_PREP cerr << "edaBit generation" << endl; - (proc->P.comm_stats - stat).print(true); + (proc->P.total_comm() - stat).print(true); #endif } @@ -920,40 +956,38 @@ void RingPrep::sanitize(vector>& edabits, int n_bits) delete &MCB; } -template<> -inline -void SemiHonestRingPrep>::buffer_bits() -{ - assert(protocol != 0); - bits_from_random(bits, *protocol); -} - template -void bits_from_random(vector& bits, typename T::Protocol& protocol) +template +void SemiHonestRingPrep::buffer_bits(false_type, true_type) { - while (bits.size() < (size_t)OnlineOptions::singleton.batch_size) - { - Rep3Share share = protocol.get_random(); - for (int j = 0; j < gf2n::degree(); j++) + assert(this->protocol != 0); + if (not T::dishonest_majority and T::variable_players) + // Shamir + this->buffer_bits_without_check(); + else + while (this->bits.size() < (size_t) OnlineOptions::singleton.batch_size) { - bits.push_back(share & 1); - share >>= 1; + auto share = this->get_random(); + for (int j = 0; j < T::open_type::degree(); j++) + { + this->bits.push_back(share & 1); + share >>= 1; + } } - } } template template -void ReplicatedPrep::buffer_bits(false_type) +void SemiHonestRingPrep::buffer_bits(false_type, false_type) { - ReplicatedRingPrep::buffer_bits(); + this->buffer_bits_without_check(); } template -void ReplicatedPrep::buffer_bits() +void SemiHonestRingPrep::buffer_bits() { assert(this->protocol != 0); - buffer_bits<0>(T::clear::prime_field); + buffer_bits(T::clear::prime_field, T::clear::characteristic_two); } template diff --git a/Protocols/Semi2k.h b/Protocols/Semi.h similarity index 75% rename from Protocols/Semi2k.h rename to Protocols/Semi.h index 69cf63aad..e290ca0eb 100644 --- a/Protocols/Semi2k.h +++ b/Protocols/Semi.h @@ -3,8 +3,8 @@ * */ -#ifndef PROTOCOLS_SEMI2K_H_ -#define PROTOCOLS_SEMI2K_H_ +#ifndef PROTOCOLS_SEMI_H_ +#define PROTOCOLS_SEMI_H_ #include "SPDZ.h" #include "Processor/TruncPrTuple.h" @@ -13,12 +13,12 @@ * Dishonest-majority protocol for computation modulo a power of two */ template -class Semi2k : public SPDZ +class Semi : public SPDZ { SeededPRNG G; public: - Semi2k(Player& P) : + Semi(Player& P) : SPDZ(P) { } @@ -30,6 +30,19 @@ class Semi2k : public SPDZ void trunc_pr(const vector& regs, int size, SubProcessor& proc) + { + trunc_pr(regs, size, proc, T::clear::characteristic_two); + } + + template + void trunc_pr(const vector&, int, SubProcessor&, true_type) + { + throw not_implemented(); + } + + template + void trunc_pr(const vector& regs, int size, + SubProcessor& proc, false_type) { if (this->P.num_players() > 2) throw runtime_error("probabilistic truncation " @@ -60,4 +73,4 @@ class Semi2k : public SPDZ } }; -#endif /* PROTOCOLS_SEMI2K_H_ */ +#endif /* PROTOCOLS_SEMI_H_ */ diff --git a/Protocols/Semi2kShare.h b/Protocols/Semi2kShare.h index a9df48b4a..ee5e83202 100644 --- a/Protocols/Semi2kShare.h +++ b/Protocols/Semi2kShare.h @@ -7,7 +7,7 @@ #define PROTOCOLS_SEMI2KSHARE_H_ #include "SemiShare.h" -#include "Semi2k.h" +#include "Semi.h" #include "OT/Rectangle.h" #include "GC/SemiSecret.h" #include "GC/square64.h" @@ -27,7 +27,7 @@ class Semi2kShare : public SemiShare> typedef DirectSemiMC Direct_MC; typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; - typedef Semi2k Protocol; + typedef Semi Protocol; typedef SemiPrep2k LivePrep; typedef Semi2kShare prep_type; @@ -35,8 +35,6 @@ class Semi2kShare : public SemiShare> typedef OTTripleGenerator TripleGenerator; typedef Z2kSquare Rectangle; - typedef GC::SemiSecret bit_type; - static const bool has_split = true; Semi2kShare() diff --git a/Protocols/SemiShare.h b/Protocols/SemiShare.h index ed044c461..c2dd90858 100644 --- a/Protocols/SemiShare.h +++ b/Protocols/SemiShare.h @@ -7,6 +7,7 @@ #define PROTOCOLS_SEMISHARE_H_ #include "Protocols/Beaver.h" +#include "Protocols/Semi.h" #include "Processor/DummyProtocol.h" #include "ShareInterface.h" @@ -16,7 +17,7 @@ using namespace std; template class Input; template class SemiMC; template class DirectSemiMC; -template class SPDZ; +template class Semi; template class SemiPrep; template class SemiInput; template class PrivateOutput; @@ -59,7 +60,7 @@ class SemiShare : public T, public ShareInterface typedef DirectSemiMC Direct_MC; typedef SemiInput Input; typedef ::PrivateOutput PrivateOutput; - typedef SPDZ Protocol; + typedef Semi Protocol; typedef SemiPrep LivePrep; typedef LivePrep TriplePrep; @@ -69,12 +70,15 @@ class SemiShare : public T, public ShareInterface typedef T sacri_type; typedef typename T::Square Rectangle; +#ifndef NO_MIXED_CIRCUITS typedef GC::SemiSecret bit_type; +#endif const static bool needs_ot = true; const static bool dishonest_majority = true; const static bool variable_players = true; const static bool expensive = false; + static const bool has_trunc_pr = true; static string type_short() { return "D" + string(1, T::type_char()); } diff --git a/Protocols/Shamir.h b/Protocols/Shamir.h index 3d2bf469b..f722886eb 100644 --- a/Protocols/Shamir.h +++ b/Protocols/Shamir.h @@ -62,20 +62,8 @@ class Shamir : public ProtocolBase void reset(); void init_mul(); - void init_mul(SubProcessor* proc); - template - void init_mul(V*) - { - init_mul(); - } - template - void init_mul(const V&, const W&) - { - init_mul(); - } - - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); void start_exchange(); @@ -85,7 +73,7 @@ class Shamir : public ProtocolBase T finalize(int n_input_players); - void init_dotprod(SubProcessor* proc = 0); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); diff --git a/Protocols/Shamir.hpp b/Protocols/Shamir.hpp index d387f3b47..9fe10bdea 100644 --- a/Protocols/Shamir.hpp +++ b/Protocols/Shamir.hpp @@ -80,13 +80,6 @@ void Shamir::reset() resharing->reset(i); } -template -void Shamir::init_mul(SubProcessor* proc) -{ - (void) proc; - init_mul(); -} - template void Shamir::init_mul() { @@ -96,13 +89,12 @@ void Shamir::init_mul() } template -typename T::clear Shamir::prepare_mul(const T& x, const T& y, int n) +void Shamir::prepare_mul(const T& x, const T& y, int n) { (void) n; auto add_share = x * y * rec_factor; if (P.my_num() < n_mul_players) resharing->add_mine(add_share); - return {}; } template @@ -157,9 +149,9 @@ T Shamir::finalize(int n_relevant_players) } template -void Shamir::init_dotprod(SubProcessor* proc) +void Shamir::init_dotprod() { - init_mul(proc); + init_mul(); dotprod_share = 0; } diff --git a/Protocols/ShuffleSacrifice.hpp b/Protocols/ShuffleSacrifice.hpp index fe509321c..81e859319 100644 --- a/Protocols/ShuffleSacrifice.hpp +++ b/Protocols/ShuffleSacrifice.hpp @@ -10,7 +10,6 @@ #include "Tools/PointerVector.h" #include "GC/BitAdder.h" -#include "MalRepRingPrep.hpp" #include "LimitedPrep.hpp" inline @@ -25,6 +24,21 @@ ShuffleSacrifice::ShuffleSacrifice(int B, int C) : { } +template +void ShuffleSacrifice::shuffle(vector& check_triples, Player& P) +{ + int buffer_size = check_triples.size(); + + // shuffle + GlobalPRNG G(P); + for (int i = 0; i < buffer_size; i++) + { + int remaining = buffer_size - i; + int pos = G.get_uint(remaining); + swap(check_triples[i], check_triples[i + pos]); + } +} + template void TripleShuffleSacrifice::triple_combine(vector >& triples, vector >& to_combine, Player& P, diff --git a/Protocols/Spdz2kPrep.h b/Protocols/Spdz2kPrep.h index 33883c66f..03a91ff25 100644 --- a/Protocols/Spdz2kPrep.h +++ b/Protocols/Spdz2kPrep.h @@ -26,7 +26,6 @@ class Spdz2kPrep : public virtual MaliciousRingPrep, MascotTriplePrep* bit_prep; SubProcessor* bit_proc; typename BitShare::MAC_Check* bit_MC; - typename BitShare::Protocol* bit_protocol; public: Spdz2kPrep(SubProcessor* proc, DataPositions& usage); @@ -41,8 +40,6 @@ class Spdz2kPrep : public virtual MaliciousRingPrep, #ifdef SPDZ2K_BIT void get_dabit(T& a, GC::TinySecret& b); #endif - - NamedCommStats comm_stats(); }; #endif /* PROTOCOLS_SPDZ2KPREP_H_ */ diff --git a/Protocols/Spdz2kPrep.hpp b/Protocols/Spdz2kPrep.hpp index f5c9cdce6..815277614 100644 --- a/Protocols/Spdz2kPrep.hpp +++ b/Protocols/Spdz2kPrep.hpp @@ -25,7 +25,6 @@ Spdz2kPrep::Spdz2kPrep(SubProcessor* proc, DataPositions& usage) : bit_MC = 0; bit_proc = 0; bit_prep = 0; - bit_protocol = 0; } template @@ -36,7 +35,6 @@ Spdz2kPrep::~Spdz2kPrep() delete bit_prep; delete bit_proc; delete bit_MC; - delete bit_protocol; } } @@ -50,10 +48,8 @@ void Spdz2kPrep::set_protocol(typename T::Protocol& protocol) // just dummies bit_pos = DataPositions(proc->P.num_players()); bit_prep = new MascotTriplePrep(bit_proc, bit_pos); - bit_proc = new SubProcessor(*bit_MC, *bit_prep, proc->P); bit_prep->params.amplify = false; - bit_protocol = new typename BitShare::Protocol(proc->P); - bit_prep->set_protocol(*bit_protocol); + bit_proc = new SubProcessor(*bit_MC, *bit_prep, proc->P); bit_MC->set_prep(*bit_prep); this->proc->MC.set_prep(*this); } @@ -65,7 +61,7 @@ void MaliciousRingPrep::buffer_bits() RingPrep::buffer_bits_without_check(); assert(this->protocol != 0); auto& protocol = *this->protocol; - protocol.init_dotprod(this->proc); + protocol.init_dotprod(); auto one = T::constant(1, protocol.P.my_num(), this->proc->MC.get_alphai()); GlobalPRNG G(protocol.P); for (auto& bit : this->bits) @@ -238,12 +234,29 @@ void MaliciousRingPrep::buffer_edabits_from_personal(bool strict, int n_bits, } template -NamedCommStats Spdz2kPrep::comm_stats() +void MaliciousRingPrep::buffer_edabits(bool strict, int n_bits, + ThreadQueues* queues) { - auto res = OTPrep::comm_stats(); - if (bit_prep) - res += bit_prep->comm_stats(); - return res; + RunningTimer timer; +#ifndef NONPERSONAL_EDA + this->buffer_edabits_from_personal(strict, n_bits, queues); +#else + assert(this->proc != 0); + ShuffleSacrifice shuffle_sacrifice; + typedef typename T::bit_type::part_type bit_type; + vector> bits; + vector sums; + this->buffer_edabits_without_check(n_bits, sums, bits, + shuffle_sacrifice.minimum_n_inputs(), queues); + vector>& checked = this->edabits[{strict, n_bits}]; + shuffle_sacrifice.edabit_sacrifice(checked, sums, bits, + n_bits, *this->proc, strict, -1, queues); + if (strict) + this->sanitize(checked, n_bits, -1, queues); +#endif +#ifdef VERBOSE_EDA + cerr << "Total edaBit generation took " << timer.elapsed() << " seconds" << endl; +#endif } #endif diff --git a/Protocols/SpdzWise.h b/Protocols/SpdzWise.h index afbf2c850..c12b4f5fe 100644 --- a/Protocols/SpdzWise.h +++ b/Protocols/SpdzWise.h @@ -38,22 +38,23 @@ class SpdzWise : public ProtocolBase SpdzWise(Player& P); virtual ~SpdzWise(); - Player& branch(); + typename T::Protocol branch(); - void init(SubProcessor* proc); + void init(Preprocessing&, typename T::MAC_Check& MC); - void init_mul(SubProcessor* proc); - typename T::clear prepare_mul(const T& x, const T& y, int n = -1); + void init_mul(); + void prepare_mul(const T& x, const T& y, int n = -1); void exchange(); T finalize_mul(int n = -1); - void init_dotprod(SubProcessor*); + void init_dotprod(); void prepare_dotprod(const T& x, const T& y); void next_dotprod(); T finalize_dotprod(int length); void add_to_check(const T& x); void check(); + void maybe_check(); int get_n_relevant_players() { return internal.get_n_relevant_players(); } diff --git a/Protocols/SpdzWise.hpp b/Protocols/SpdzWise.hpp index 40f3cee71..2ea08ba46 100644 --- a/Protocols/SpdzWise.hpp +++ b/Protocols/SpdzWise.hpp @@ -19,34 +19,40 @@ SpdzWise::~SpdzWise() } template -Player& SpdzWise::branch() +typename T::Protocol SpdzWise::branch() { - return P; + typename T::Protocol res(P); + res.mac_key = mac_key; + return res; +} + +template +void SpdzWise::init(Preprocessing&, typename T::MAC_Check& MC) +{ + mac_key = MC.get_alphai(); } template -void SpdzWise::init(SubProcessor* proc) +void SpdzWise::maybe_check() { - assert(proc != 0); - mac_key = proc->MC.get_alphai(); + assert(not mac_key.is_zero()); if ((int) results.size() >= OnlineOptions::singleton.batch_size) check(); } template -void SpdzWise::init_mul(SubProcessor* proc) +void SpdzWise::init_mul() { - init(proc); + maybe_check(); internal.init_mul(); internal2.init_mul(); } template -typename T::clear SpdzWise::prepare_mul(const T& x, const T& y, int) +void SpdzWise::prepare_mul(const T& x, const T& y, int) { internal.prepare_mul(x.get_share(), y.get_share()); internal.prepare_mul(x.get_mac(), y.get_share()); - return {}; } template @@ -67,9 +73,9 @@ void SpdzWise::exchange() } template -void SpdzWise::init_dotprod(SubProcessor* proc) +void SpdzWise::init_dotprod() { - init(proc); + maybe_check(); internal.init_dotprod(); internal2.init_dotprod(); } diff --git a/Protocols/SpdzWiseInput.hpp b/Protocols/SpdzWiseInput.hpp index ef7f549bf..e0d508e51 100644 --- a/Protocols/SpdzWiseInput.hpp +++ b/Protocols/SpdzWiseInput.hpp @@ -12,6 +12,7 @@ SpdzWiseInput::SpdzWiseInput(SubProcessor* proc, Player& P) : { assert(proc != 0); mac_key = proc->MC.get_alphai(); + checker.init(proc->DataF, proc->MC); } template @@ -76,7 +77,7 @@ void SpdzWiseInput::exchange() shares[i][j].set_mac(honest_mult.finalize_mul()); checker.results.push_back(shares[i][j]); } - checker.init(proc); + checker.maybe_check(); } template diff --git a/Protocols/SpdzWisePrep.hpp b/Protocols/SpdzWisePrep.hpp index f88e97d64..9cb86017a 100644 --- a/Protocols/SpdzWisePrep.hpp +++ b/Protocols/SpdzWisePrep.hpp @@ -9,19 +9,21 @@ #include "MaliciousShamirShare.h" #include "SquarePrep.h" #include "Math/gfp.h" +#include "ProtocolSet.h" #include "ReplicatedPrep.hpp" #include "Spdz2kPrep.hpp" #include "ShamirMC.hpp" #include "MaliciousRepPO.hpp" #include "MaliciousShamirPO.hpp" +#include "GC/RepPrep.hpp" template void SpdzWisePrep::buffer_triples() { assert(this->protocol != 0); assert(this->proc != 0); - this->protocol->init_mul(this->proc); + this->protocol->init_mul(); generate_triples_initialized(this->triples, OnlineOptions::singleton.batch_size, this->protocol); } @@ -38,8 +40,11 @@ void SpdzWisePrep>>::buffer_bits() { typedef MaliciousRep3Share part_type; vector bits; - typename part_type::Honest::Protocol protocol(this->protocol->P); - bits_from_random(bits, protocol); + ProtocolSet set(this->proc->P, {}); + auto& protocol = set.protocol; + auto& prep = set.preprocessing; + for (int i = 0; i < buffer_size; i++) + bits.push_back(prep.get_bit()); protocol.init_mul(); for (auto& bit : bits) protocol.prepare_mul(bit, this->proc->MC.get_alphai()); @@ -99,7 +104,7 @@ void SpdzWisePrep::buffer_inputs(int player) vector rs(OnlineOptions::singleton.batch_size); auto& P = this->proc->P; this->inputs.resize(P.num_players()); - this->protocol->init_mul(this->proc); + this->protocol->init_mul(); for (auto& r : rs) { r = this->protocol->get_random(); diff --git a/Protocols/SpdzWiseRing.hpp b/Protocols/SpdzWiseRing.hpp index 30904c386..36e638d14 100644 --- a/Protocols/SpdzWiseRing.hpp +++ b/Protocols/SpdzWiseRing.hpp @@ -36,7 +36,7 @@ void SpdzWiseRing::zero_check(check_type t) while(bits.size() > 1) { auto& protocol = zero_proc.protocol; - protocol.init_mul(&zero_proc); + protocol.init_mul(); for (int i = bits.size() - 2; i >= 0; i -= 2) protocol.prepare_mul(bits[i], bits[i + 1]); protocol.exchange(); diff --git a/Protocols/SquarePrep.h b/Protocols/SquarePrep.h index fcdc2c239..be0913b37 100644 --- a/Protocols/SquarePrep.h +++ b/Protocols/SquarePrep.h @@ -10,7 +10,7 @@ template void generate_squares(vector>& squares, int n_squares, - U* protocol, SubProcessor* proc); + U* protocol); template class SquarePrep : public BufferPrep @@ -22,8 +22,8 @@ class SquarePrep : public BufferPrep void buffer_squares() { - generate_squares(this->squares, this->buffer_size, &this->proc->protocol, - this->proc); + generate_squares(this->squares, this->buffer_size, + &this->proc->protocol); } public: diff --git a/README.md b/README.md index daa658a5f..bd1075121 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ The following table lists all protocols that are fully supported. | Semi-honest, honest majority | [Shamir / ATLAS / Rep3](#honest-majority) | [Rep3](#honest-majority) | [Rep3 / CCD](#honest-majority) | [BMR](#bmr) | See [this paper](https://eprint.iacr.org/2020/300) for an explanation -of the various security models and high-level introduction to +of the various security models and a high-level introduction to multi-party computation. ##### Finding the most efficient protocol @@ -131,8 +131,8 @@ there are a few things to consider: dot products. - Fixed-point multiplication: Three- and four-party replicated secret - sharing modulo a power of two allow a special probabilistic - truncation protocol (see [Dalskov et + sharing as well semi-honest full-threshold protocols allow a special + probabilistic truncation protocol (see [Dalskov et al.](https://eprint.iacr.org/2019/131) and [Dalskov et al.](https://eprint.iacr.org/2020/1330)). You can activate it by adding `program.use_trunc_pr = True` at the beginning of your diff --git a/Scripts/decompile.py b/Scripts/decompile.py new file mode 100755 index 000000000..0142ba69a --- /dev/null +++ b/Scripts/decompile.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +import sys, os + +sys.path.append('.') + +from Compiler.instructions_base import Instruction +from Compiler.program import * + +if len(sys.argv) <= 1: + print('Usage: %s ' % sys.argv[0]) + +for tapename in Program.read_tapes(sys.argv[1]): + with open('Programs/Bytecode/%s.asm' % tapename, 'w') as out: + for i, inst in enumerate(Tape.read_instructions(tapename)): + print(inst, '#', i, file=out) diff --git a/Scripts/memory-usage.py b/Scripts/memory-usage.py new file mode 100755 index 000000000..15959ee68 --- /dev/null +++ b/Scripts/memory-usage.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +import sys, os +import collections + +sys.path.append('.') + +from Compiler.program import * +from Compiler.instructions_base import * + +if len(sys.argv) <= 1: + print('Usage: %s ' % sys.argv[0]) + +res = collections.defaultdict(lambda: 0) +m = 0 + +for tapename in Program.read_tapes(sys.argv[1]): + for inst in Tape.read_instructions(tapename): + t = inst.type + if issubclass(t, DirectMemoryInstruction): + res[t.arg_format[0]] = max(inst.args[1].i + inst.size, + res[t.arg_format[0]]) + for arg in inst.args: + if isinstance(arg, RegisterArgFormat): + m = max(m, arg.i + inst.size) + +print (res) +print (m) + diff --git a/Scripts/run-common.sh b/Scripts/run-common.sh index 3c0891e61..7e5e6d449 100644 --- a/Scripts/run-common.sh +++ b/Scripts/run-common.sh @@ -34,39 +34,26 @@ run_player() { if ! test -e $SPDZROOT/logs; then mkdir $SPDZROOT/logs fi - if [[ $bin = Player-Online.x || $bin =~ 'party.x' ]]; then - params="$prog $* -pn $port -h localhost" - if [[ ! ($bin =~ 'rep' || $bin =~ 'brain' || $bin =~ 'yao') ]]; then - params="$params -N $players" - fi - else - params="$port localhost $prog $*" + params="$prog $* -pn $port -h localhost" + if $SPDZROOT/$bin 2>&1 | grep -q '^-N,'; then + params="$params -N $players" fi - rem=$(($players - 2)) if test "$prog"; then log_prefix=$prog- fi - for i in $(seq 0 $rem); do + set -o pipefail + for i in $(seq 0 $[players-1]); do >&2 echo Running $prefix $SPDZROOT/$bin $i $params log=$SPDZROOT/logs/$log_prefix$i $prefix $SPDZROOT/$bin $i $params 2>&1 | { if test $i = 0; then tee $log; else cat > $log; fi; } & + codes[$i]=$! + done + for i in $(seq 0 $[players-1]); do + wait ${codes[$i]} || return 1 done - last_player=$(($players - 1)) - i=$last_player - >&2 echo Running $prefix $SPDZROOT/$bin $last_player $params - $prefix $SPDZROOT/$bin $last_player $params > $SPDZROOT/logs/$log_prefix$last_player 2>&1 || return 1 - wait } -sleep 0.5 - -#mkdir /dev/shm/Player-Data - players=${PLAYERS:-2} SPDZROOT=${SPDZROOT:-.} - -#. Scripts/setup.sh - -mkdir logs 2> /dev/null diff --git a/Scripts/test_streaming.sh b/Scripts/test_streaming.sh index 0ff2fb336..62a493084 100755 --- a/Scripts/test_streaming.sh +++ b/Scripts/test_streaming.sh @@ -15,3 +15,7 @@ done ./stream-fake-mascot-triples.x & Scripts/mascot.sh test_thread_mul -f || exit 1 + +./stream-fake-mascot-triples.x & + +Scripts/mascot.sh test_thread_mul -f || exit 1 diff --git a/Scripts/tldr.sh b/Scripts/tldr.sh index 5dd4f45db..ed6c01441 100755 --- a/Scripts/tldr.sh +++ b/Scripts/tldr.sh @@ -27,7 +27,8 @@ if test "$flags"; then cpu=amd64 fi - cp -av bin/`uname`-$cpu/* . + cp -av bin/`uname`-$cpu/* . || { echo This only works with a release downloaded from https://github.com/data61/MP-SPDZ/releases 1>&2; exit 1; } fi mkdir Player-Data 2> /dev/null +exit 0 diff --git a/Tools/BitVector.cpp b/Tools/BitVector.cpp index 4ef3406f0..567e57885 100644 --- a/Tools/BitVector.cpp +++ b/Tools/BitVector.cpp @@ -9,6 +9,15 @@ #include #include +void BitVector::assign(const BitVector& K) +{ + if (nbits != K.nbits) + { + resize(K.nbits); + } + memcpy(bytes, K.bytes, nbytes); +} + void BitVector::resize_zero(size_t new_nbits) { size_t old_nbytes = nbytes; diff --git a/Tools/BitVector.h b/Tools/BitVector.h index 055610519..54d9ed109 100644 --- a/Tools/BitVector.h +++ b/Tools/BitVector.h @@ -33,14 +33,7 @@ class BitVector public: - void assign(const BitVector& K) - { - if (nbits != K.nbits) - { - resize(K.nbits); - } - memcpy(bytes, K.bytes, nbytes); - } + void assign(const BitVector& K); void assign_bytes(char* new_bytes, int len) { resize(len*8); diff --git a/Tools/Buffer.cpp b/Tools/Buffer.cpp index c669081f8..9dd15804c 100644 --- a/Tools/Buffer.cpp +++ b/Tools/Buffer.cpp @@ -26,7 +26,7 @@ void BufferBase::setup(ifstream* f, int length, const string& filename, bool BufferBase::is_pipe() { struct stat buf; - if (stat(filename.c_str(), &buf)) + if (stat(filename.c_str(), &buf) == 0) return S_ISFIFO(buf.st_mode); else return false; @@ -113,6 +113,17 @@ void BufferBase::prune() rename(tmp_name.c_str(), filename.c_str()); file->open(filename.c_str(), ios::in | ios::binary); } +#ifdef VERBOSE + else + { + cerr << "Not pruning " << filename << " because it's "; + if (file) + cerr << "closed"; + else + cerr << "unused"; + cerr << endl; + } +#endif } void BufferBase::purge() diff --git a/Tools/Bundle.h b/Tools/Bundle.h index ed4b982e3..7859e3e4c 100644 --- a/Tools/Bundle.h +++ b/Tools/Bundle.h @@ -31,7 +31,7 @@ class Bundle : public vector { } - void compare(Player& P) + void compare(PlayerBase& P) { P.unchecked_broadcast(*this); for (auto& os : *this) diff --git a/Tools/TimerWithComm.cpp b/Tools/TimerWithComm.cpp new file mode 100644 index 000000000..2a5e8e12a --- /dev/null +++ b/Tools/TimerWithComm.cpp @@ -0,0 +1,23 @@ +/* + * TimerWithComm.cpp + * + */ + +#include "TimerWithComm.h" + +void TimerWithComm::start(const NamedCommStats& stats) +{ + Timer::start(); + last_stats = stats; +} + +void TimerWithComm::stop(const NamedCommStats& stats) +{ + Timer::stop(); + total_stats += stats - last_stats; +} + +double TimerWithComm::mb_sent() +{ + return total_stats.sent * 1e-6; +} diff --git a/Tools/TimerWithComm.h b/Tools/TimerWithComm.h new file mode 100644 index 000000000..2f3976a20 --- /dev/null +++ b/Tools/TimerWithComm.h @@ -0,0 +1,23 @@ +/* + * TimerWithComm.h + * + */ + +#ifndef TOOLS_TIMERWITHCOMM_H_ +#define TOOLS_TIMERWITHCOMM_H_ + +#include "time-func.h" +#include "Networking/Player.h" + +class TimerWithComm : public Timer +{ + NamedCommStats total_stats, last_stats; + +public: + void start(const NamedCommStats& stats = {}); + void stop(const NamedCommStats& stats = {}); + + double mb_sent(); +}; + +#endif /* TOOLS_TIMERWITHCOMM_H_ */ diff --git a/Tools/benchmarking.cpp b/Tools/benchmarking.cpp new file mode 100644 index 000000000..e956f15ec --- /dev/null +++ b/Tools/benchmarking.cpp @@ -0,0 +1,15 @@ +/* + * benchmarking.cpp + * + */ + +#include "benchmarking.h" + +void insecure_fake() +{ +#if defined(INSECURE) or defined(INSECURE_FAKE) + cerr << "WARNING: insecure preprocessing" << endl; +#else + insecure("preprocessing"); +#endif +} diff --git a/Tools/benchmarking.h b/Tools/benchmarking.h index 0ca65b761..13fa9c365 100644 --- a/Tools/benchmarking.h +++ b/Tools/benchmarking.h @@ -8,6 +8,7 @@ #include #include +#include using namespace std; // call before insecure benchmarking functionality @@ -26,4 +27,6 @@ inline void insecure(string message, bool warning = true) #endif } +void insecure_fake(); + #endif /* TOOLS_BENCHMARKING_H_ */ diff --git a/Tools/octetStream.h b/Tools/octetStream.h index df920a302..cd90b0e94 100644 --- a/Tools/octetStream.h +++ b/Tools/octetStream.h @@ -35,7 +35,9 @@ class bigint; class FlexBuffer; /** - * Buffer for networking communication with a pointer for sequential reading + * Buffer for network communication with a pointer for sequential reading. + * When sent over the network or stored in a file, the length is prefixed + * as eight bytes in little-endian order. */ class octetStream { diff --git a/Tools/random.cpp b/Tools/random.cpp index 7a0cd1dab..7cf1924f3 100644 --- a/Tools/random.cpp +++ b/Tools/random.cpp @@ -13,7 +13,7 @@ using namespace std; PRNG::PRNG() : - cnt(0), n_cached_bits(0), cached_bits(0) + cnt(0), n_cached_bits(0), cached_bits(0), initialized(false) { #if defined(__AES__) || !defined(__x86_64__) #ifdef USE_AES @@ -83,6 +83,7 @@ void PRNG::SecureSeed(Player& player) void PRNG::InitSeed() { + initialized = true; #ifdef USE_AES if (useC) { aes_schedule(KeyScheduleC,seed); } @@ -122,6 +123,7 @@ void PRNG::print_state() const void PRNG::hash() { + assert(initialized); #ifndef USE_AES unsigned char tmp[RAND_SIZE + SEED_SIZE]; randombytes_buf_deterministic(tmp, sizeof tmp, seed); diff --git a/Tools/random.h b/Tools/random.h index d22be6e88..5e65d8350 100644 --- a/Tools/random.h +++ b/Tools/random.h @@ -61,6 +61,8 @@ class PRNG int n_cached_bits; word cached_bits; + bool initialized; + void hash(); // Hashes state to random and sets cnt=0 void next(); diff --git a/Utils/Fake-Offline.cpp b/Utils/Fake-Offline.cpp index e8026a952..f1158cfa6 100644 --- a/Utils/Fake-Offline.cpp +++ b/Utils/Fake-Offline.cpp @@ -387,7 +387,7 @@ int generate(ez::ezOptionParser& opt); int main(int argc, const char** argv) { - insecure("preprocessing"); + insecure_fake(); bigint::init_thread(); FakeParams params; diff --git a/Utils/binary-example.cpp b/Utils/binary-example.cpp new file mode 100644 index 000000000..45e5f3371 --- /dev/null +++ b/Utils/binary-example.cpp @@ -0,0 +1,140 @@ +/* + * binary-example.cpp + * + */ + +#include "GC/TinierSecret.h" +#include "GC/PostSacriSecret.h" +#include "GC/CcdSecret.h" +#include "GC/MaliciousCcdSecret.h" +#include "GC/AtlasSecret.h" +#include "GC/TinyMC.h" +#include "GC/VectorInput.h" +#include "GC/PostSacriBin.h" +#include "Protocols/ProtocolSet.h" + +#include "GC/ShareSecret.hpp" +#include "GC/CcdPrep.hpp" +#include "GC/TinierSharePrep.hpp" +#include "GC/RepPrep.hpp" +#include "GC/Secret.hpp" +#include "GC/TinyPrep.hpp" +#include "GC/ThreadMaster.hpp" +#include "Protocols/Atlas.hpp" +#include "Protocols/MaliciousRepPrep.hpp" +#include "Protocols/Share.hpp" +#include "Protocols/MaliciousRepMC.hpp" +#include "Protocols/Shamir.hpp" +#include "Protocols/fake-stuff.hpp" +#include "Machines/ShamirMachine.hpp" +#include "Machines/Rep4.hpp" + +template +void run(int argc, char** argv); + +int main(int argc, char** argv) +{ + // need player number and number of players + if (argc < 3) + { + cerr << "Usage: " << argv[0] + << " [protocol [bit length [threshold]]]" + << endl; + exit(1); + } + + string protocol = "Tinier"; + if (argc > 3) + protocol = argv[3]; + + if (protocol == "Tinier") + run>(argc, argv); + else if (protocol == "Rep3") + run(argc, argv); + else if (protocol == "Rep4") + run(argc, argv); + else if (protocol == "PS") + run(argc, argv); + else if (protocol == "Semi") + run(argc, argv); + else if (protocol == "CCD" or protocol == "MalCCD" or protocol == "Atlas") + { + int nparties = (atoi(argv[2])); + int threshold = (nparties - 1) / 2; + if (argc > 5) + threshold = atoi(argv[5]); + assert(2 * threshold < nparties); + ShamirOptions::s().threshold = threshold; + ShamirOptions::s().nparties = nparties; + + if (protocol == "CCD") + run>>(argc, argv); + else if (protocol == "MalCCD") + run>(argc, argv); + else + run(argc, argv); + } + else + { + cerr << "Unknown protocol: " << protocol << endl; + exit(1); + } +} + +template +void run(int argc, char** argv) +{ + // run 16-bit computation by default + int n_bits = 16; + if (argc > 4) + n_bits = atoi(argv[4]); + + // set up networking on localhost + int my_number = atoi(argv[1]); + int n_parties = atoi(argv[2]); + int port_base = 9999; + Names N(my_number, n_parties, "localhost", port_base); + CryptoPlayer P(N); + + // protocol setup (domain, MAC key if needed etc) + BinaryProtocolSetup setup(P); + + // set of protocols (input, multiplication, output) + BinaryProtocolSet set(P, setup); + auto& input = set.input; + auto& protocol = set.protocol; + auto& output = set.output; + + int n = 10; + vector a(n), b(n); + + input.reset_all(P); + for (int i = 0; i < n; i++) + input.add_from_all(i + P.my_num(), n_bits); + input.exchange(); + for (int i = 0; i < n; i++) + { + a[i] = input.finalize(0, n_bits); + b[i] = input.finalize(1, n_bits); + } + + protocol.init_mul(); + for (int i = 0; i < n; i++) + protocol.prepare_mul(a[i], b[i], n_bits); + protocol.exchange(); + output.init_open(P, n); + for (int i = 0; i < n; i++) + { + auto c = protocol.finalize_mul(n_bits); + output.prepare_open(c); + } + output.exchange(P); + + cout << "result: "; + for (int i = 0; i < n; i++) + cout << output.finalize_open() << " "; + cout << endl; + + protocol.check(); + output.Check(P); +} diff --git a/Utils/mixed-example.cpp b/Utils/mixed-example.cpp new file mode 100644 index 000000000..532d705e4 --- /dev/null +++ b/Utils/mixed-example.cpp @@ -0,0 +1,137 @@ +/* + * mixed-example.cpp + * + */ + +#include "Protocols/ProtocolSet.h" + +#include "Machines/SPDZ.hpp" +#include "Machines/Semi2k.hpp" +#include "Machines/Rep.hpp" +#include "Machines/Rep4.hpp" +#include "Machines/Atlas.hpp" + +template +void run(char** argv); + +int main(int argc, char** argv) +{ + // need player number and number of players + if (argc < 3) + { + cerr << "Usage: " << argv[0] + << " [protocol]" + << endl; + exit(1); + } + + string protocol = "SPDZ2k"; + if (argc > 3) + protocol = argv[3]; + + if (protocol == "SPDZ2k") + run>(argv); + else if (protocol == "Semi2k") + run>(argv); + else if (protocol == "Rep3") + run>(argv); + else if (protocol == "Rep4") + run>(argv); + else if (protocol == "Atlas") + run>>(argv); + else + { + cerr << "Unknown protocol: " << protocol << endl; + exit(1); + } +} + +template +void run(char** argv) +{ + // reduce batch size + OnlineOptions::singleton.bucket_size = 5; + OnlineOptions::singleton.batch_size = 100; + + // set up networking on localhost + int my_number = atoi(argv[1]); + int n_parties = atoi(argv[2]); + int port_base = 9999; + Names N(my_number, n_parties, "localhost", port_base); + CryptoPlayer P(N); + + // protocol setup (domain, MAC key if needed etc) + MixedProtocolSetup setup(P); + + // set of protocols (bit_input, multiplication, output) + MixedProtocolSet set(P, setup); + auto& output = set.output; + auto& bit_input = set.binary.input; + auto& bit_protocol = set.binary.protocol; + auto& bit_output = set.binary.output; + auto& prep = set.preprocessing; + + int n = 10; + int n_bits = 16; + vector a(n), b(n); + + // inputs in binary domain + bit_input.reset_all(P); + for (int i = 0; i < n; i++) + bit_input.add_from_all(i + P.my_num(), n_bits); + bit_input.exchange(); + for (int i = 0; i < n; i++) + { + a[i] = bit_input.finalize(0, n_bits); + b[i] = bit_input.finalize(1, n_bits); + } + + // compute AND in binary domain + bit_protocol.init_mul(); + for (int i = 0; i < n; i++) + bit_protocol.prepare_mul(a[i], b[i], n_bits); + bit_protocol.exchange(); + bit_protocol.check(); + bit_output.init_open(P, n * n_bits); + PointerVector> dabits; + for (int i = 0; i < n; i++) + { + auto c = bit_protocol.finalize_mul(n_bits); + + // mask result with dabits and open + for (int j = 0; j < n_bits; j++) + { + dabits.push_back({}); + auto& dabit = dabits.back(); + prep.get_dabit(dabit.first, dabit.second); + bit_output.prepare_open( + typename T::bit_type::part_type( + dabit.second.get_bit(0) + c.get_bit(j))); + } + } + bit_output.exchange(P); + output.init_open(P, n); + for (int i = 0; i < n; i++) + { + T res; + // unmask via XOR and recombine + for (int j = 0; j < n_bits; j++) + { + typename T::clear masked = bit_output.finalize_open().get_bit(0); + auto mask = dabits.next().first; + res += (mask - mask * masked * 2 + + T::constant(masked, P.my_num(), setup.get_mac_key())) + << j; + } + output.prepare_open(res); + } + output.exchange(P); + bit_output.Check(P); + + cout << "result: "; + for (int i = 0; i < n; i++) + cout << output.finalize_open() << " "; + cout << endl; + + output.Check(P); +} diff --git a/Utils/paper-example.cpp b/Utils/paper-example.cpp index 87247fee8..9cae6953f 100644 --- a/Utils/paper-example.cpp +++ b/Utils/paper-example.cpp @@ -11,8 +11,10 @@ #include "Machines/SPDZ.hpp" #include "Machines/MalRep.hpp" #include "Machines/ShamirMachine.hpp" +#include "Machines/Semi2k.hpp" #include "Protocols/CowGearShare.h" #include "Protocols/CowGearPrep.hpp" +#include "Protocols/ProtocolSet.h" template void run(char** argv, int prime_length); @@ -42,6 +44,8 @@ int main(int argc, char** argv) run>>(argv, prime_length); else if (protocol == "SPDZ2k") run>(argv, 0); + else if (protocol == "Semi2k") + run>(argv, 0); else if (protocol == "Shamir" or protocol == "MalShamir") { int nparties = (atoi(argv[2])); @@ -74,35 +78,14 @@ void run(char** argv, int prime_length) Names N(my_number, n_parties, "localhost", port_base); CryptoPlayer P(N); - // initialize fields - T::clear::init_default(prime_length); - T::clear::next::init_default(prime_length, false); + // protocol setup (domain, MAC key if needed etc) + ProtocolSetup setup(P, prime_length); - // must initialize MAC key for security of some protocols - typename T::mac_key_type mac_key; - T::read_or_generate_mac_key("", P, mac_key); - - // global OT setup - BaseMachine machine; - if (T::needs_ot) - machine.ot_setups.push_back({P}); - - // keeps tracks of preprocessing usage (triples etc) - DataPositions usage; - usage.set_num_players(P.num_players()); - - // output protocol - typename T::MAC_Check output(mac_key); - - // various preprocessing - typename T::LivePrep preprocessing(0, usage); - SubProcessor processor(output, preprocessing, P); - - // input protocol - typename T::Input input(processor, output); - - // multiplication protocol - typename T::Protocol protocol(P); + // set of protocols (input, multiplication, output) + ProtocolSet set(P, setup); + auto& input = set.input; + auto& protocol = set.protocol; + auto& output = set.output; int n = 1000; vector a(n), b(n); @@ -119,19 +102,23 @@ void run(char** argv, int prime_length) b[i] = input.finalize(1); } - protocol.init_dotprod(&processor); + protocol.init_dotprod(); for (int i = 0; i < n; i++) protocol.prepare_dotprod(a[i], b[i]); protocol.next_dotprod(); protocol.exchange(); c = protocol.finalize_dotprod(n); + + // protocol check before revealing results + protocol.check(); + output.init_open(P); output.prepare_open(c); output.exchange(P); result = output.finalize_open(); cout << "result: " << result << endl; - output.Check(P); - T::LivePrep::teardown(); + // result check after opening + output.Check(P); } diff --git a/Utils/stream-fake-mascot-triples.cpp b/Utils/stream-fake-mascot-triples.cpp index 5aa85a054..517056e72 100644 --- a/Utils/stream-fake-mascot-triples.cpp +++ b/Utils/stream-fake-mascot-triples.cpp @@ -27,13 +27,18 @@ void* run(void* arg) int count = 0; while (true) { - gfpvar triple[3]; - for (int i = 0; i < 2; i++) - triple[i].randomize(G); - triple[2] = triple[0] * triple[1]; - for (int i = 0; i < 3; i++) - files.output_shares(triple[i]); - count++; + for (int i = 0; i < 100000; i++) + { + gfpvar triple[3]; + for (int i = 0; i < 2; i++) + triple[i].randomize(G); + triple[2] = triple[0] * triple[1]; + for (int i = 0; i < 3; i++) + files.output_shares(triple[i]); + count++; + } + // take a break to make them wait + sleep(1); } cerr << "failed after " << count << endl; return 0; @@ -41,7 +46,7 @@ void* run(void* arg) int main() { - insecure("preprocessing"); + insecure_fake(); typedef Share T; int nplayers = 2; int lgp = 128; diff --git a/Yao/YaoEvaluator.h b/Yao/YaoEvaluator.h index 074fb3400..749ba2878 100644 --- a/Yao/YaoEvaluator.h +++ b/Yao/YaoEvaluator.h @@ -58,9 +58,6 @@ class YaoEvaluator: public GC::Thread>, int get_n_worker_threads() { return max(1u, thread::hardware_concurrency() / master.machine.nthreads); } - - NamedCommStats comm_stats() - { return super::comm_stats() + player.comm_stats; } }; inline void YaoEvaluator::load_gate(YaoGate& gate) diff --git a/Yao/YaoGarbler.cpp b/Yao/YaoGarbler.cpp index e6ae6cda1..647369a15 100644 --- a/Yao/YaoGarbler.cpp +++ b/Yao/YaoGarbler.cpp @@ -120,8 +120,3 @@ void YaoGarbler::process_receiver_inputs() receiver_input_keys.pop_front(); } } - -NamedCommStats YaoGarbler::comm_stats() -{ - return super::comm_stats() + player.comm_stats; -} diff --git a/Yao/YaoGarbler.h b/Yao/YaoGarbler.h index 038fe432f..0608336c8 100644 --- a/Yao/YaoGarbler.h +++ b/Yao/YaoGarbler.h @@ -71,8 +71,6 @@ class YaoGarbler: public GC::Thread>, int get_threshold() { return master.threshold; } long get_gate_id() { return gate_id(thread_num); } - - NamedCommStats comm_stats(); }; inline YaoGarbler& YaoGarbler::s() diff --git a/Yao/YaoWire.h b/Yao/YaoWire.h index ddaf3b9c2..92f3ec614 100644 --- a/Yao/YaoWire.h +++ b/Yao/YaoWire.h @@ -23,6 +23,10 @@ class YaoWire : public Phase static void xors(GC::Processor& processor, const vector& args, size_t start, size_t end); + template + static void andm(GC::Processor& processor, + const BaseInstruction& instruction); + void XOR(const YaoWire& left, const YaoWire& right) { key_ = left.key_ ^ right.key_; diff --git a/Yao/YaoWire.hpp b/Yao/YaoWire.hpp index bb3b14068..aa04fe357 100644 --- a/Yao/YaoWire.hpp +++ b/Yao/YaoWire.hpp @@ -46,4 +46,24 @@ void YaoWire::xors(GC::Processor& processor, const vector& args, processor.xors(args, start, end); } +template +void YaoWire::andm(GC::Processor& processor, + const BaseInstruction& instruction) +{ + + int unit = GC::Clear::N_BITS; + for (int i = 0; i < DIV_CEIL(instruction.get_n(), unit); i++) + { + auto &dest = processor.S[instruction.get_r(0) + i]; + int n = min(unsigned(unit), instruction.get_n() - i * unit); + dest.resize_regs(n); + for (int j = 0; j < n; j++) + if (processor.C[instruction.get_r(2) + i].get_bit(j)) + dest.get_reg(j) = + processor.S[instruction.get_r(1) + i].get_reg(j); + else + dest.get_reg(j).public_input(0); + } +} + #endif /* YAO_YAOWIRE_HPP_ */ diff --git a/doc/Doxyfile b/doc/Doxyfile index 771f8cf13..3dd299405 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -829,7 +829,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h +INPUT = ../Networking ../Tools/octetStream.h ../Processor/Data_Files.h ../Protocols/Replicated.h ../Protocols/ReplicatedPrep.h ../Protocols/MAC_Check_Base.h ../Processor/Input.h ../ExternalIO/Client.h ../Protocols/ProtocolSet.h ../Protocols/ProtocolSetup.h # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/doc/conf.py b/doc/conf.py index 57f730add..86bb12d46 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -21,7 +21,7 @@ # -- Project information ----------------------------------------------------- project = u'MP-SPDZ' -copyright = u'2021, CSIRO\'s Data61' +copyright = u'2022, CSIRO\'s Data61' author = u'Marcel Keller' # The short X.Y version @@ -185,7 +185,8 @@ breathe_projects = {'mp-spdz': 'xml'} breathe_default_project = 'mp-spdz' import subprocess -subprocess.call('doxygen', shell=True) +if (subprocess.call('doxygen', shell=True)): + raise Exception('doxygen failed') def setup(app): app.add_css_file('custom.css') diff --git a/doc/index.rst b/doc/index.rst index d7a13e941..d2a2c4dcd 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,10 +1,16 @@ Welcome to MP-SPDZ's documentation! =================================== -This documentation provides a reference to the most important -high-level functionality provided by the MP-SPDZ compiler. For a -tutorial and documentation on how to run programs, the -implemented protocols etc. see https://github.com/data61/MP-SPDZ. +If you're new to MP-SPDZ, consider the following: + +1. `Quickstart tutorial `_ +2. `Implemented protocols `_ +3. :ref:`troubleshooting` + +Unlike the `Readme +`_, this +documentation provides a reference for more detailed aspects of the +software. Compilation process ------------------- diff --git a/doc/io.rst b/doc/io.rst index 5184ab338..a4d00cee8 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -83,6 +83,8 @@ covering both client code and server-side high-level code. :py:func:`Compiler.types.MultiArray.reveal_to_clients`. The same functions are available for :py:class:`~Compiler.types.sfix` and :py:class:`~Compiler.types.Array`, respectively. +See also :ref:`client ref` below. + Secret Shares ~~~~~~~~~~~~~ @@ -114,3 +116,11 @@ etc. Note also that all types based on :py:class:`~Compiler.types.sfix`) share the same memory, and that the address is only a base address. This means that vectors will be written to the memory starting at the given address. + +.. _client ref: + +Reference +~~~~~~~~~ + +.. doxygenclass:: Client + :members: diff --git a/doc/low-level.rst b/doc/low-level.rst index 0aaf3708e..c70bf5b65 100644 --- a/doc/low-level.rst +++ b/doc/low-level.rst @@ -83,109 +83,24 @@ number of parties. .. code-block:: cpp - // initialize fields - T::clear::init_default(prime_length); + ProtocolSetup setup(P, prime_length); We have to use a specific prime for computation modulo a prime. This deterministically generates one of the desired length if necessary. For computation modulo a power of two, this does not do -anything. +anything. Some protocols use an information-theoretic tag that is +constant throughout the protocol. This code reads it from storage if +available or generates a fresh one otherwise. .. code-block:: cpp - T::clear::next::init_default(prime_length, false); + ProtocolSet set(P, setup); + auto& input = set.input; + auto& protocol = set.protocol; + auto& output = set.output; -For computation modulo a prime, it is more efficient to use Montgomery -representation, which is not compatible with the MASCOT offline phase -however. This line initializes another field instance for MASCOT -without using Montgomery representation. - -.. code-block:: cpp - - // must initialize MAC key for security of some protocols - typename T::mac_key_type mac_key; - T::read_or_generate_mac_key("", P, mac_key); - -Some protocols use an information-theoretic tag that is constant -throughout the protocol. This codes reads it from storage if available -or generates a fresh one otherwise. - -.. code-block:: cpp - - // global OT setup - BaseMachine machine; - if (T::needs_ot) - machine.ot_setups.push_back({P}); - -Many protocols for a dishonest majority use oblivious transfer. This -block runs a few instances to seed the oblivious transfer -extension. The resulting setup only works for one thread. For several -threads, you need to add sufficiently many instances to -:member:`ot_setups` and set :member:`BaseMachine::thread_num` -(thread-local) to a different consecutive number in every thread. - -.. code-block:: cpp - - // keeps tracks of preprocessing usage (triples etc) - DataPositions usage; - usage.set_num_players(P.num_players()); - -To help keeping track of the required preprocessing, it is necessary -to initialize preprocessing instances with a :class:`DataPositions` -variable that will store the usage. - -.. code-block:: cpp - - // initialize binary computation - T::bit_type::mac_key_type::init_field(); - typename T::bit_type::mac_key_type binary_mac_key; - T::bit_type::part_type::read_or_generate_mac_key("", P, binary_mac_key); - GC::ShareThread thread(N, - OnlineOptions::singleton, P, binary_mac_key, usage); - -While this example only uses arithmetic computation, you need to -initialize binary computation as well unless you use the compile-time -option ``NO_MIXED_CIRCUITS``. - -.. code-block:: cpp - - // output protocol - typename T::MAC_Check output(mac_key); - -Some output protocols use the MAC key to check the correctness. - -.. code-block:: cpp - - // various preprocessing - typename T::LivePrep preprocessing(0, usage); - SubProcessor processor(output, preprocessing, P); - -In this example we use live preprocessing, but it is also possible to -read preprocessing data from disk by using :class:`Sub_Data_Files` -instead. You can use a live preprocessing instances to generate -preprocessing data independently, but many protocols require that a -:class:`SubProcessor` instance has been created as well. The latter -essentially glues an instance of the output and the preprocessing -protocol together, which is necessary for Beaver-based multiplication -protocols. - -.. code-block:: cpp - - // input protocol - typename T::Input input(processor, output); - -Some input protocols depend on preprocessing and an output protocol, -which is reflect in the standard constructor. Other constructors are -available depending on the protocol. - -.. code-block:: cpp - - // multiplication protocol - typename T::Protocol protocol(P); - -This instantiates a multiplication protocol. :var:`P` is required -because some protocols start by exchanging keys for pseudo-random -secret sharing. +The :class:`ProtocolSet` contains one instance for every essential +protocol step. .. code-block:: cpp @@ -235,6 +150,14 @@ The initialization of the multiplication sets the preprocessing and output instances to use in Beaver multiplication. :func:`next_dotprod` separates dot products in the data preparation phase. +.. code-block:: cpp + + protocol.check(); + +Some protocols require a check of all multiplications up to a certain +point. To guarantee that outputs do not reveal secret information, it +has to be run before using the output protocol. + .. code-block:: cpp output.init_open(P); @@ -245,8 +168,8 @@ separates dot products in the data preparation phase. cout << "result: " << result << endl; output.Check(P); -The output protocol follows the same blueprint except that it is -necessary to call the checking in order to verify the outputs. +The output protocol follows the same blueprint as the multiplication +protocol. .. code-block:: cpp @@ -281,6 +204,9 @@ Domain Types the time of writing, 4, 8, 28, 40, 63, and 128 are supported if the storage type is large enough. + +.. _share-type-reference: + Share Types ------------ @@ -385,6 +311,28 @@ Share Types ``MaliciousShamirShare`` or ``MaliciousRep3Share``. +Protocol Setup +-------------- + +.. doxygenclass:: ProtocolSetup + :members: + +.. doxygenclass:: ProtocolSet + :members: + +.. doxygenclass:: BinaryProtocolSetup + :members: + +.. doxygenclass:: BinaryProtocolSet + :members: + +.. doxygenclass:: MixedProtocolSetup + :members: + +.. doxygenclass:: MixedProtocolSet + :members: + + Protocol Interfaces ------------------- diff --git a/doc/networking.rst b/doc/networking.rst index 16908681a..a1c61b98d 100644 --- a/doc/networking.rst +++ b/doc/networking.rst @@ -18,7 +18,7 @@ individually setting ports: coordination server being run as a thread of party 0. The hostname of the coordination server has to be given with the command-line parameter ``--hostname``, and the coordination server runs on the - base port number minus one, thus defaulting to 4999. Furthermore, you + base port number, thus defaulting to 5000. Furthermore, you can specify a party's listening port using ``--my-port``. 2. The parties read the information from a local file, which needs to @@ -40,7 +40,9 @@ change this by either using ``--encrypted/-e`` or If using encryption, the certificates (``Player-Data/*.pem``) must be the same on all hosts, and you have to run ``c_rehash Player-Data`` on -all of them. +all of them. ``Scripts/setup-ssl.sh`` can be used to generate the +necessary certificates. The common name has to be ``P`` +for computing parties and ``C`` for clients. .. _network-reference: diff --git a/doc/non-linear.rst b/doc/non-linear.rst index 5fe8df1f6..bcdbbd3ae 100644 --- a/doc/non-linear.rst +++ b/doc/non-linear.rst @@ -7,8 +7,8 @@ domains (modulus other than two) only comes in three flavors throughout MP-SPDZ: Unknown prime modulus - This approach goes back to `Catrina and Saxena - `_. It crucially relies on + This approach goes back to `Catrina and de Hoogh + `_. It crucially relies on the use of secret random bits in the arithmetic domain. Enough such bits allow to mask a secret value so that it is secure to reveal the masked value. This can then be split in bits as it is diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 3dadcfae6..1441e3524 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -16,7 +16,7 @@ is a thread created by control flow instructions such as The exceptions to the general rule are edaBit generation with malicious security and AND triples with malicious security and honest -majority, both when use bucket size three. Bucket size three implies +majority, both when using bucket size three. Bucket size three implies batches of over a million to achieve 40-bit statistical security, and in honest-majority binary computation the item size is 64, which makes the actual batch size 64 million triples. In multithreaded programs, @@ -27,3 +27,65 @@ jump whenever another batch is generated. Note that, while some protocols are flexible with the batch size and can thus be controlled using ``-b``, others mandate a batch size, which can be as large as a million. + + +Separate preprocessing +====================== + +It is possible to separate out the preprocessing from the +input-dependent ("online") phase. This is done by either option ``-F`` +or ``-f`` on the virtual machines. In both cases, the preprocessing +data is read from files, either all data per type from a single file +(``-F``) or one file per thread (``-f``). The latter allows to use +named pipes. + +The file name depends on the protocol and the computation domain. It +is generally ``/--/--P[-T]``. For example, the +triples for party 1 in SPDZ modulo a 128-bit prime can be found in +``Player-Data/2-p-128/Triples-p-P1``. The protocol shorthand can be +found by calling ``::type_short()``. See +:ref:`share-type-reference` for a description of the share types. + +Preprocessing files start with a header describing the protocol and +computation domain to avoid errors due to mismatches. The header is as +follows: + +- Length to follow (little-endian 8-byte number) +- Protocol descriptor +- Domain descriptor + +The protocol descriptor is defined by ``::type_string()``. For SPDZ modulo a prime it is ``SPDZ gfp``. + +The domain descriptor depends on the kind of domain: + +Modulo a prime + Serialization of the prime + + - Sign bit (0 as 1 byte) + - Length to follow (little-endian 4-byte number) + - Prime (big-endian) + +Modulo a power of two: + Exponent (little-endian 4-byte number) + +:math:`GF(2^n)` + - Storage size in bytes (little-endian 8-byte number). Default is 16. + - :math:`n` (little-endian 4-byte number) + +As an example, the following output of ``hexdump -C`` describes SPDZ +modulo the default 128-bit prime +(170141183460469231731687303715885907969):: + + 00000000 1d 00 00 00 00 00 00 00 53 50 44 5a 20 67 66 70 |........SPDZ gfp| + 00000010 00 10 00 00 00 80 00 00 00 00 00 00 00 00 00 00 |................| + 00000020 00 00 1b 80 01 |.....| + 00000025 + + +``Fake-Offline.x`` generates preprocessing data insecurely for a range +of protocols, and ``{mascot,cowgear,mal-shamir}-offline.x`` generate +sufficient preprocessing data for a specific high-level program with +MASCOT, CowGear, and malicious Shamir secret sharing, respectively. diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 1c096d985..6a79ea198 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -1,3 +1,5 @@ +.. _troubleshooting: + Troubleshooting --------------- @@ -57,10 +59,23 @@ second batch is necessary the cost shoots up. Other preprocessing methods allow for a variable batch size, which can be changed using ``-b``. Smaller batch sizes generally reduce the communication cost while potentially increasing the number of communication rounds. Try -adding ``-b 10`` to the virtal machine (or script) arguments for very +adding ``-b 10`` to the virtual machine (or script) arguments for very short computations. +Disparities in round figures +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The number of virtual machine rounds given by the compiler are not an +exact prediction of network rounds but the number of relevant protocol +calls (such as multiplication, input, output etc) in the program. The +actual number of network rounds is determined by the choice of +protocol, which might use several rounds per protocol +call. Furthermore, communication at the beginning and the end of a +computation such as random key distribution and MAC checks further +increase the number of network rounds. + + Handshake failures ~~~~~~~~~~~~~~~~~~ @@ -82,8 +97,8 @@ use the client facility. Connection failures ~~~~~~~~~~~~~~~~~~~ -MP-SPDZ requires at least one TCP port per party to be open to other -parties. In the default setting, it's 4999 and 5000 on party 0, and +MP-SPDZ requires one TCP port per party to be open to other +parties. In the default setting, it's 5000 on party 0, and 5001 on party 1 etc. You change change the base port (5000) using ``--portnumbase`` and individual ports for parties using ``--my-port``. The scripts in use a random base port number, which you