Skip to content

Commit

Permalink
Basic refresh tests work again!
Browse files Browse the repository at this point in the history
Repurposed `PrivateKey.recover_share_from_updated_private_shares` code as the test function `combine_private_shares_at`, since it doesn't make sense to combine private key shares in production code.
  • Loading branch information
cygnusv committed Apr 8, 2024
1 parent 57026aa commit f422903
Showing 1 changed file with 58 additions and 87 deletions.
145 changes: 58 additions & 87 deletions ferveo/src/refresh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use ark_poly::{
};
use ferveo_common::{serialization, Keypair};
use ferveo_tdec::{
lagrange_basis_at, prepare_combine_simple, BlindedKeyShare, CiphertextHeader,
prepare_combine_simple, BlindedKeyShare, CiphertextHeader,
DecryptionSharePrecomputed, DecryptionShareSimple,
};
use itertools::{zip_eq, Itertools};
Expand Down Expand Up @@ -90,46 +90,6 @@ impl<E: Pairing> UpdatableBlindedKeyShare<E> {
// };
}

// TODO: Input should be named somthing different than UpdatedPrivateKeyShare
// Perhaps RecoveryShare, or something
/// From the PSS paper, section 4.2.4, (https://link.springer.com/content/pdf/10.1007/3-540-44750-4_27.pdf)
/// `x_r` is the point at which the share is to be recovered
pub fn recover_share_from_updated_private_shares(
// TODO: Consider hiding x_r from the public API
x_r: &DomainPoint<E>,
domain_points: &HashMap<u32, DomainPoint<E>>,
// TODO: recovery_shares?
updated_shares: &HashMap<u32, UpdatedPrivateKeyShare<E>>,
) -> Result<PrivateKeyShare<E>> {
// Pick the domain points and updated shares according to share index
let mut domain_points_ = vec![];
let mut updated_shares_ = vec![];
for share_index in updated_shares.keys().sorted() {
domain_points_.push(
*domain_points
.get(share_index)
.ok_or(Error::InvalidShareIndex(*share_index))?,
);
updated_shares_.push(
updated_shares
.get(share_index)
.ok_or(Error::InvalidShareIndex(*share_index))?
.0
.clone(),
);
}

// Interpolate new shares to recover y_r
// TODO: check if this logic is repeated a bunch of times in other places
let lagrange = lagrange_basis_at::<E>(&domain_points_, x_r);
let prods =
zip_eq(updated_shares_, lagrange).map(|(y_j, l)| y_j.0.mul(l));
let y_r = prods.fold(E::G2::zero(), |acc, y_j| acc + y_j);
Ok(PrivateKeyShare(ferveo_tdec::PrivateKeyShare(
y_r.into_affine(),
)))
}

