diff --git a/packages/wasm-mps/src/lib.rs b/packages/wasm-mps/src/lib.rs index d822b897f14..b8a65a40b8b 100644 --- a/packages/wasm-mps/src/lib.rs +++ b/packages/wasm-mps/src/lib.rs @@ -112,6 +112,7 @@ mod mps { pub struct Share { pub share: Vec, pub pk: [u8; 32], + pub chaincode: [u8; 32], } fn internal_dkg_round0_process( @@ -279,6 +280,7 @@ mod mps { Ok(Share { share: bincode::serialize(&share).map_err(|_| MpsError::SerializationError)?, pk: share.public_key.compress().to_bytes(), + chaincode: share.root_chain_code, }) } @@ -525,14 +527,22 @@ mod tests { ) .unwrap(); - // Assert generated public keys are equal + // Assert generated public keychains are equal assert_eq!( p2_share.pk, p0_share.pk, - "Party 0 share differs from party 2 share" + "Party 0 public key differs from party 2 public key" ); assert_eq!( p2_share.pk, p1_share.pk, - "Party 1 share differs from party 2 share" + "Party 1 public key differs from party 2 public key" + ); + assert_eq!( + p2_share.chaincode, p0_share.chaincode, + "Party 0 chaincode differs from party 2 chaincode" + ); + assert_eq!( + p2_share.chaincode, p1_share.chaincode, + "Party 1 chaincode differs from party 2 chaincode" ); } @@ -697,6 +707,7 @@ impl MsgState { pub struct Share { share: Vec, pk: Vec, + chaincode: Vec, } #[wasm_bindgen] @@ -710,6 +721,11 @@ impl Share { pub fn pk(&self) -> Vec { self.pk.clone() } + + #[wasm_bindgen(getter)] + pub fn chaincode(&self) -> Vec { + self.chaincode.clone() + } } #[wasm_bindgen] @@ -730,6 +746,7 @@ impl MsgShare { Share { share: self.share.share.clone(), pk: self.share.pk.clone(), + chaincode: self.share.chaincode.clone(), } } } @@ -796,6 +813,7 @@ pub fn ed25519_dkg_round2_process(round2_messages: Array, state: &[u8]) -> Resul Ok(Share { share: result.share, pk: result.pk.to_vec(), + chaincode: result.chaincode.to_vec(), }) } diff --git a/packages/wasm-mps/test/mps.ts b/packages/wasm-mps/test/mps.ts index bf46129523f..cd3730ee45d 100644 --- a/packages/wasm-mps/test/mps.ts +++ b/packages/wasm-mps/test/mps.ts @@ -73,6 +73,9 @@ describe("mps", function () { ); for (let i = 0; i < 2; i++) { assert.ok(results3[i].pk.every((value, index) => value === results3[2].pk[index])); + assert.ok( + results3[i].chaincode.every((value, index) => value === results3[2].chaincode[index]), + ); } });