Skip to content

Commit

Permalink
Maintenance.
Browse files Browse the repository at this point in the history
  • Loading branch information
mkskeller committed Jan 11, 2022
1 parent cdb0c0f commit e07d9bf
Show file tree
Hide file tree
Showing 216 changed files with 2,406 additions and 1,113 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
url = https://github.com/mkskeller/SimpleOT
[submodule "mpir"]
path = mpir
url = git://github.com/wbhart/mpir.git
url = https://github.com/wbhart/mpir
[submodule "Programs/Circuits"]
path = Programs/Circuits
url = https://github.com/mkskeller/bristol-fashion
Expand Down
2 changes: 1 addition & 1 deletion BMR/Party.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ ProgramParty::~ProgramParty()
reset();
if (P)
{
cerr << "Data sent: " << 1e-6 * P->comm_stats.total_data() << " MB" << endl;
cerr << "Data sent: " << 1e-6 * P->total_comm().total_data() << " MB" << endl;
delete P;
}
delete[] eval_threads;
Expand Down
2 changes: 1 addition & 1 deletion BMR/RealGarbleWire.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ void GarbleInputter<T>::exchange()
assert(party.P != 0);
assert(party.MC != 0);
auto& protocol = party.shared_proc->protocol;
protocol.init_mul(party.shared_proc);
protocol.init_mul();
for (auto& tuple : tuples)
protocol.prepare_mul(tuple.first->mask,
T::constant(1, party.P->my_num(), party.mac_key)
Expand Down
8 changes: 5 additions & 3 deletions BMR/RealProgramParty.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
while (next != GC::DONE_BREAK);

MC->Check(*P);
data_sent = P->comm_stats.total_data() + prep->data_sent();
data_sent = P->total_comm().sent;

this->machine.write_memory(this->N.my_num());
}
Expand All @@ -173,15 +173,17 @@ void RealProgramParty<T>::garble()
garble_jobs.clear();
garble_inputter->reset_all(*P);
auto& protocol = *garble_protocol;
protocol.init_mul(shared_proc);
protocol.init(*prep, shared_proc->MC);
protocol.init_mul();

next = this->first_phase(program, garble_processor, this->garble_machine);

garble_inputter->exchange();
protocol.exchange();

typename T::Protocol second_protocol(*P);
second_protocol.init_mul(shared_proc);
second_protocol.init(*prep, shared_proc->MC);
second_protocol.init_mul();
for (auto& job : garble_jobs)
job.middle_round(*this, second_protocol);

Expand Down
3 changes: 3 additions & 0 deletions BMR/Register.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ class ProgramRegister : public Phase, public Register
template<class U>
static void convcbit2s(GC::Processor<U>&, const BaseInstruction&)
{ throw runtime_error("convcbit2s not implemented"); }
template<class U>
static void andm(GC::Processor<U>&, const BaseInstruction&)
{ throw runtime_error("andm not implemented"); }

// most BMR phases don't need actual input
template<class T>
Expand Down
6 changes: 6 additions & 0 deletions BMR/TrustedParty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ BaseTrustedParty::BaseTrustedParty()
_received_gc_received = 0;
n_received = 0;
randomfd = open("/dev/urandom", O_RDONLY);
done_filling = false;
}

BaseTrustedParty::~BaseTrustedParty()
{
close(randomfd);
}

