Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add refactor from past #113

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions integration_test/init_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,62 +131,62 @@ func TestThreshold(t *testing.T) {
owner := newEthAddress(t)
t.Run("test 13 operators threshold", func(t *testing.T) {
id := spec.NewID()
_, ks, _, err := clnt.StartDKG(id, withdraw.Bytes(), []uint64{11, 22, 33, 44, 55, 66, 77, 88, 99, 100, 111, 122, 133}, "mainnet", owner, 0)
ids := []uint64{11, 22, 33, 44, 55, 66, 77, 88, 99, 100, 111, 122, 133}
_, ks, _, err := clnt.StartDKG(id, withdraw.Bytes(), ids, "mainnet", owner, 0)
require.NoError(t, err)
sharesDataSigned, err := hex.DecodeString(ks.Shares[0].Payload.SharesData[2:])
require.NoError(t, err)
pubkeyraw, err := hex.DecodeString(ks.Shares[0].Payload.PublicKey[2:])
require.NoError(t, err)
threshold, err := utils.GetThreshold([]uint64{11, 22, 33, 44, 55, 66, 77, 88, 99, 100, 111, 122, 133})
require.NoError(t, err)
priviteKeys := []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey, servers[4].PrivKey, servers[5].PrivKey, servers[6].PrivKey, servers[7].PrivKey}
require.Less(t, len(priviteKeys), threshold)
err = testSharesData(ops, 13, priviteKeys, sharesDataSigned, pubkeyraw, owner, 0)
threshold := utils.GetThreshold(ids)
privateKeys := []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey, servers[4].PrivKey, servers[5].PrivKey, servers[6].PrivKey, servers[7].PrivKey}
require.Less(t, len(privateKeys), threshold)
err = testSharesData(ops, 13, privateKeys, sharesDataSigned, pubkeyraw, owner, 0)
require.ErrorContains(t, err, "could not reconstruct a valid signature")
// test valid minimum threshold
priviteKeys = []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey, servers[4].PrivKey, servers[5].PrivKey, servers[6].PrivKey, servers[7].PrivKey, servers[8].PrivKey}
require.Equal(t, len(priviteKeys), threshold)
err = testSharesData(ops, 13, priviteKeys, sharesDataSigned, pubkeyraw, owner, 0)
privateKeys = []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey, servers[4].PrivKey, servers[5].PrivKey, servers[6].PrivKey, servers[7].PrivKey, servers[8].PrivKey}
require.Equal(t, len(privateKeys), threshold)
err = testSharesData(ops, 13, privateKeys, sharesDataSigned, pubkeyraw, owner, 0)
require.NoError(t, err)
})
t.Run("test 10 operators threshold", func(t *testing.T) {
id := spec.NewID()
_, ks, _, err := clnt.StartDKG(id, withdraw.Bytes(), []uint64{11, 22, 33, 44, 55, 66, 77, 88, 99, 100}, "mainnet", owner, 0)
ids := []uint64{11, 22, 33, 44, 55, 66, 77, 88, 99, 100}
_, ks, _, err := clnt.StartDKG(id, withdraw.Bytes(), ids, "mainnet", owner, 0)
require.NoError(t, err)
sharesDataSigned, err := hex.DecodeString(ks.Shares[0].Payload.SharesData[2:])
require.NoError(t, err)
pubkeyraw, err := hex.DecodeString(ks.Shares[0].Payload.PublicKey[2:])
require.NoError(t, err)
threshold, err := utils.GetThreshold([]uint64{11, 22, 33, 44, 55, 66, 77, 88, 99, 100})
require.NoError(t, err)
priviteKeys := []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey, servers[4].PrivKey, servers[5].PrivKey}
require.Less(t, len(priviteKeys), threshold)
err = testSharesData(ops, 10, priviteKeys, sharesDataSigned, pubkeyraw, owner, 0)
threshold := utils.GetThreshold(ids)
privateKeys := []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey, servers[4].PrivKey, servers[5].PrivKey}
require.Less(t, len(privateKeys), threshold)
err = testSharesData(ops, 10, privateKeys, sharesDataSigned, pubkeyraw, owner, 0)
require.ErrorContains(t, err, "could not reconstruct a valid signature")
// test valid minimum threshold
priviteKeys = []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey, servers[4].PrivKey, servers[5].PrivKey, servers[6].PrivKey}
require.Equal(t, len(priviteKeys), threshold)
err = testSharesData(ops, 10, priviteKeys, sharesDataSigned, pubkeyraw, owner, 0)
privateKeys = []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey, servers[4].PrivKey, servers[5].PrivKey, servers[6].PrivKey}
require.Equal(t, len(privateKeys), threshold)
err = testSharesData(ops, 10, privateKeys, sharesDataSigned, pubkeyraw, owner, 0)
require.NoError(t, err)
})
t.Run("test 7 operators threshold", func(t *testing.T) {
id := spec.NewID()
_, ks, _, err := clnt.StartDKG(id, withdraw.Bytes(), []uint64{11, 22, 33, 44, 55, 66, 77}, "mainnet", owner, 0)
ids := []uint64{11, 22, 33, 44, 55, 66, 77}
_, ks, _, err := clnt.StartDKG(id, withdraw.Bytes(), ids, "mainnet", owner, 0)
require.NoError(t, err)
sharesDataSigned, err := hex.DecodeString(ks.Shares[0].Payload.SharesData[2:])
require.NoError(t, err)
pubkeyraw, err := hex.DecodeString(ks.Shares[0].Payload.PublicKey[2:])
require.NoError(t, err)
threshold, err := utils.GetThreshold([]uint64{11, 22, 33, 44, 55, 66, 77})
require.NoError(t, err)
priviteKeys := []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey}
require.Less(t, len(priviteKeys), threshold)
err = testSharesData(ops, 7, priviteKeys, sharesDataSigned, pubkeyraw, owner, 0)
threshold := utils.GetThreshold(ids)
privateKeys := []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey}
require.Less(t, len(privateKeys), threshold)
err = testSharesData(ops, 7, privateKeys, sharesDataSigned, pubkeyraw, owner, 0)
require.ErrorContains(t, err, "could not reconstruct a valid signature")
// test valid minimum threshold
priviteKeys = []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey, servers[4].PrivKey}
require.Equal(t, len(priviteKeys), threshold)
err = testSharesData(ops, 7, priviteKeys, sharesDataSigned, pubkeyraw, owner, 0)
privateKeys = []*rsa.PrivateKey{servers[0].PrivKey, servers[1].PrivKey, servers[2].PrivKey, servers[3].PrivKey, servers[4].PrivKey}
require.Equal(t, len(privateKeys), threshold)
err = testSharesData(ops, 7, privateKeys, sharesDataSigned, pubkeyraw, owner, 0)
require.NoError(t, err)
})
t.Run("test 4 operators threshold", func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions pkgs/crypto/crypto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ func testResults(t *testing.T, suite pairing.Suite, thr, n int, results []*dkg.R

secretPoly, err := share.RecoverPriPoly(suite.G1(), shares, thr, n)
coefs := secretPoly.Coefficients()
t.Logf("Ploly len %d", len(coefs))
t.Logf("Poly len %d", len(coefs))
for _, c := range coefs {
t.Logf("Ploly coef %s", c.String())
t.Logf("Poly coef %s", c.String())
}
require.NoError(t, err)
gotPub := secretPoly.Commit(suite.G1().Point().Base())
Expand Down
6 changes: 0 additions & 6 deletions pkgs/dkg/drand.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ var ErrAlreadyExists = errors.New("duplicate message")
type LocalOwner struct {
Logger *zap.Logger
startedDKG chan struct{}
ErrorChan chan error
ID uint64
data *DKGdata
board *board.Board
Expand All @@ -83,7 +82,6 @@ func New(opts *OwnerOpts) *LocalOwner {
owner := &LocalOwner{
Logger: opts.Logger,
startedDKG: make(chan struct{}, 1),
ErrorChan: make(chan error, 1),
ID: opts.ID,
broadcastF: opts.BroadcastF,
exchanges: make(map[uint64]*wire.Exchange),
Expand Down Expand Up @@ -534,10 +532,6 @@ func (o *LocalOwner) checkOperators() bool {
return true
}

func (o *LocalOwner) GetLocalOwner() *LocalOwner {
return o
}

// GetDKGNodes returns a slice of DKG node instances used for the protocol
func (o *LocalOwner) GetDKGNodes(ops []*spec.Operator) ([]kyber_dkg.Node, error) {
nodes := make([]kyber_dkg.Node, 0)
Expand Down
3 changes: 1 addition & 2 deletions pkgs/initiator/initiator.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,7 @@ func (c *Initiator) StartDKG(id [24]byte, withdraw []byte, ids []uint64, network
instanceIDField := zap.String("Ceremony ID", hex.EncodeToString(id[:]))
c.Logger.Info("🚀 Starting init dkg ceremony", zap.Uint64s("operator IDs", ids))

// compute threshold (3f+1)
threshold := len(ids) - ((len(ids) - 1) / 3)
threshold := utils.GetThreshold(ids)
// make init message
init := &spec.Init{
Operators: ops,
Expand Down
2 changes: 1 addition & 1 deletion pkgs/initiator/initiator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func TestLoadOperators(t *testing.T) {
"ip": "wrongURL"
}
]`), &ops)
require.ErrorContains(t, err, "invalid operator URL")
require.ErrorContains(t, err, "invalid operator 1 URL")
})
}

Expand Down
24 changes: 24 additions & 0 deletions pkgs/operator/crypto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package operator

import (
"crypto/rand"
"crypto/rsa"

"github.com/bloxapp/ssv/utils/rsaencryption"
spec_crypto "github.com/ssvlabs/dkg-spec/crypto"
)

// Sign creates a RSA signature for the message at operator before sending it to initiator
func (s *Switch) Sign(msg []byte) ([]byte, error) {
return spec_crypto.SignRSA(s.PrivateKey, msg)
}

// Encrypt with RSA public key private DKG share key
func (s *Switch) Encrypt(msg []byte) ([]byte, error) {
return rsa.EncryptPKCS1v15(rand.Reader, &s.PrivateKey.PublicKey, msg)
}

// Decrypt with RSA private key private DKG share key
func (s *Switch) Decrypt(ciphertext []byte) ([]byte, error) {
return rsaencryption.DecodeKey(s.PrivateKey, ciphertext)
}
154 changes: 154 additions & 0 deletions pkgs/operator/handlers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package operator

import (
"encoding/hex"
"fmt"
"io"
"net/http"

"github.com/bloxapp/ssv-dkg/pkgs/utils"
"github.com/bloxapp/ssv-dkg/pkgs/wire"
"github.com/pkg/errors"
"go.uber.org/zap"
)

func (s *Server) resultsHandler(writer http.ResponseWriter, request *http.Request) {
rawdata, err := io.ReadAll(request.Body)
if err != nil {
utils.WriteErrorResponse(s.Logger, writer, err, http.StatusBadRequest)
return
}
signedResultMsg := &wire.SignedTransport{}
if err := signedResultMsg.UnmarshalSSZ(rawdata); err != nil {
utils.WriteErrorResponse(s.Logger, writer, err, http.StatusBadRequest)
return
}

// Validate that incoming message is a result message
if signedResultMsg.Message.Type != wire.ResultMessageType {
utils.WriteErrorResponse(s.Logger, writer, errors.New("received wrong message type"), http.StatusBadRequest)
return
}
s.Logger.Debug("received a result message")
err = s.State.SaveResultData(signedResultMsg, s.OutputPath)
if err != nil {
err := &utils.SensitiveError{Err: err, PresentedErr: "failed to write results"}
utils.WriteErrorResponse(s.Logger, writer, err, http.StatusBadRequest)
return
}
writer.WriteHeader(http.StatusOK)
}

func (s *Server) healthHandler(writer http.ResponseWriter, request *http.Request) {
b, err := s.State.Pong()
if err != nil {
utils.WriteErrorResponse(s.Logger, writer, err, http.StatusBadRequest)
return
}
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(b); err != nil {
s.Logger.Error("error writing health_check response: " + err.Error())
return
}
}

func (s *Server) dkgHandler(writer http.ResponseWriter, request *http.Request) {
s.Logger.Debug("received a dkg protocol message")
rawdata, err := io.ReadAll(request.Body)
if err != nil {
utils.WriteErrorResponse(s.Logger, writer, fmt.Errorf("operator %d, err: %v", s.State.OperatorID, err), http.StatusBadRequest)
return
}
b, err := s.State.ProcessMessage(rawdata)
if err != nil {
utils.WriteErrorResponse(s.Logger, writer, fmt.Errorf("operator %d, err: %v", s.State.OperatorID, err), http.StatusBadRequest)
return
}
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(b); err != nil {
s.Logger.Error("error writing dkg response: " + err.Error())
return
}
}

func (s *Server) initHandler(writer http.ResponseWriter, request *http.Request) {
s.Logger.Debug("incoming INIT msg")
signedInitMsg, err := processIncomingRequest(s.Logger, writer, request, wire.InitMessageType, s.State.OperatorID)
if err != nil {
s.Logger.Error("Error processing incoming init message", zap.Error(err))
utils.WriteErrorResponse(s.Logger, writer, err, http.StatusBadRequest)
return
}
reqid := signedInitMsg.Message.Identifier
logger := s.Logger.With(zap.String("reqid", hex.EncodeToString(reqid[:])))
logger.Debug("creating instance with init message data")
b, err := s.State.InitInstance(reqid, signedInitMsg.Message, signedInitMsg.Signer, signedInitMsg.Signature)
if err != nil {
s.Logger.Error("Error creating instance", zap.Error(err))
utils.WriteErrorResponse(s.Logger, writer, fmt.Errorf("operator %d, failed to initialize instance, err: %v", s.State.OperatorID, err), http.StatusBadRequest)
return
}
logger.Info("✅ Instance started successfully")

writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(b); err != nil {
logger.Error("error writing init response: " + err.Error())
return
}
}

func (s *Server) resignHandler(writer http.ResponseWriter, request *http.Request) {
s.Logger.Debug("incoming RESIGN msg")
signedResignMsg, err := processIncomingRequest(s.Logger, writer, request, wire.ResignMessageType, s.State.OperatorID)
if err != nil {
s.Logger.Error("Error processing incoming init message", zap.Error(err))
utils.WriteErrorResponse(s.Logger, writer, err, http.StatusBadRequest)
return
}
reqid := signedResignMsg.Message.Identifier
logger := s.Logger.With(zap.String("reqid", hex.EncodeToString(reqid[:])))
b, err := s.State.HandleInstanceOperation(reqid, signedResignMsg.Message, signedResignMsg.Signer, signedResignMsg.Signature, "resign")
if err != nil {
s.Logger.Error("Error resigning instance", zap.Error(err))
utils.WriteErrorResponse(s.Logger, writer, fmt.Errorf("operator %d, failed to resign, err: %v", s.State.OperatorID, err), http.StatusBadRequest)
return
}
logger.Info("✅ resigned data successfully")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(b); err != nil {
logger.Error("error writing resign response: " + err.Error())
return
}
}

func (s *Server) reshareHandler(writer http.ResponseWriter, request *http.Request) {
s.Logger.Debug("incoming RESHARE msg")
rawdata, err := io.ReadAll(request.Body)
if err != nil {
utils.WriteErrorResponse(s.Logger, writer, fmt.Errorf("operator %d, err: %v", s.State.OperatorID, err), http.StatusBadRequest)
return
}
signedReshareMsg := &wire.SignedTransport{}
if err := signedReshareMsg.UnmarshalSSZ(rawdata); err != nil {
utils.WriteErrorResponse(s.Logger, writer, err, http.StatusBadRequest)
return
}
// Validate that incoming message is an init message
if signedReshareMsg.Message.Type != wire.ReshareMessageType {
utils.WriteErrorResponse(s.Logger, writer, fmt.Errorf("operator %d, err: %v", s.State.OperatorID, errors.New("not reshare message to reshare route")), http.StatusBadRequest)
return
}
reqid := signedReshareMsg.Message.Identifier
logger := s.Logger.With(zap.String("reqid", hex.EncodeToString(reqid[:])))
b, err := s.State.HandleInstanceOperation(reqid, signedReshareMsg.Message, signedReshareMsg.Signer, signedReshareMsg.Signature, "reshare")
if err != nil {
utils.WriteErrorResponse(s.Logger, writer, fmt.Errorf("operator %d, err: %v", s.State.OperatorID, err), http.StatusBadRequest)
return
}
logger.Info("✅ Reshare instance created successfully")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(b); err != nil {
logger.Error("error writing reshare response: " + err.Error())
return
}
}
Loading
Loading