pub fn unblind_private_key_share(
&self,
validator_keypair: &Keypair<E>,
Expand Down Expand Up @@ -433,23 +393,26 @@ fn make_random_polynomial_with_root<E: Pairing>(
#[cfg(test)]
mod tests_refresh {
use std::collections::HashMap;
use std::ops::Mul;

use ark_bls12_381::Fr;
use ark_ec::CurveGroup;
use ark_poly::EvaluationDomain;
use ark_std::{test_rng, UniformRand, Zero};
use ferveo_tdec::{
test_common::setup_simple, BlindedKeyShare,
PrivateDecryptionContextSimple,
lagrange_basis_at, test_common::setup_simple
};
use itertools::{zip_eq, Itertools};
use rand_core::RngCore;
use test_case::{test_case, test_matrix};

use crate::{
test_common::*, DomainPoint, UpdateTranscript,
test_common::*, DomainPoint, UpdatableBlindedKeyShare, UpdateTranscript
};

type ScalarField =
<ark_bls12_381::Bls12_381 as ark_ec::pairing::Pairing>::ScalarField;
type G2 = <ark_bls12_381::Bls12_381 as ark_ec::pairing::Pairing>::G2;

// /// Using tdec test utilities here instead of PVSS to test the internals of the shared key recovery
// fn create_updated_private_key_shares<R: RngCore>(
Expand Down Expand Up @@ -502,6 +465,27 @@ mod tests_refresh {
// updated_private_key_shares
// }

/// `x_r` is the point at which the share is to be recovered
fn combine_private_shares_at(
x_r: &DomainPoint<E>,
domain_points: &HashMap<u32, DomainPoint<E>>,
shares: &HashMap<u32, ferveo_tdec::PrivateKeyShare<E>>,
) -> ferveo_tdec::PrivateKeyShare<E> {
let mut domain_points_ = vec![];
let mut updated_shares_ = vec![];
for share_index in shares.keys().sorted() {
domain_points_.push(*domain_points.get(share_index).unwrap());
updated_shares_.push(shares.get(share_index).unwrap().0.clone());
}

// Interpolate new shares to recover y_r
let lagrange = lagrange_basis_at::<E>(&domain_points_, &x_r);
let prods =
zip_eq(updated_shares_, lagrange).map(|(y_j, l)| y_j.mul(l));
let y_r = prods.fold(G2::zero(), |acc, y_j| acc + y_j);
ferveo_tdec::PrivateKeyShare(y_r.into_affine())
}

/// Ñ parties (where t <= Ñ <= N) jointly execute a "share recovery" algorithm, and the output is 1 new share.
/// The new share is intended to restore a previously existing share, e.g., due to loss or corruption.
#[test_case(4, 4; "number of shares (validators) is a power of 2")]
Expand Down Expand Up @@ -659,20 +643,20 @@ mod tests_refresh {

// This is a workaround for a type mismatch - We need to convert the private shares to updated private shares
// This is just to test that we are able to recover the shared private key from the updated private shares
let updated_private_key_shares = private_shares
.into_iter()
.map(|(share_index, share)| {
(share_index, UpdatedPrivateKeyShare(share))
})
.collect::<HashMap<u32, _>>();
let new_shared_private_key =
PrivateKeyShare::recover_share_from_updated_private_shares(
&ScalarField::zero(),
domain_points,
&updated_private_key_shares,
)
.unwrap();
assert_eq!(shared_private_key, new_shared_private_key.0);
// let updated_private_key_shares = private_shares
// .into_iter()
// .map(|(share_index, share)| {
// (share_index, UpdatedPrivateKeyShare(share))
// })
// .collect::<HashMap<u32, _>>();
// let new_shared_private_key =
// PrivateKeyShare::recover_share_from_updated_private_shares(
// &ScalarField::zero(),
// domain_points,
// &updated_private_key_shares,
// )
// .unwrap();
assert_ne!(shared_private_key, shared_private_key);
}

/// Ñ parties (where t <= Ñ <= N) jointly execute a "share refresh" algorithm.
Expand Down Expand Up @@ -735,9 +719,6 @@ mod tests_refresh {
let blinded_key_share =
p.public_decryption_contexts[p.index].blinded_key_share;

let participant_public_key =
blinded_key_share.validator_public_key;

// Current participant receives update transcripts from other participants
let updates_for_participant: Vec<_> =
update_transcripts_by_producer
Expand All @@ -760,27 +741,19 @@ mod tests_refresh {
.collect();

// And creates a new, refreshed share
let updated_blinded_key_share = UpdatableBlindedKeyShare(blinded_key_share)
.apply_share_updates(&updates_for_participant);

// TODO: Encapsulate this somewhere, originally from PrivateKeyShare.create_updated_key_share
let updated_blinded_key_share: BlindedKeyShare<E> =
BlindedKeyShare {
validator_public_key: participant_public_key,
blinded_key_share: updates_for_participant.iter().fold(
blinded_key_share.blinded_key_share,
|acc, delta| (acc + delta.update).into(),
),
};

let unblinding_factor = p.setup_params.b_inv;
let updated_share = UpdatedPrivateKeyShare(
updated_blinded_key_share.unblind(unblinding_factor),
);

(p.index as u32, updated_share)
let validator_keypair = ferveo_common::Keypair{
decryption_key: p.setup_params.b
};
let updated_private_share = updated_blinded_key_share.unblind_private_key_share(&validator_keypair).unwrap();

(p.index as u32, updated_private_share)
})
// We only need `threshold` refreshed shares to recover the original share
.take(security_threshold)
.collect::<HashMap<u32, UpdatedPrivateKeyShare<E>>>();
.collect::<HashMap<u32, ferveo_tdec::PrivateKeyShare<E>>>();

let domain_points = domain_points_and_keys
.iter()
Expand All @@ -789,14 +762,12 @@ mod tests_refresh {
})
.collect::<HashMap<u32, DomainPoint<E>>>();

// Finally, let's recreate the shared private key from the refreshed shares
let new_shared_private_key =
PrivateKeyShare::recover_share_from_updated_private_shares(
&ScalarField::zero(),
&domain_points,
&refreshed_shares,
)
.unwrap();
assert_eq!(shared_private_key, new_shared_private_key.0);
let x_r = ScalarField::zero();
let new_shared_private_key = combine_private_shares_at(
&x_r,
&domain_points,
&refreshed_shares
);
assert_eq!(shared_private_key, new_shared_private_key);
}
}

0 comments on commit f422903

Please sign in to comment.