diff --git a/.github/workflows/bump-minor-version.yml b/.github/workflows/minor-release-step1.yml similarity index 96% rename from .github/workflows/bump-minor-version.yml rename to .github/workflows/minor-release-step1.yml index 6cc929952..dbb85e400 100644 --- a/.github/workflows/bump-minor-version.yml +++ b/.github/workflows/minor-release-step1.yml @@ -1,4 +1,8 @@ -name: Bump Minor Version +name: Minor release step 1 - Create new release branch +description: | + Create a new `main-vX.Y`` branch based on latest release branch. + Create a pull request to update version.go file from `vX.Y` to `vX.Y-rc`. + Next, developers should merge this PR. on: workflow_dispatch: diff --git a/.github/workflows/tag-minor-release-candidate.yml b/.github/workflows/minor-release-step2.yml similarity index 93% rename from .github/workflows/tag-minor-release-candidate.yml rename to .github/workflows/minor-release-step2.yml index 682ba8396..6d7bf4f09 100644 --- a/.github/workflows/tag-minor-release-candidate.yml +++ b/.github/workflows/minor-release-step2.yml @@ -1,4 +1,8 @@ -name: Tag Minor Release Candidate +name: Minor release step 2 - Tag release candidate +description: | + Tag `vX.Y.0-rcN`, where the version used is the one from latest `main-vX.Y` release branch. + Next, developers should test the rc. + If new rc is needed, PRs with changes should be merged to `main-vX.Y` branch and developers should re-run this pipeline to create a new rc tag. on: workflow_dispatch: diff --git a/.github/workflows/prepare-minor-full-release.yml b/.github/workflows/minor-release-step3.yml similarity index 98% rename from .github/workflows/prepare-minor-full-release.yml rename to .github/workflows/minor-release-step3.yml index 21f61a207..095379558 100644 --- a/.github/workflows/prepare-minor-full-release.yml +++ b/.github/workflows/minor-release-step3.yml @@ -1,7 +1,6 @@ -name: Prepare Minor Full Release +name: Minor release step 3 - Release branch to stable version description: | - This workflow creates pull requests to update version files for a minor full release. - After the PRs are merged, a tag for the stable release should be created manually. + Create pull request to update version.go file in latest `main-vX.Y` from `vX.Y-rc` to `vX.Y`. on: workflow_dispatch: diff --git a/.github/workflows/tag-minor-full-release.yml b/.github/workflows/minor-release-step4.yml similarity index 97% rename from .github/workflows/tag-minor-full-release.yml rename to .github/workflows/minor-release-step4.yml index 4ba84ea59..4d52736a6 100644 --- a/.github/workflows/tag-minor-full-release.yml +++ b/.github/workflows/minor-release-step4.yml @@ -1,4 +1,6 @@ -name: Tag Minor Full Release +name: Minor release step 4 - Tag minor full release +description: | + Tag `vX.Y.0`, where the version used is the one from latest `main-vX.Y` release branch. on: workflow_dispatch: diff --git a/.github/workflows/bump-patch-version.yml b/.github/workflows/patch-release-step1.yml similarity index 96% rename from .github/workflows/bump-patch-version.yml rename to .github/workflows/patch-release-step1.yml index 72e7648c0..5b39b1b15 100644 --- a/.github/workflows/bump-patch-version.yml +++ b/.github/workflows/patch-release-step1.yml @@ -1,4 +1,7 @@ -name: Bump Patch Version +name: Patch release step 1 - Create new release branch +description: | + Create a PR to update version.go file from `vX.Y` to `vX.Y-rc` in latest `main-vX.Y` release branch. + Next, developers should merge this PR and open manually PRs to `main-vX.Y` with the desired cherry-picked commits from `main`. on: workflow_dispatch: diff --git a/.github/workflows/tag-patch-release-candidate.yml b/.github/workflows/patch-release-step2.yml similarity index 95% rename from .github/workflows/tag-patch-release-candidate.yml rename to .github/workflows/patch-release-step2.yml index caa30b122..6632d2a2b 100644 --- a/.github/workflows/tag-patch-release-candidate.yml +++ b/.github/workflows/patch-release-step2.yml @@ -1,4 +1,8 @@ -name: Tag Patch Release Candidate +name: Patch release step 2 - Tag release candidate +description: | + Tag `vX.Y.Z-rcN`, where the version used is the one from latest `main-vX.Y` release branch and latest stable patch tag + 1. + Next, developers should test the rc. + If new rc is needed, PRs with changes should be merged to `main-vX.Y` branch and developers should re-run this pipeline to create a new rc tag. on: workflow_dispatch: diff --git a/.github/workflows/prepare-patch-full-release.yml b/.github/workflows/patch-release-step3.yml similarity index 98% rename from .github/workflows/prepare-patch-full-release.yml rename to .github/workflows/patch-release-step3.yml index a268b0c40..e3afd9ba7 100644 --- a/.github/workflows/prepare-patch-full-release.yml +++ b/.github/workflows/patch-release-step3.yml @@ -1,4 +1,6 @@ -name: Prepare Patch Full Release +name: Patch release step 3 - Release branch to stable version +description: | + Create pull request to update version.go file in latest `main-vX.Y` from `vX.Y-rc` to `vX.Y`. on: workflow_dispatch: diff --git a/.github/workflows/tag-patch-full-release.yml b/.github/workflows/patch-release-step4.yml similarity index 97% rename from .github/workflows/tag-patch-full-release.yml rename to .github/workflows/patch-release-step4.yml index 46ef6f70a..f75fab321 100644 --- a/.github/workflows/tag-patch-full-release.yml +++ b/.github/workflows/patch-release-step4.yml @@ -1,4 +1,6 @@ -name: Tag Patch Full Release +name: Patch release step 4 - Tag patch full release +description: | + Tag `vX.Y.Z`, where the version used is the one from latest `main-vX.Y` release branch and latest stable patch tag + 1. on: workflow_dispatch: diff --git a/cluster/helpers.go b/cluster/helpers.go index de22dfda2..c927bbf37 100644 --- a/cluster/helpers.go +++ b/cluster/helpers.go @@ -26,6 +26,11 @@ import ( "github.com/obolnetwork/charon/tbls" ) +const ( + // maxDefinitionSize is the maximum allowed size for a cluster definition file (16MB). + maxDefinitionSize = 16 * 1024 * 1024 +) + // FetchDefinition fetches cluster definition file from a remote URI. func FetchDefinition(ctx context.Context, url string) (Definition, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*10) @@ -47,11 +52,17 @@ func FetchDefinition(ctx context.Context, url string) (Definition, error) { defer resp.Body.Close() - buf, err := io.ReadAll(resp.Body) + limitedReader := io.LimitReader(resp.Body, maxDefinitionSize+1) + + buf, err := io.ReadAll(limitedReader) if err != nil { return Definition{}, errors.Wrap(err, "read response body") } + if len(buf) > maxDefinitionSize { + return Definition{}, errors.New("definition file too large", z.Int("max_bytes", maxDefinitionSize)) + } + var res Definition if err := json.Unmarshal(buf, &res); err != nil { return Definition{}, errors.Wrap(err, "unmarshal definition") diff --git a/cluster/helpers_internal_test.go b/cluster/helpers_internal_test.go index 45377b583..d03d10718 100644 --- a/cluster/helpers_internal_test.go +++ b/cluster/helpers_internal_test.go @@ -94,6 +94,11 @@ func TestFetchDefinition(t *testing.T) { _, _ = w.Write(b) case "/nonok": w.WriteHeader(http.StatusInternalServerError) + case "/tooLarge": + // Simulate a response that exceeds maxDefinitionSize (16MB) + // Write 17MB of data to trigger the size limit + largeData := make([]byte, 17*1024*1024) + _, _ = w.Write(largeData) } })) defer server.Close() @@ -103,6 +108,7 @@ func TestFetchDefinition(t *testing.T) { url string want Definition wantErr bool + errMsg string }{ { name: "Fetch valid definition", @@ -122,12 +128,25 @@ func TestFetchDefinition(t *testing.T) { want: invalidDef, wantErr: true, }, + { + name: "Definition file too large (memory exhaustion protection)", + url: fmt.Sprintf("%s/%s", server.URL, "tooLarge"), + want: invalidDef, + wantErr: true, + errMsg: "definition file too large", + }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := FetchDefinition(context.Background(), tt.url) if tt.wantErr { require.Error(t, err) + + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + } + return } diff --git a/dkg/dkg_test.go b/dkg/dkg_test.go index d8d69e7e4..e1d5867c5 100644 --- a/dkg/dkg_test.go +++ b/dkg/dkg_test.go @@ -437,16 +437,17 @@ func verifyDKGResults(t *testing.T, def cluster.Definition, dir string) { } } - // Ensure keystores can generate valid tbls aggregate signature. + // Ensure keystores can generate valid tbls threshold aggregate signature. for i := range def.NumValidators { - var sigs []tbls.Signature + sigsByIdx := make(map[int]tbls.Signature) + msg := []byte("data") for j := range len(def.Operators) { - msg := []byte("data") sig, err := tbls.Sign(secretShares[i][j], msg) require.NoError(t, err) - sigs = append(sigs, sig) + // Use 1-based share indices as production does + sigsByIdx[j+1] = sig // Ensure all public shares can verify the partial signature for _, lock := range locks { @@ -461,7 +462,13 @@ func verifyDKGResults(t *testing.T, def cluster.Definition, dir string) { } } - _, err := tbls.Aggregate(sigs) + // Use ThresholdAggregate (Lagrange interpolation) instead of simple Aggregate + // to ensure share indices are correct - this is what production uses. + aSig, err := tbls.ThresholdAggregate(sigsByIdx) + require.NoError(t, err) + + // Verify against the validator's full public key + err = tbls.Verify(tbls.PublicKey(locks[0].Validators[i].PubKey), msg, aSig) require.NoError(t, err) } } diff --git a/dkg/frostp2p.go b/dkg/frostp2p.go index 04b0c73a0..896332ced 100644 --- a/dkg/frostp2p.go +++ b/dkg/frostp2p.go @@ -140,8 +140,13 @@ func newBcastCallback(peers map[peer.ID]cluster.NodeIdx, round1CastsRecv chan *p return errors.New("invalid round 1 casts message") } + peerNode, ok := peers[pID] + if !ok { + return errors.New("unknown peer in round 1 cast", z.Any("peer", p2p.PeerName(pID))) + } + for _, cast := range msg.GetCasts() { - if int(cast.GetKey().GetSourceId()) != peers[pID].ShareIdx { + if int(cast.GetKey().GetSourceId()) != peerNode.ShareIdx { return errors.New("invalid round 1 cast source ID") } else if cast.GetKey().GetTargetId() != 0 { return errors.New("invalid round 1 cast target ID") @@ -174,8 +179,13 @@ func newBcastCallback(peers map[peer.ID]cluster.NodeIdx, round1CastsRecv chan *p return errors.New("invalid round 2 casts message") } + peerNode, ok := peers[pID] + if !ok { + return errors.New("unknown peer in round 2 cast", z.Any("peer", p2p.PeerName(pID))) + } + for _, cast := range msg.GetCasts() { - if int(cast.GetKey().GetSourceId()) != peers[pID].ShareIdx { + if int(cast.GetKey().GetSourceId()) != peerNode.ShareIdx { return errors.New("invalid round 2 cast source ID") } else if cast.GetKey().GetTargetId() != 0 { return errors.New("invalid round 2 cast target ID") @@ -209,10 +219,20 @@ func newP2PCallback(p2pNode host.Host, peers map[peer.ID]cluster.NodeIdx, round1 return nil, false, errors.New("invalid round 1 p2p message") } + sourcePeer, ok := peers[pID] + if !ok { + return nil, false, errors.New("unknown source peer in round 1 p2p", z.Any("peer", p2p.PeerName(pID))) + } + + targetPeer, ok := peers[p2pNode.ID()] + if !ok { + return nil, false, errors.New("unknown target peer in round 1 p2p", z.Any("peer", p2p.PeerName(p2pNode.ID()))) + } + for _, share := range msg.GetShares() { - if int(share.GetKey().GetSourceId()) != peers[pID].ShareIdx { + if int(share.GetKey().GetSourceId()) != sourcePeer.ShareIdx { return nil, false, errors.New("invalid round 1 p2p source ID") - } else if int(share.GetKey().GetTargetId()) != peers[p2pNode.ID()].ShareIdx { + } else if int(share.GetKey().GetTargetId()) != targetPeer.ShareIdx { return nil, false, errors.New("invalid round 1 p2p target ID") } else if int(share.GetKey().GetValIdx()) < 0 || int(share.GetKey().GetValIdx()) >= numVals { return nil, false, errors.New("invalid round 1 p2p validator index") @@ -220,7 +240,7 @@ func newP2PCallback(p2pNode host.Host, peers map[peer.ID]cluster.NodeIdx, round1 } if dedupRound1P2P[pID] { - log.Debug(ctx, "Ignoring duplicate round 2 message", z.Any("peer", p2p.PeerName(pID))) + log.Debug(ctx, "Ignoring duplicate round 1 message", z.Any("peer", p2p.PeerName(pID))) return nil, false, nil } @@ -452,7 +472,7 @@ func round1CastFromProto(cast *pb.FrostRound1Cast) (msgKey, frost.Round1Bcast, e ci, err := curve.Scalar.SetBytes(cast.GetCi()) if err != nil { - return msgKey{}, frost.Round1Bcast{}, errors.Wrap(err, "decode c1 scalar") + return msgKey{}, frost.Round1Bcast{}, errors.Wrap(err, "decode ci scalar") } var comms []curves.Point @@ -498,7 +518,7 @@ func round2CastFromProto(cast *pb.FrostRound2Cast) (msgKey, frost.Round2Bcast, e vkShare, err := curve.Point.FromAffineCompressed(cast.GetVkShare()) if err != nil { - return msgKey{}, frost.Round2Bcast{}, errors.Wrap(err, "decode c1 scalar") + return msgKey{}, frost.Round2Bcast{}, errors.Wrap(err, "decode verification key share") } key, err := keyFromProto(cast.GetKey()) diff --git a/dkg/nodesigs.go b/dkg/nodesigs.go index 0954d9911..7789077da 100644 --- a/dkg/nodesigs.go +++ b/dkg/nodesigs.go @@ -118,7 +118,7 @@ func (n *nodeSigBcast) setSig(sig []byte, slot int) { } // broadcastCallback is the default bcast.Callback for nodeSigBcast. -func (n *nodeSigBcast) broadcastCallback(ctx context.Context, _ peer.ID, _ string, msg proto.Message) error { +func (n *nodeSigBcast) broadcastCallback(ctx context.Context, senderID peer.ID, _ string, msg proto.Message) error { nodeSig, ok := msg.(*dkgpb.MsgNodeSig) if !ok { return errors.New("invalid node sig type") @@ -138,6 +138,11 @@ func (n *nodeSigBcast) broadcastCallback(ctx context.Context, _ peer.ID, _ strin return errors.New("invalid peer index") } + // Verify that the actual sender's peer ID matches the claimed peer index + if n.peers[msgPeerIdx].ID != senderID { + return errors.New("sender peer ID does not match claimed peer index") + } + lockHash, err := n.lockHash(ctx) if err != nil { return errors.Wrap(err, "lock hash wait") diff --git a/dkg/nodesigs_internal_test.go b/dkg/nodesigs_internal_test.go index 6d90eeea2..78aec5c69 100644 --- a/dkg/nodesigs_internal_test.go +++ b/dkg/nodesigs_internal_test.go @@ -227,6 +227,24 @@ func TestSigsCallbacks(t *testing.T) { require.ErrorContains(t, err, "invalid node sig type") }) + t.Run("sender peer ID mismatch", func(t *testing.T) { + ns.lockHashData = bytes.Repeat([]byte{42}, 32) + + msg := &dkgpb.MsgNodeSig{ + Signature: bytes.Repeat([]byte{42}, 65), + PeerIndex: uint32(2), // Claims to be from peer 2 + } + + // But actually sent by peer 1 + err := ns.broadcastCallback(context.Background(), + peers[1], + "", + msg, + ) + + require.ErrorContains(t, err, "sender peer ID does not match claimed peer index") + }) + t.Run("signature verification failed", func(t *testing.T) { ns.lockHashData = bytes.Repeat([]byte{42}, 32) @@ -236,7 +254,7 @@ func TestSigsCallbacks(t *testing.T) { } err := ns.broadcastCallback(context.Background(), - peers[0], + peers[2], "", msg, ) @@ -251,7 +269,7 @@ func TestSigsCallbacks(t *testing.T) { } err := ns.broadcastCallback(context.Background(), - peers[0], + peers[2], "", msg, ) diff --git a/dkg/pedersen/dkg.go b/dkg/pedersen/dkg.go index 31803ae87..ae89f6665 100644 --- a/dkg/pedersen/dkg.go +++ b/dkg/pedersen/dkg.go @@ -13,6 +13,8 @@ import ( "github.com/obolnetwork/charon/app/errors" "github.com/obolnetwork/charon/app/log" + "github.com/obolnetwork/charon/app/z" + "github.com/obolnetwork/charon/cluster" "github.com/obolnetwork/charon/dkg/share" "github.com/obolnetwork/charon/tbls" ) @@ -51,17 +53,22 @@ func RunDKG(ctx context.Context, config *Config, board *Board, numVals int) ([]s return int(a.Index) - int(b.Index) }) - nonce, err := generateNonce(nodes) - if err != nil { + threshold := config.Threshold + if threshold <= 0 { + threshold = cluster.Threshold(len(nodes)) + + log.Info(ctx, "Using default threshold", z.Int("threshold", threshold)) + } + + if err := validateThreshold(len(nodes), threshold); err != nil { return nil, err } dkgConfig := &kdkg.Config{ Longterm: nodePrivateKey, - Nonce: nonce, Suite: config.Suite, NewNodes: nodes, - Threshold: config.Threshold, + Threshold: threshold, FastSync: true, Auth: drandbls.NewSchemeOnG2(kbls.NewBLS12381Suite()), Log: newLogger(log.WithTopic(ctx, "pedersen")), @@ -71,7 +78,14 @@ func RunDKG(ctx context.Context, config *Config, board *Board, numVals int) ([]s shares := make([]share.Share, 0, numVals) - for range numVals { + for i := range numVals { + nonce, err := generateNonce(nodes, i) + if err != nil { + return nil, err + } + + dkgConfig.Nonce = nonce + phaser := kdkg.NewTimePhaser(config.PhaseDuration) protocol, err := kdkg.NewProtocol( @@ -181,10 +195,10 @@ func processKey(ctx context.Context, config *Config, board *Board, key *kdkg.Dis publicShares := make(map[int]tbls.PublicKey) - for i, oi := range oldShareIndices { + for _, oi := range oldShareIndices { var pk tbls.PublicKey copy(pk[:], oldShareRevMap[oi]) - publicShares[i+1] = pk + publicShares[oi] = pk } return share.Share{ @@ -208,3 +222,17 @@ func readBoardChannel[T any](ctx context.Context, ch <-chan T, count int) ([]T, return pubKeys, nil } + +// validateThreshold verifies that the threshold is between 1 and nodeCount. +// Note that in case of rotation we cannot increase the threshold beyond the original cluster size. +func validateThreshold(nodeCount, threshold int) error { + if threshold < 1 { + return errors.New("threshold is too low", z.Int("threshold", threshold)) + } + + if threshold > nodeCount { + return errors.New("threshold exceeds node count", z.Int("threshold", threshold), z.Int("nodes", nodeCount)) + } + + return nil +} diff --git a/dkg/pedersen/reshare.go b/dkg/pedersen/reshare.go index d1b3d66c9..d8ddbee03 100644 --- a/dkg/pedersen/reshare.go +++ b/dkg/pedersen/reshare.go @@ -5,7 +5,6 @@ package pedersen import ( "bytes" "context" - "crypto/sha256" "slices" "github.com/drand/kyber" @@ -13,16 +12,18 @@ import ( kshare "github.com/drand/kyber/share" kdkg "github.com/drand/kyber/share/dkg" drandbls "github.com/drand/kyber/sign/bdn" + ssz "github.com/ferranbt/fastssz" "github.com/obolnetwork/charon/app/errors" "github.com/obolnetwork/charon/app/log" "github.com/obolnetwork/charon/app/z" + "github.com/obolnetwork/charon/cluster" "github.com/obolnetwork/charon/dkg/share" "github.com/obolnetwork/charon/tbls" ) // RunReshareDKG runs the core reshare protocol for add/remove operators or just reshare. -func RunReshareDKG(ctx context.Context, config *Config, board *Board, shares []share.Share) ([]share.Share, error) { +func RunReshareDKG(ctx context.Context, config *Config, board *Board, shares []share.Share, expectedValidatorPubKeys []tbls.PublicKey) ([]share.Share, error) { if config.Reshare == nil { return nil, errors.New("reshare config is nil") } @@ -158,13 +159,19 @@ func RunReshareDKG(ctx context.Context, config *Config, board *Board, shares []s } } - // Validate node classification - if len(config.Reshare.RemovedPeers) > 0 && len(oldNodes) == 0 { - return nil, errors.New("remove operation requires at least one node with existing shares to participate") + // For remove-only operations, reassign compact indices to newNodes. + // This ensures new shares are evaluated at x=1,2,3,...,n (compact) rather than + // the original gapped indices. The oldNodes keep their original indices for + // correct Lagrange interpolation of contributed shares. + if len(config.Reshare.RemovedPeers) > 0 && len(config.Reshare.AddedPeers) == 0 { + for i := range newNodes { + newNodes[i].Index = kdkg.Index(i) + } } - if len(config.Reshare.AddedPeers) > 0 && len(newNodes) <= len(oldNodes) { - return nil, errors.New("add operation requires new nodes to join, but all nodes already exist in the cluster") + // Validate old/new node counts against the reshare operation being performed. + if err := validateReshareNodeCounts(len(oldNodes), len(newNodes), config.Threshold, config.Reshare); err != nil { + return nil, err } // In add operations, old nodes must have shares to contribute @@ -173,18 +180,46 @@ func RunReshareDKG(ctx context.Context, config *Config, board *Board, shares []s return nil, errors.New("existing node in add operation must have shares to contribute") } - nonce, err := generateNonce(nodes) - if err != nil { - return nil, err + // Validate that at least one old node remains after removal operation + if len(config.Reshare.RemovedPeers) > 0 { + oldNodesRemaining := 0 + + for _, oldNode := range oldNodes { + // Check if this old node is in the new cluster + for _, newNode := range newNodes { + if oldNode.Index == newNode.Index { + oldNodesRemaining++ + break + } + } + } + + if oldNodesRemaining == 0 { + return nil, errors.New("remove operation would remove all nodes from original cluster, at least one original node must remain", + z.Int("old_nodes", len(oldNodes)), + z.Int("removed_peers", len(config.Reshare.RemovedPeers)), + ) + } + } + + newThreshold := config.Reshare.NewThreshold + if newThreshold <= 0 { + newThreshold = cluster.Threshold(len(newNodes)) + + log.Info(ctx, "Using default new threshold", z.Int("new_threshold", newThreshold)) + } + + // Validate new threshold against resulting cluster size + if err := validateThreshold(len(newNodes), newThreshold); err != nil { + return nil, errors.Wrap(err, "invalid new threshold") } reshareConfig := &kdkg.Config{ Longterm: nodePrivateKey, - Nonce: nonce, Suite: config.Suite, NewNodes: newNodes, OldNodes: oldNodes, - Threshold: config.Reshare.NewThreshold, + Threshold: newThreshold, OldThreshold: config.Threshold, FastSync: true, Auth: drandbls.NewSchemeOnG2(kbls.NewBLS12381Suite()), @@ -193,12 +228,19 @@ func RunReshareDKG(ctx context.Context, config *Config, board *Board, shares []s log.Info(ctx, "Starting pedersen reshare...", z.Int("oldNodes", len(oldNodes)), z.Int("newNodes", len(newNodes)), - z.Int("oldThreshold", config.Threshold), z.Int("newThreshold", config.Reshare.NewThreshold), + z.Int("oldThreshold", config.Threshold), z.Int("newThreshold", newThreshold), z.Bool("thisIsOldNode", thisIsOldNode), z.Bool("thisIsRemovedNode", thisIsRemovedNode)) newShares := make([]share.Share, 0, config.Reshare.TotalShares) for shareNum := range config.Reshare.TotalShares { + nonce, err := generateNonce(nodes, shareNum) + if err != nil { + return nil, err + } + + reshareConfig.Nonce = nonce + phaser := kdkg.NewTimePhaser(config.PhaseDuration) // Nodes with existing shares provide their share to the reshare protocol. @@ -211,7 +253,13 @@ func RunReshareDKG(ctx context.Context, config *Config, board *Board, shares []s reshareConfig.PublicCoeffs = nil } else { // This is a new node - restore public coefficients from exchanged public key shares - commits, err := restoreCommits(pubKeyShares, shareNum, config.Threshold) + // Validate that the recovered group public key matches the expected validator public key + var expectedPubKey *tbls.PublicKey + if shareNum < len(expectedValidatorPubKeys) { + expectedPubKey = &expectedValidatorPubKeys[shareNum] + } + + commits, err := restoreCommits(pubKeyShares, shareNum, config.Threshold, expectedPubKey) if err != nil { return nil, errors.Wrap(err, "restore commits") } @@ -274,7 +322,8 @@ func broadcastNoneKey(ctx context.Context, config *Config, board *Board) error { // restoreCommitsFromPubShares recovers public polynomial commits from a map of public key shares. // The nodeIdx in the map is 0-indexed. -func restoreCommitsFromPubShares(pubSharesBytes map[int][]byte, threshold int) ([]kyber.Point, error) { +// If expectedValidatorPubKey is provided, validates that the recovered group public key matches. +func restoreCommitsFromPubShares(pubSharesBytes map[int][]byte, threshold int, expectedValidatorPubKey *tbls.PublicKey) ([]kyber.Point, error) { var ( suite = kbls.NewBLS12381Suite() kyberPubShares []*kshare.PubShare @@ -300,6 +349,22 @@ func restoreCommitsFromPubShares(pubSharesBytes map[int][]byte, threshold int) ( _, commits := pubPoly.Info() + // Validate the recovered group public key against the expected validator public key. + if expectedValidatorPubKey != nil { + if len(commits) == 0 { + return nil, errors.New("no commits recovered") + } + + recoveredPubKeyBytes, err := commits[0].MarshalBinary() + if err != nil { + return nil, errors.Wrap(err, "marshal recovered public key") + } + + if !bytes.Equal(recoveredPubKeyBytes, expectedValidatorPubKey[:]) { + return nil, errors.New("recovered group public key does not match expected validator public key") + } + } + return commits, nil } @@ -310,7 +375,7 @@ func restoreDistKeyShare(keyShare share.Share, threshold int, nodeIdx int) (*kdk pubSharesBytes[shareIdx-1] = pks[:] } - commits, err := restoreCommitsFromPubShares(pubSharesBytes, threshold) + commits, err := restoreCommitsFromPubShares(pubSharesBytes, threshold, nil) if err != nil { return nil, errors.Wrap(err, "restore commits") } @@ -345,32 +410,78 @@ func restoreDistKeyShare(keyShare share.Share, threshold int, nodeIdx int) (*kdk return dks, nil } -func restoreCommits(publicShares map[int][][]byte, shareNum, threshold int) ([]kyber.Point, error) { +func restoreCommits(publicShares map[int][][]byte, shareNum, threshold int, expectedValidatorPubKey *tbls.PublicKey) ([]kyber.Point, error) { + // Validate that all nodes have sufficient shares before accessing + for nodeIdx, pks := range publicShares { + if shareNum >= len(pks) { + return nil, errors.New("insufficient public key shares from node", + z.Int("node_index", nodeIdx), + z.Int("share_num", shareNum), + z.Int("available_shares", len(pks)), + ) + } + } + // Extract the specific share's public keys for all nodes pubSharesBytes := make(map[int][]byte) for nodeIdx, pks := range publicShares { pubSharesBytes[nodeIdx] = pks[shareNum] } - return restoreCommitsFromPubShares(pubSharesBytes, threshold) + return restoreCommitsFromPubShares(pubSharesBytes, threshold, expectedValidatorPubKey) } -func generateNonce(nodes []kdkg.Node) ([]byte, error) { - var buf bytes.Buffer +func generateNonce(nodes []kdkg.Node, iteration int) ([]byte, error) { + hh := ssz.DefaultHasherPool.Get() + defer ssz.DefaultHasherPool.Put(hh) - for _, node := range nodes { - pkBytes, err := node.Public.MarshalBinary() - if err != nil { - return nil, errors.Wrap(err, "marshal node public key") - } + indx := hh.Index() - _, err = buf.Write(pkBytes) - if err != nil { - return nil, errors.Wrap(err, "hash node public key") + // Field (0) 'iteration' + hh.PutUint32(uint32(iteration)) + + // Field (1) 'nodes' - list of (index, pubkey) pairs + { + subIndx := hh.Index() + + for _, node := range nodes { + elemIndx := hh.Index() + + hh.PutUint32(node.Index) + + pkBytes, err := node.Public.MarshalBinary() + if err != nil { + return nil, errors.Wrap(err, "marshal node public key") + } + + hh.PutBytes(pkBytes) + hh.Merkleize(elemIndx) } + + hh.MerkleizeWithMixin(subIndx, uint64(len(nodes)), uint64(len(nodes))) } - hash := sha256.Sum256(buf.Bytes()) + hh.Merkleize(indx) + + hash, err := hh.HashRoot() + if err != nil { + return nil, errors.Wrap(err, "hash root") + } return hash[:], nil } + +// validateReshareNodeCounts validates that there are enough nodes to complete the reshare. +// KDKG requires at least oldThreshold old nodes to reconstruct the secret polynomial. +func validateReshareNodeCounts(oldNodesCount, newNodesCount, oldThreshold int, reshare *ReshareConfig) error { + if len(reshare.RemovedPeers) > 0 && oldNodesCount < oldThreshold { + return errors.New("remove operation requires at least threshold nodes with existing shares to participate", + z.Int("old_nodes", oldNodesCount), z.Int("threshold", oldThreshold)) + } + + if len(reshare.AddedPeers) > 0 && newNodesCount <= oldNodesCount { + return errors.New("add operation requires new nodes to join, but all nodes already exist in the cluster") + } + + return nil +} diff --git a/dkg/pedersen/reshare_internal_test.go b/dkg/pedersen/reshare_internal_test.go index 57fafad54..430666634 100644 --- a/dkg/pedersen/reshare_internal_test.go +++ b/dkg/pedersen/reshare_internal_test.go @@ -5,6 +5,9 @@ package pedersen import ( "testing" + kbls "github.com/drand/kyber-bls12381" + kdkg "github.com/drand/kyber/share/dkg" + "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" "github.com/obolnetwork/charon/dkg/share" @@ -54,3 +57,186 @@ func TestRestoreDistKeyShare(t *testing.T) { require.Error(t, err) }) } + +func TestValidateReshareNodeCounts(t *testing.T) { + tests := []struct { + name string + oldNodesCount int + newNodesCount int + oldThreshold int + reshare *ReshareConfig + expectError bool + errorContains string + }{ + { + name: "no removals or additions - always valid", + oldNodesCount: 4, + newNodesCount: 4, + oldThreshold: 3, + reshare: &ReshareConfig{}, + expectError: false, + }, + { + name: "removals with enough old nodes", + oldNodesCount: 3, + newNodesCount: 3, + oldThreshold: 3, + reshare: &ReshareConfig{RemovedPeers: []peer.ID{"peer1"}}, + expectError: false, + }, + { + name: "removals with more than threshold old nodes", + oldNodesCount: 4, + newNodesCount: 3, + oldThreshold: 3, + reshare: &ReshareConfig{RemovedPeers: []peer.ID{"peer1"}}, + expectError: false, + }, + { + name: "removals with insufficient old nodes", + oldNodesCount: 2, + newNodesCount: 2, + oldThreshold: 3, + reshare: &ReshareConfig{RemovedPeers: []peer.ID{"peer1"}}, + expectError: true, + errorContains: "remove operation requires at least threshold nodes", + }, + { + name: "removals with zero old nodes (complete replacement)", + oldNodesCount: 0, + newNodesCount: 5, + oldThreshold: 3, + reshare: &ReshareConfig{RemovedPeers: []peer.ID{"peer1"}, AddedPeers: []peer.ID{"peer2"}}, + expectError: true, + errorContains: "remove operation requires at least threshold nodes", + }, + { + name: "additions with new nodes joining", + oldNodesCount: 4, + newNodesCount: 5, + oldThreshold: 3, + reshare: &ReshareConfig{AddedPeers: []peer.ID{"peer1"}}, + expectError: false, + }, + { + name: "additions without new nodes joining", + oldNodesCount: 4, + newNodesCount: 4, + oldThreshold: 3, + reshare: &ReshareConfig{AddedPeers: []peer.ID{"peer1"}}, + expectError: true, + errorContains: "add operation requires new nodes to join", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateReshareNodeCounts(tc.oldNodesCount, tc.newNodesCount, tc.oldThreshold, tc.reshare) + if tc.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tc.errorContains) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestRestoreCommitsOutOfBounds(t *testing.T) { + tests := []struct { + name string + publicShares map[int][][]byte + shareNum int + threshold int + expectError bool + }{ + { + name: "share number exceeds available shares", + publicShares: map[int][][]byte{ + 0: {[]byte("share0_0"), []byte("share0_1")}, + 1: {[]byte("share1_0"), []byte("share1_1")}, + 2: {[]byte("share2_0"), []byte("share2_1")}, + }, + shareNum: 2, // Requesting index 2, but only 0 and 1 exist + threshold: 2, + expectError: true, + }, + { + name: "one node has insufficient shares", + publicShares: map[int][][]byte{ + 0: {[]byte("share0_0"), []byte("share0_1"), []byte("share0_2")}, + 1: {[]byte("share1_0"), []byte("share1_1")}, // Only 2 shares + 2: {[]byte("share2_0"), []byte("share2_1"), []byte("share2_2")}, + }, + shareNum: 2, // Node 1 doesn't have index 2 + threshold: 2, + expectError: true, + }, + { + name: "empty shares with non-zero shareNum", + publicShares: map[int][][]byte{ + 0: {}, + 1: {}, + }, + shareNum: 0, + threshold: 1, + expectError: true, + }, + { + name: "valid access within bounds", + publicShares: map[int][][]byte{ + 0: {[]byte("share0_0"), []byte("share0_1"), []byte("share0_2")}, + 1: {[]byte("share1_0"), []byte("share1_1"), []byte("share1_2")}, + }, + shareNum: 1, + threshold: 2, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := restoreCommits(tt.publicShares, tt.shareNum, tt.threshold, nil) + + if tt.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), "insufficient public key shares from node") + } else if err != nil { + // Valid cases might still error due to invalid key data, + // but should not error with bounds message + require.NotContains(t, err.Error(), "insufficient public key shares") + } + }) + } +} + +func TestGenerateNonce(t *testing.T) { + suite := kbls.NewBLS12381Suite().G1().(kdkg.Suite) + _, pub1 := randomKeyPair(suite) + _, pub2 := randomKeyPair(suite) + _, pub3 := randomKeyPair(suite) + + nodes := []kdkg.Node{ + {Index: 1, Public: pub1}, + {Index: 2, Public: pub2}, + {Index: 3, Public: pub3}, + } + + nonce1, err := generateNonce(nodes, 0) + require.NoError(t, err) + + nonce2, err := generateNonce(nodes, 1) + require.NoError(t, err) + + require.NotEqual(t, nonce1, nonce2) + + nodes = []kdkg.Node{ + {Index: 1, Public: pub1}, + {Index: 2, Public: pub2}, + } + + nonce3, err := generateNonce(nodes, 1) + require.NoError(t, err) + + require.NotEqual(t, nonce2, nonce3) +} diff --git a/dkg/pedersen/reshare_test.go b/dkg/pedersen/reshare_test.go index 60769e203..f92985851 100644 --- a/dkg/pedersen/reshare_test.go +++ b/dkg/pedersen/reshare_test.go @@ -79,10 +79,20 @@ func TestRunReshare(t *testing.T) { group, gctx := errgroup.WithContext(t.Context()) + // Extract expected validator public keys from old shares + // All nodes should have the same validator public keys + var expectedValidatorPubKeys []tbls.PublicKey + if len(oldShares) > 0 && len(oldShares[0]) > 0 { + expectedValidatorPubKeys = make([]tbls.PublicKey, len(oldShares[0])) + for i, share := range oldShares[0] { + expectedValidatorPubKeys[i] = share.PubKey + } + } + for n := range nodes { group.Go(func() error { nodes[n].Config.Reshare = &pedersen.ReshareConfig{TotalShares: numVals, NewThreshold: threshold} - shares, err := pedersen.RunReshareDKG(gctx, nodes[n].Config, nodes[n].Board, oldShares[n]) + shares, err := pedersen.RunReshareDKG(gctx, nodes[n].Config, nodes[n].Board, oldShares[n], expectedValidatorPubKeys) nodes[n].Shares = shares return err diff --git a/dkg/pedersen/testutils.go b/dkg/pedersen/testutils.go index 3c385b496..3ae158577 100644 --- a/dkg/pedersen/testutils.go +++ b/dkg/pedersen/testutils.go @@ -130,9 +130,8 @@ func VerifyShares(t *testing.T, nodes []*TestNode, numVals, threshold int) { for v := range numVals { var ( - sigs []tbls.Signature - pshares []tbls.PublicKey - secrets = make(map[int]tbls.PrivateKey) + sigsByIdx = make(map[int]tbls.Signature) + secrets = make(map[int]tbls.PrivateKey) ) for _, node := range nodes { @@ -145,15 +144,16 @@ func VerifyShares(t *testing.T, nodes []*TestNode, numVals, threshold int) { err = tbls.Verify(pubKeyShare, msg, sig) require.NoError(t, err) - sigs = append(sigs, sig) - pshares = append(pshares, pubKeyShare) + sigsByIdx[node.NodeIdx.ShareIdx] = sig secrets[node.NodeIdx.ShareIdx] = node.Shares[v].SecretShare } - aggSig, err := tbls.Aggregate(sigs) + // Use ThresholdAggregate (Lagrange interpolation) instead of simple Aggregate + // to ensure share indices are correct - this is what production uses. + aggSig, err := tbls.ThresholdAggregate(sigsByIdx) require.NoError(t, err) - err = tbls.VerifyAggregate(pshares, aggSig, msg) + err = tbls.Verify(nodes[0].Shares[v].PubKey, msg, aggSig) require.NoError(t, err) recSecret, err := tbls.RecoverSecret(secrets, uint(len(nodes)), uint(threshold)) diff --git a/dkg/pedersen/utils_internal_test.go b/dkg/pedersen/utils_internal_test.go index 3a2867842..b597bbb6c 100644 --- a/dkg/pedersen/utils_internal_test.go +++ b/dkg/pedersen/utils_internal_test.go @@ -86,3 +86,45 @@ func TestKeyShareToBLS(t *testing.T) { require.Equal(t, pubKeyBytes, pubKey[:]) }) } + +func TestValidateThreshold(t *testing.T) { + tests := []struct { + name string + nodeCount int + threshold int + wantErr bool + errMsg string + }{ + { + name: "valid threshold at maximum (equals node count)", + nodeCount: 5, + threshold: 5, + wantErr: false, + }, + { + name: "valid threshold between minimum and maximum", + nodeCount: 7, + threshold: 6, + wantErr: false, + }, + { + name: "invalid threshold below minimum", + nodeCount: 4, + threshold: 0, // minimum is 1 + wantErr: true, + errMsg: "threshold is too low", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateThreshold(tt.nodeCount, tt.threshold) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/dkg/protocol_removeoperators.go b/dkg/protocol_removeoperators.go index da8fbdfd6..a7d7b779e 100644 --- a/dkg/protocol_removeoperators.go +++ b/dkg/protocol_removeoperators.go @@ -149,7 +149,7 @@ func (p *removeOperatorsProtocol) PostInit(ctx context.Context, pctx *ProtocolCo nodeIdx := slices.Index(newPeerIDs, pctx.ThisPeerID) pctx.ThisNodeIdx = cluster.NodeIdx{ PeerIdx: nodeIdx, - ShareIdx: nodeIdx + 1, + ShareIdx: peerMap[pctx.ThisPeerID].ShareIdx, } pctx.SigExchanger = newExchanger(pctx.ThisNode, nodeIdx, newPeerIDs, []sigType{sigLock}, pctx.Config.Timeout) } diff --git a/dkg/protocol_test.go b/dkg/protocol_test.go index 7d7d218c0..931daab8f 100644 --- a/dkg/protocol_test.go +++ b/dkg/protocol_test.go @@ -10,6 +10,7 @@ import ( "slices" "strconv" "strings" + "sync/atomic" "testing" "time" @@ -147,6 +148,65 @@ func TestRemoveOperatorsProtocol_MoreThanF(t *testing.T) { verifyClusterValidators(t, numValidators, outputNodeDirs) } +func TestRemoveOperatorsProtocol_AllNodes(t *testing.T) { + const ( + numValidators = 3 + numNodes = 4 + threshold = 3 + ) + + srcClusterDir := createTestCluster(t, numNodes, threshold, numValidators) + dstClusterDir := t.TempDir() + + lockFilePath := path.Join(nodeDir(srcClusterDir, 0), clusterLockFile) + lock, err := dkg.LoadAndVerifyClusterLock(t.Context(), lockFilePath, "", false) + require.NoError(t, err) + + // Attempt to remove all nodes from the original cluster + oldENRs := []string{ + lock.Operators[0].ENR, + lock.Operators[1].ENR, + lock.Operators[2].ENR, + lock.Operators[3].ENR, + } + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + var errorReported atomic.Bool + + runProtocol(t, numNodes, func(relayAddr string, n int) error { + dkgConfig := createDKGConfig(t, relayAddr) + ndir := nodeDir(srcClusterDir, n) + removeConfig := dkg.RemoveOperatorsConfig{ + LockFilePath: path.Join(ndir, clusterLockFile), + PrivateKeyPath: p2p.KeyPath(ndir), + ValidatorKeysDir: path.Join(ndir, validatorKeysDir), + OutputDir: nodeDir(dstClusterDir, n), + RemovingENRs: oldENRs, + ParticipatingENRs: []string{ + lock.Operators[0].ENR, + lock.Operators[1].ENR, + lock.Operators[2].ENR, + lock.Operators[3].ENR, + }, + } + + err := dkg.RunRemoveOperatorsProtocol(ctx, removeConfig, dkgConfig) + if err != nil { + if strings.Contains(err.Error(), "remove operation would remove all nodes from original cluster") { + errorReported.Store(true) + } + + return nil + } + + return err + }) + + require.True(t, errorReported.Load(), "Expected error when attempting to remove all nodes from original cluster") +} + func TestRunAddOperatorsProtocol(t *testing.T) { const ( numValidators = 3 @@ -482,35 +542,37 @@ func verifyClusterValidators(t *testing.T, numVals int, nodeDirs []string) { //n } data := []byte("test data") - allSigs := make([][]tbls.Signature, numVals) - clusterPubKeys := make([][]tbls.PublicKey, numVals) + allSigs := make([]map[int]tbls.Signature, numVals) + validatorPubKeys := make([]tbls.PublicKey, numVals) for valIdx := range numVals { - sigs := make([]tbls.Signature, numNodes) + sigsByIdx := make(map[int]tbls.Signature) for nodeIdx := range numNodes { sig, err := tbls.Sign(clusterSecrets[nodeIdx][valIdx], data) require.NoError(t, err) - sigs[nodeIdx] = sig + // Use 1-based share indices as production does + sigsByIdx[nodeIdx+1] = sig } - allSigs[valIdx] = sigs - clusterPubKeys[valIdx] = make([]tbls.PublicKey, numNodes) + allSigs[valIdx] = sigsByIdx - for nodeIdx := range numNodes { - pubKey, err := tbls.SecretToPublicKey(clusterSecrets[nodeIdx][valIdx]) - require.NoError(t, err) + // Get the validator's full public key for verification + lockFilePath := path.Join(nodeDirs[0], clusterLockFile) + lock, err := dkg.LoadAndVerifyClusterLock(t.Context(), lockFilePath, "", false) + require.NoError(t, err) - clusterPubKeys[valIdx][nodeIdx] = pubKey - } + validatorPubKeys[valIdx] = tbls.PublicKey(lock.Validators[valIdx].PubKey) } for valIdx := range numVals { - aSig, err := tbls.Aggregate(allSigs[valIdx]) + // Use ThresholdAggregate (Lagrange interpolation) instead of simple Aggregate + // to ensure share indices are correct - this is what production uses. + aSig, err := tbls.ThresholdAggregate(allSigs[valIdx]) require.NoError(t, err) - err = tbls.VerifyAggregate(clusterPubKeys[valIdx], aSig, data) + err = tbls.Verify(validatorPubKeys[valIdx], data, aSig) require.NoError(t, err) } } diff --git a/dkg/protocolsteps.go b/dkg/protocolsteps.go index c523fcd63..0fdd16dfd 100644 --- a/dkg/protocolsteps.go +++ b/dkg/protocolsteps.go @@ -33,7 +33,18 @@ type reshareProtocolStep struct { } func (s *reshareProtocolStep) Run(ctx context.Context, pctx *ProtocolContext) error { - shares, err := pedersen.RunReshareDKG(ctx, s.config, s.board, pctx.Shares) + // Extract expected validator public keys from the cluster lock + expectedValidatorPubKeys := make([]tbls.PublicKey, len(pctx.Lock.Validators)) + for i, validator := range pctx.Lock.Validators { + pubKey, err := validator.PublicKey() + if err != nil { + return errors.Wrap(err, "get validator public key", z.Int("validator_index", i)) + } + + expectedValidatorPubKeys[i] = pubKey + } + + shares, err := pedersen.RunReshareDKG(ctx, s.config, s.board, pctx.Shares, expectedValidatorPubKeys) if err != nil { return err } diff --git a/dkg/protocolsteps_internal_test.go b/dkg/protocolsteps_internal_test.go index 62eab4994..6571d0f76 100644 --- a/dkg/protocolsteps_internal_test.go +++ b/dkg/protocolsteps_internal_test.go @@ -99,6 +99,16 @@ func TestReshareProtocolStep(t *testing.T) { group, gctx := errgroup.WithContext(t.Context()) + // Create a cluster lock with the expected validator public keys + lock := &cluster.Lock{ + Validators: make([]cluster.DistValidator, numVals), + } + for i, pubKey := range oldPubKeys { + lock.Validators[i] = cluster.DistValidator{ + PubKey: pedersen.MustDecodeHex(t, pubKey), + } + } + for n := range nodes { group.Go(func() error { nodes[n].Config.Reshare = &pedersen.ReshareConfig{TotalShares: numVals, NewThreshold: threshold} @@ -109,6 +119,7 @@ func TestReshareProtocolStep(t *testing.T) { } pctx := &ProtocolContext{ Shares: oldShares[n], + Lock: lock, } return step.Run(gctx, pctx) diff --git a/dkg/sync/server.go b/dkg/sync/server.go index c250a26ba..5f1869bb5 100644 --- a/dkg/sync/server.go +++ b/dkg/sync/server.go @@ -27,7 +27,11 @@ import ( "github.com/obolnetwork/charon/p2p" ) -const protocolID = "/charon/dkg/sync/1.0.0/" +const ( + protocolID = "/charon/dkg/sync/1.0.0/" + + maxMessageSize = 32 * 1024 * 1024 // 32 MB +) // Protocols returns the list of supported Protocols in order of precedence. func Protocols() []protocol.ID { @@ -348,7 +352,7 @@ func writeSizedProto(writer io.Writer, msg proto.Message) error { err = binary.Write(writer, binary.LittleEndian, size) if err != nil { - return errors.Wrap(err, "read size") + return errors.Wrap(err, "write size") } n, err := writer.Write(buf) @@ -370,6 +374,10 @@ func readSizedProto(reader io.Reader, msg proto.Message) error { return errors.Wrap(err, "read size") } + if size <= 0 || size > maxMessageSize { + return errors.New("invalid message size") + } + buf := make([]byte, size) n, err := reader.Read(buf) diff --git a/dkg/sync/server_internal_test.go b/dkg/sync/server_internal_test.go index ee71b02c1..cec3d9a6a 100644 --- a/dkg/sync/server_internal_test.go +++ b/dkg/sync/server_internal_test.go @@ -3,12 +3,15 @@ package sync import ( + "bytes" + "encoding/binary" "testing" "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" "github.com/obolnetwork/charon/app/version" + pb "github.com/obolnetwork/charon/dkg/dkgpb/v1" "github.com/obolnetwork/charon/testutil" ) @@ -57,3 +60,45 @@ func TestUpdateStep(t *testing.T) { require.ErrorContains(t, err, "peer reported step is ahead the last known step") }) } + +func TestReadWriteSizedProto(t *testing.T) { + t.Run("valid message", func(t *testing.T) { + msg := &pb.MsgSync{ + Version: "v0.1.0", + Step: 1, + } + + var buf bytes.Buffer + + err := writeSizedProto(&buf, msg) + require.NoError(t, err) + + reader := bytes.NewReader(buf.Bytes()) + result := &pb.MsgSync{} + err = readSizedProto(reader, result) + require.NoError(t, err) + require.Equal(t, msg.GetVersion(), result.GetVersion()) + require.Equal(t, msg.GetStep(), result.GetStep()) + }) + + t.Run("size too large", func(t *testing.T) { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf[:8], maxMessageSize+1) + + reader := bytes.NewReader(buf) + result := &pb.MsgSync{} + err := readSizedProto(reader, result) + require.ErrorContains(t, err, "invalid message size") + }) + + t.Run("unexpected message length", func(t *testing.T) { + // Create a buffer with size prefix indicating 100 bytes but only 10 bytes of data + buf := make([]byte, 8+10) + binary.LittleEndian.PutUint64(buf[:8], 100) + + reader := bytes.NewReader(buf) + result := &pb.MsgSync{} + err := readSizedProto(reader, result) + require.ErrorContains(t, err, "unexpected message length") + }) +}