TrustedProgramParty::TrustedProgramParty(int argc, char** argv) :
Expand Down
3 changes: 1 addition & 2 deletions BMR/TrustedParty.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BaseTrustedParty : virtual public CommonFakeParty {
vector<SendBuffer> msg_input_masks;

BaseTrustedParty();
virtual ~BaseTrustedParty() {}
virtual ~BaseTrustedParty();

/* From NodeUpdatable class */
virtual void NodeReady();
Expand Down Expand Up @@ -104,7 +104,6 @@ class TrustedProgramParty : public BaseTrustedParty {
void add_all_keys(const Register& reg, bool external);
};


inline void BaseTrustedParty::add_keys(const Register& reg)
{
for(int p = 0; p < get_n_parties(); p++)
Expand Down
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.

## 0.2.9 (Jan 11, 2021)

- Disassembler
- Run-time parameter for probabilistic truncation error
- Probabilistic truncation for some protocols computing modulo a prime
- Simplified C++ interface
- Comparison as in [ACCO](https://dl.acm.org/doi/10.1145/3474123.3486757)
- More general scalar-vector multiplication
- Complete memory support for clear bits
- Extended clear bit functionality with Yao's garbled circuits
- Allow preprocessing information to be supplied via named pipes
- In-place operations for containers

## 0.2.8 (Nov 4, 2021)

- Tested on Apple laptop with ARM chip
Expand Down
10 changes: 8 additions & 2 deletions Compiler/GC/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,16 @@ def load_mem(cls, address, mem_type=None, size=None):
return cls.load_dynamic_mem(address)
else:
for i in range(res.size):
cls.load_inst[util.is_constant(address)](res[i], address + i)
cls.mem_op(cls.load_inst, res[i], address + i)
return res
def store_in_mem(self, address):
self.store_inst[isinstance(address, int)](self, address)
self.mem_op(self.store_inst, self, address)
@staticmethod
def mem_op(inst, reg, address):
direct = isinstance(address, int)
if not direct:
address = regint.conv(address)
inst[direct](reg, address)
@classmethod
def new(cls, value=None, n=None):
if util.is_constant(value):
Expand Down
11 changes: 5 additions & 6 deletions Compiler/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,16 @@ def LTZ(s, a, k, kappa):
k: bit length of a
"""
movs(s, program.non_linear.ltz(a, k, kappa))

def LtzRing(a, k):
from .types import sint, _bitint
from .GC.types import sbitvec
if program.use_split():
summands = a.split_to_two_summands(k)
carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands)))
msb = carry ^ summands[0][-1] ^ summands[1][-1]
movs(s, sint.conv(msb))
return sint.conv(msb)
return
elif program.options.ring:
from . import floatingpoint
Expand All @@ -96,11 +99,7 @@ def LTZ(s, a, k, kappa):
a = r_bin[0].bit_decompose_clear(c_prime, m)
b = r_bin[:m]
u = CarryOutRaw(a[::-1], b[::-1])
movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)))
return
t = sint()
Trunc(t, a, k, k - 1, kappa, True)
subsfi(s, t, 0)
return sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u))

def LessThanZero(a, k, kappa):
from . import types
Expand Down
2 changes: 1 addition & 1 deletion Compiler/compilerLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def run(args, options):
prog.finalize()

if prog.req_num:
print('Program requires:')
print('Program requires at most:')
for x in prog.req_num.pretty():
print(x)

Expand Down
5 changes: 4 additions & 1 deletion Compiler/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@ class ArgumentError(CompilerError):
""" Exception raised for errors in instruction argument parsing. """
def __init__(self, arg, msg):
self.arg = arg
self.msg = msg
self.msg = msg

class VectorMismatch(CompilerError):
pass
14 changes: 8 additions & 6 deletions Compiler/floatingpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
for i in range(1,l):
ci[i] = c % two_power(i)
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
lts(d, c_dprime, r_prime, l, kappa)
d = program.Program.prog.non_linear.ltz(c_dprime - r_prime, l, kappa)
if compute_modulo:
b = c_dprime - r_prime + pow2m * d
return b, pow2m
Expand Down Expand Up @@ -629,12 +629,14 @@ def BITLT(a, b, bit_length):
# - From the paper
# Multiparty Computation for Interval, Equality, and Comparison without
# Bit-Decomposition Protocol
def BitDecFull(a, maybe_mixed=False):
def BitDecFull(a, n_bits=None, maybe_mixed=False):
from .library import get_program, do_while, if_, break_point
from .types import sint, regint, longint, cint
p = get_program().prime
assert p
bit_length = p.bit_length()
n_bits = n_bits or bit_length
assert n_bits <= bit_length
logp = int(round(math.log(p, 2)))
if abs(p - 2 ** logp) / p < 2 ** -get_program().security:
# inspired by Rabbit (https://eprint.iacr.org/2021/119)
Expand Down Expand Up @@ -677,12 +679,12 @@ def _():
czero = (c==0)
q = bbits[0].long_one() - comparison.BitLTL_raw(bbits, t)
fbar = [bbits[0].clear_type.conv(cint(x))
for x in ((1<<bit_length)+c-p).bit_decompose(bit_length)]
fbard = bbits[0].bit_decompose_clear(cmodp, bit_length)
g = [q.if_else(fbar[i], fbard[i]) for i in range(bit_length)]
for x in ((1<<bit_length)+c-p).bit_decompose(n_bits)]
fbard = bbits[0].bit_decompose_clear(cmodp, n_bits)
g = [q.if_else(fbar[i], fbard[i]) for i in range(n_bits)]
h = bbits[0].bit_adder(bbits, g)
abits = [bbits[0].clear_type(cint(czero)).if_else(bbits[i], h[i])
for i in range(bit_length)]
for i in range(n_bits)]
if maybe_mixed:
return abits
else:
Expand Down
17 changes: 2 additions & 15 deletions Compiler/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ class join_tape(base.Instruction):
arg_format = ['int']

class crash(base.IOInstruction):
""" Crash runtime if the register's value is > 0.
""" Crash runtime if the value in the register is not zero.
:param: Crash condition (regint)"""
code = base.opcodes['CRASH']
Expand Down Expand Up @@ -1275,7 +1275,7 @@ class prep(base.Instruction):
field_type = 'modp'

def add_usage(self, req_node):
req_node.increment((self.field_type, self.args[0]), 1)
req_node.increment((self.field_type, self.args[0]), self.get_size())

def has_var_args(self):
return True
Expand Down Expand Up @@ -2407,19 +2407,6 @@ def expand(self):
subml(self.args[0], s[5], c[1])


@base.gf2n
@base.vectorize
class lts(base.CISC):
""" Secret comparison $s_i = (s_j < s_k)$. """
__slots__ = []
arg_format = ['sw', 's', 's', 'int', 'int']

def expand(self):
from .types import sint
a = sint()
subs(a, self.args[1], self.args[2])
comparison.LTZ(self.args[0], a, self.args[3], self.args[4])

# placeholder for documentation
class cisc:
""" Meta instruction for emulation. This instruction is only generated
Expand Down
83 changes: 78 additions & 5 deletions Compiler/instructions_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import inspect
import functools
import copy
import sys
import struct
from Compiler.exceptions import *
from Compiler.config import *
from Compiler import util
Expand Down Expand Up @@ -299,11 +301,12 @@ def maybe_vectorized_instruction(*args, **kwargs):
vectorized_name = 'v' + instruction.__name__
Vectorized_Instruction.__name__ = vectorized_name
global_dict[vectorized_name] = Vectorized_Instruction

if 'sphinx.extension' in sys.modules:
return instruction

global_dict[instruction.__name__ + '_class'] = instruction
instruction.__doc__ = ''
# exclude GF(2^n) instructions from documentation
if instruction.code and instruction.code >> 8 == 1:
maybe_vectorized_instruction.__doc__ = ''
maybe_vectorized_instruction.arg_format = instruction.arg_format
return maybe_vectorized_instruction


Expand Down Expand Up @@ -389,8 +392,11 @@ def maybe_gf2n_instruction(*args, **kwargs):
else:
global_dict[GF2N_Instruction.__name__] = GF2N_Instruction

if 'sphinx.extension' in sys.modules:
return instruction

global_dict[instruction.__name__ + '_class'] = instruction_cls
instruction_cls.__doc__ = ''
maybe_gf2n_instruction.arg_format = instruction.arg_format
return maybe_gf2n_instruction
#return instruction

Expand Down Expand Up @@ -661,6 +667,12 @@ def encode(cls, arg):
assert arg.i >= 0
return int_to_bytes(arg.i)

def __init__(self, f):
self.i = struct.unpack('>I', f.read(4))[0]

def __str__(self):
return self.reg_type + str(self.i)

class ClearModpAF(RegisterArgFormat):
reg_type = RegType.ClearModp

Expand All @@ -686,6 +698,12 @@ def check(cls, arg):
def encode(cls, arg):
return int_to_bytes(arg)

def __init__(self, f):
self.i = struct.unpack('>i', f.read(4))[0]

def __str__(self):
return str(self.i)

class ImmediateModpAF(IntArgFormat):
@classmethod
def check(cls, arg):
Expand Down Expand Up @@ -722,6 +740,13 @@ def check(cls, arg):
def encode(cls, arg):
return bytearray(arg, 'ascii') + b'\0' * (cls.length - len(arg))

def __init__(self, f):
tmp = f.read(16)
self.str = str(tmp[0:tmp.find(b'\0')], 'ascii')

def __str__(self):
return self.str

ArgFormats = {
'c': ClearModpAF,
's': SecretModpAF,
Expand Down Expand Up @@ -890,6 +915,54 @@ def __str__(self):
def __repr__(self):
return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'

class ParsedInstruction:
reverse_opcodes = {}

def __init__(self, f):
cls = type(self)
from Compiler import instructions
from Compiler.GC import instructions as gc_inst
if not cls.reverse_opcodes:
for module in instructions, gc_inst:
for x, y in inspect.getmodule(module).__dict__.items():
if inspect.isclass(y) and y.__name__[0] != 'v':
try:
cls.reverse_opcodes[y.code] = y
except AttributeError:
pass
read = lambda: struct.unpack('>I', f.read(4))[0]
full_code = read()
code = full_code % (1 << Instruction.code_length)
self.size = full_code >> Instruction.code_length
self.type = cls.reverse_opcodes[code]
t = self.type
name = t.__name__
try:
n_args = len(t.arg_format)
self.var_args = False
except:
n_args = read()
self.var_args = True
try:
arg_format = iter(t.arg_format)
except:
if name == 'cisc':
arg_format = itertools.chain(['str'], itertools.repeat('int'))
else:
arg_format = itertools.repeat('int')
self.args = [ArgFormats[next(arg_format)](f)
for i in range(n_args)]

def __str__(self):
name = self.type.__name__
res = name + ' '
if self.size > 1:
res = 'v' + res + str(self.size) + ', '
if self.var_args:
res += str(len(self.args)) + ', '
res += ', '.join(str(arg) for arg in self.args)
return res

class VarArgsInstruction(Instruction):
def has_var_args(self):
return True
Expand Down
Loading

0 comments on commit e07d9bf

Please sign in to comment.