diff --git a/discojs/package.json b/discojs/package.json index 4e3b63a5f..2cc709fca 100644 --- a/discojs/package.json +++ b/discojs/package.json @@ -34,6 +34,7 @@ "@tensorflow/tfjs-node": "4", "@types/simple-peer": "9", "nodemon": "3", - "ts-node": "10" + "ts-node": "10", + "fast-check": "3" } } diff --git a/discojs/src/aggregator/byzantine.spec.ts b/discojs/src/aggregator/byzantine.spec.ts index d300fbb4c..4770947ef 100644 --- a/discojs/src/aggregator/byzantine.spec.ts +++ b/discojs/src/aggregator/byzantine.spec.ts @@ -1,5 +1,6 @@ import { Set } from "immutable"; import { describe, expect, it } from "vitest"; +import fc from "fast-check"; import { WeightsContainer } from "../index.js"; import { ByzantineRobustAggregator } from "./byzantine.js"; @@ -31,8 +32,8 @@ describe("ByzantineRobustAggregator", () => { expect(arr).to.deep.equal([[2], [3]]); }); - it("clips a single outlier with small radius", async () => { - const agg = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 1, 0); + it("reduces influence of a single outlier", async () => { + const agg = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 10, 0); const [c1, c2, bad] = ["c1", "c2", "bad"]; agg.setNodes(Set.of(c1, c2, bad)); @@ -43,21 +44,40 @@ describe("ByzantineRobustAggregator", () => { const out = await p; const arr = await WSIntoArrays(out); - expect(arr[0][0]).to.be.closeTo(1, 1e-6); + + const result = arr[0][0]; + const mean = (1 + 1 + 100) / 3; + + expect(Math.abs(result - 1)).to.be.lessThan(Math.abs(mean - 1)); }); - it("applies multiple clipping iterations (maxIterations > 1)", async () => { - const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1.0, 3, 0); - const [c1, bad] = ["c1", "bad"]; - agg.setNodes(Set.of(c1, bad)); + it("multiple iterations improve the estimate", async () => { + const [c1, c2, bad] = ["c1", "c2", "bad"]; - const p = agg.getPromiseForAggregation(); - agg.add(c1, WeightsContainer.of([0]), 0); - agg.add(bad, WeightsContainer.of([10]), 0); + const agg1 = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 1, 0); + agg1.setNodes(Set.of(c1, c2, bad)); - const out = await p; - const arr = await WSIntoArrays(out); - expect(arr[0][0]).to.be.lessThan(1); // clipped closer to 0 + const p1 = agg1.getPromiseForAggregation(); + agg1.add(c1, WeightsContainer.of([0]), 0); + agg1.add(c2, WeightsContainer.of([0]), 0); + agg1.add(bad, WeightsContainer.of([10]), 0); + const out1 = await p1; + const arr1 = await WSIntoArrays(out1); + + const agg3 = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 3, 0); + agg3.setNodes(Set.of(c1, c2, bad)); + + const p3 = agg3.getPromiseForAggregation(); + agg3.add(c1, WeightsContainer.of([0]), 0); + agg3.add(c2, WeightsContainer.of([0]), 0); + agg3.add(bad, WeightsContainer.of([10]), 0); + const out3 = await p3; + const arr3 = await WSIntoArrays(out3); + + const honest = 0; + + expect(Math.abs(arr3[0][0] - honest)) + .to.be.lessThanOrEqual(Math.abs(arr1[0][0] - honest)); }); it("uses momentum when beta > 0", async () => { @@ -65,21 +85,25 @@ describe("ByzantineRobustAggregator", () => { const [c1, c2] = ["c1", "c2"]; agg.setNodes(Set.of(c1, c2)); + // Round 1 const p1 = agg.getPromiseForAggregation(); agg.add(c1, WeightsContainer.of([2]), 0); agg.add(c2, WeightsContainer.of([2]), 0); const out1 = await p1; const arr1 = await WSIntoArrays(out1); - expect(arr1[0][0]).to.equal(2); + // m₀ = (1 - β) * g = 1 + expect(arr1[0][0]).to.be.closeTo(1, 1e-6); + + // Round 2 const p2 = agg.getPromiseForAggregation(); agg.add(c1, WeightsContainer.of([4]), 1); agg.add(c2, WeightsContainer.of([4]), 1); const out2 = await p2; const arr2 = await WSIntoArrays(out2); - // With momentum = 0.5, result = 0.5 * prev + 0.5 * current = 3.0 - expect(arr2[0][0]).to.be.closeTo(3, 1e-6); + // m₁ = 0.5*4 + 0.5*1 = 2.5 → avg = 2.5 + expect(arr2[0][0]).to.be.closeTo(2.5, 1e-6); }); it("respects roundCutoff — ignores old contributions", async () => { @@ -100,4 +124,186 @@ describe("ByzantineRobustAggregator", () => { const arr2 = await WSIntoArrays(out2); expect(arr2[0][0]).to.equal(20); }); + + it("remains robust with 30% Byzantine clients", async () => { + const honest = Array(7).fill(1); + const byzantine = Array(3).fill(100); + + const agg = new ByzantineRobustAggregator(0, 10, 'absolute', 1.0, 5, 0); + const ids = [...honest, ...byzantine].map((_, i) => `c${i}`); + agg.setNodes(Set(ids)); + + const p = agg.getPromiseForAggregation(); + honest.forEach((v, i) => agg.add(`c${i}`, WeightsContainer.of([v]), 0)); + byzantine.forEach((v, i) => agg.add(`c${i + honest.length}`, WeightsContainer.of([v]), 0)); + + const out = await p; + const arr = await WSIntoArrays(out); + + const honestMean = honest.reduce((a, b) => a + b, 0) / honest.length; + const rawMean = [...honest, ...byzantine].reduce((a, b) => a + b, 0) / (honest.length + byzantine.length); + + expect(Math.abs(arr[0][0] - honestMean)).to.be.lessThan(Math.abs(rawMean - honestMean)); + }); + + it("moves closer to the honest signal under constant input", async () => { + const honest = 1; + + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); + agg.setNodes(Set(["a", "b", "c", "d"])); + + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([1]), 0); + agg.add("c", WeightsContainer.of([1]), 0); + agg.add("d", WeightsContainer.of([10]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + const mean = (1 + 1 + 1 + 10) / 4; + + expect(Math.abs(v - honest)).to.be.lessThan(Math.abs(mean - honest)); + }); + + it("does not significantly worsen deviation compared to mean", async () => { + const clipRadius = 1.0; + + await fc.assert( + fc.asyncProperty( + fc.array( + fc.double({ + min: -1, + max: 1, + noNaN: true, + noDefaultInfinity: true + }), + { minLength: 3, maxLength: 10 } + ) + // avoid degenerate constant arrays (no signal) + .filter(arr => arr.some(v => Math.abs(v - arr[0]) > 1e-8)), + + async (honest) => { + const n = honest.length + 1; + + // clean aggregation + const aggClean = new ByzantineRobustAggregator(0, honest.length, "absolute", clipRadius, 1, 0); + const honestIds = honest.map((_, i) => `h${i}`); + aggClean.setNodes(Set(honestIds)); + + const pClean = aggClean.getPromiseForAggregation(); + honest.forEach((v, i) => aggClean.add(`h${i}`, WeightsContainer.of([v]), 0)); + const cleanOut = await pClean; + const clean = (await cleanOut.weights[0].data())[0]; + + // aggregation with Byzantine + const aggByz = new ByzantineRobustAggregator(0, n, "absolute", clipRadius, 1, 0); + const ids = honestIds.concat("byz"); + aggByz.setNodes(Set(ids)); + + const pByz = aggByz.getPromiseForAggregation(); + honest.forEach((v, i) => aggByz.add(`h${i}`, WeightsContainer.of([v]), 0)); + aggByz.add("byz", WeightsContainer.of([1e9]), 0); + + const byzOut = await pByz; + const byz = (await byzOut.weights[0].data())[0]; + + const deviation = Math.abs(byz - clean); + const mean = [...honest, 1e9].reduce((a, b) => a + b, 0) / n; + const baseline = Math.abs(mean - clean); + + // combined tolerance (absolute + relative) + const ABS_EPS = 1e-6; + const REL_EPS = 1e-6; + + expect(deviation).toBeLessThanOrEqual( + baseline * (1 + REL_EPS) + ABS_EPS + ); + } + ), + { numRuns: 500 } + ); + }); + + it("is invariant to client ordering", async () => { + const values = [0, 1, 100]; + const ids1 = ["a", "b", "c"]; + const ids2 = ["c", "a", "b"]; + + const run = async (ids: string[]) => { + const agg = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 3, 0); + agg.setNodes(Set(ids)); + const p = agg.getPromiseForAggregation(); + ids.forEach((id, i) => + agg.add(id, WeightsContainer.of([values[i]]), 0) + ); + return (await (await p).weights[0].data())[0]; + }; + + const out1 = await run(ids1); + const out2 = await run(ids2); + + expect(out1).to.be.closeTo(out2, 1e-6); + }); + + it("is idempotent when all inputs are identical and within clipping radius", async () => { + const agg = new ByzantineRobustAggregator(0, 5, "absolute", 10.0, 5, 0); + const ids = ["a", "b", "c", "d", "e"]; + agg.setNodes(Set(ids)); + + const p = agg.getPromiseForAggregation(); + ids.forEach(id => agg.add(id, WeightsContainer.of([3.14]), 0)); + const out = await p; + + const v = (await out.weights[0].data())[0]; + expect(v).to.be.closeTo(3.14, 1e-6); + }); + + it("limits bias under symmetric Byzantine attacks", async () => { + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); + agg.setNodes(Set(["h1", "h2", "b1", "b2"])); + + const p = agg.getPromiseForAggregation(); + agg.add("h1", WeightsContainer.of([1]), 0); + agg.add("h2", WeightsContainer.of([1]), 0); + agg.add("b1", WeightsContainer.of([100]), 0); + agg.add("b2", WeightsContainer.of([-100]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + expect(Math.abs(v - 1)).to.be.lessThan(Math.abs((1 + 1 + 100 - 100)/4 - 1)); + }); + + it("reduces influence of extreme outliers", async () => { + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); + agg.setNodes(Set(["a", "b", "c", "d"])); + + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([0]), 0); + agg.add("b", WeightsContainer.of([0.5]), 0); + agg.add("c", WeightsContainer.of([1]), 0); + agg.add("d", WeightsContainer.of([100]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + const mean = (0 + 0.5 + 1 + 100) / 4; + const honestCenter = (0 + 0.5 + 1) / 3; + + expect(Math.abs(v - honestCenter)).to.be.lessThan(Math.abs(mean - honestCenter)); + }); + + it("reset state when starting fresh aggregator", async () => { + const run = async () => { + const agg = new ByzantineRobustAggregator(0, 2, "absolute", 1.0, 3, 0.9); + agg.setNodes(Set(["a", "b"])); + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([1]), 0); + return (await (await p).weights[0].data())[0]; + }; + + expect(await run()).to.be.closeTo(await run(), 1e-6); + }); }); diff --git a/discojs/src/aggregator/byzantine.ts b/discojs/src/aggregator/byzantine.ts index 64d5cbe43..2d5c25e77 100644 --- a/discojs/src/aggregator/byzantine.ts +++ b/discojs/src/aggregator/byzantine.ts @@ -9,8 +9,33 @@ import { aggregation } from "../index.js"; * Byzantine-robust aggregator using Centered Clipping (CC), based on the * "Learning from History for Byzantine Robust Optimization" paper: https://arxiv.org/abs/2012.10333 * - * This class implements a gradient aggregation rule that clips updates - * in an iterative fashion to mitigate the influence of Byzantine nodes, as well as momentum calculations. + * This class implements Centered Clipping (Algorithm 1) with an additional + * server-side per-client momentum mechanism inspired by Algorithm 2. + * + * We initialize using the mean of contributions when no previous + * aggregate exists. This improves convergence compared to zero initialization. + * + * NOTE: + * - Momentum: + * m_i^t = (1 - β) g_i^t + β m_i^{t-1} + * - Aggregation is then performed on {m_i} + * + * WARNING: + * This implementation requires stable client identities and is not + * compatible with secure aggregation, since per-client momentum + * must be tracked on the server. + * + * Use Case: + * + * Designed for federated or distributed learning with potentially malicious + * (Byzantine) clients. Centered Clipping limits the influence of extreme or + * corrupted updates by bounding each client's contribution. + * + * CC alone can be sensitive to poor initialization (e.g., early extreme + * Byzantine updates), as clipping limits updates but does not correct a + * bad initial estimate. The added per-client momentum helps stabilize + * training over time by leveraging historical information. + * */ export class ByzantineRobustAggregator extends MultiRoundAggregator { private readonly clippingRadius: number; @@ -35,7 +60,7 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { * - Type: `number` * - Must be between 0 and 1. * - Used to compute the exponential moving average of past aggregates (i.e., momentum vector). - * The update typically looks like: `v_t = beta * v_{t-1} + (1 - beta) * g_t`, where `g_t` is the current clipped average. + * The update typically looks like: `m_i^t = (1 - β) g_i^t + β m_i^{t-1}`. * - A higher beta gives more weight to past rounds (more smoothing), while a lower beta makes the aggregator more responsive to new updates. */ @@ -60,7 +85,7 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { const prevMomentum = this.historyMomentums.get(nodeId); const newMomentum = prevMomentum ? contribution.mapWith(prevMomentum, (g, m) => g.mul(1 - this.beta).add(m.mul(this.beta))) - : contribution; // no scaling on first momentum + : contribution.map(g => g.mul(1 - this.beta)); this.historyMomentums = this.historyMomentums.set(nodeId, newMomentum); this.contributions = this.contributions.setIn([0, nodeId], newMomentum); @@ -77,34 +102,55 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { return aggregation.avg(currentContributions.values()); } - // Step 1: Initialize v to average of previous aggregations + // Step 1: Initialize v using previous aggregate or mean of contributions let v: WeightsContainer; if (this.prevAggregate) { - v = this.prevAggregate; + v = this.prevAggregate.map(t => tf.clone(t)); // Clone to avoid in-place modifications } else { - // Use shape of the first contribution to create zero vector - const first = currentContributions.values().next(); - if (first.done) throw new Error("zero sized contribution") - v = first.value.map((t: tf.Tensor) => tf.zerosLike(t)); + v = aggregation.avg(currentContributions.values()); } + + const eps = tf.scalar(1e-12); + const one = tf.scalar(1); + const radius = tf.scalar(this.clippingRadius); + // Step 2: Iterative Centered Clipping for (let l = 0; l < this.maxIterations; l++) { const clippedDiffs = Array.from(currentContributions.values()).map(m => { const diff = m.sub(v); - const norm = tf.tidy(() => euclideanNorm(diff)); - const scale = tf.tidy(() => tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm))); + + const norm = euclideanNorm(diff); + + const safeNorm = tf.maximum(norm, eps); + + const scale = tf.minimum( + one, + tf.div(radius, safeNorm) + ); + const clipped = diff.mul(scale); - norm.dispose(); scale.dispose(); + + norm.dispose(); + safeNorm.dispose(); + scale.dispose(); + return clipped; }); const avgClip = aggregation.avg(clippedDiffs); const newV = v.add(avgClip); + clippedDiffs.forEach(d => d.dispose()); - v.dispose(); // Safe if v is no longer needed + + const oldV = v; v = newV; + oldV.dispose(); } - // Step 3: Update momentum history + + eps.dispose(); + one.dispose(); + radius.dispose(); + // Step 3: Update history this.prevAggregate = v; return v; } @@ -119,8 +165,8 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { function euclideanNorm(w: WeightsContainer): tf.Scalar { // Computes the Euclidean (L2) norm of all tensors in a WeightsContainer by summing the squares of their elements and taking the square root. return tf.tidy(() => { - const norms: tf.Scalar[] = w.weights.map(t => tf.sum(tf.square(t))); - const total = norms.reduce((a, b) => tf.add(a, b)); - return tf.sqrt(total); + const squaredSums = w.weights.map(t => tf.sum(tf.square(t))); + const total = tf.addN(squaredSums); + return tf.sqrt(total) as tf.Scalar; }); } \ No newline at end of file diff --git a/discojs/src/aggregator/byzantine_vs_percentile.spec.ts b/discojs/src/aggregator/byzantine_vs_percentile.spec.ts new file mode 100644 index 000000000..3c141e53d --- /dev/null +++ b/discojs/src/aggregator/byzantine_vs_percentile.spec.ts @@ -0,0 +1,362 @@ +import { Set } from "immutable"; +import { describe, expect, it } from "vitest"; + +import { WeightsContainer } from "../index.js"; +import { ByzantineRobustAggregator } from "./byzantine.js"; +import { PercentileClippingAggregator } from "./percentile_clipping.js"; + +// Helper to convert WeightsContainer → number[][] for easy assertions +async function WSIntoArrays(ws: WeightsContainer): Promise { + return Promise.all(ws.weights.map(async t => Array.from(await t.data()))); +} + +// Timing measurement helper +interface TimingResult { + name: string; + time: number; + result: number; +} + +async function measureAggregation( + aggregator: ByzantineRobustAggregator | PercentileClippingAggregator, + name: string, + peers: { id: string; value: number }[] +): Promise { + const promise = aggregator.getPromiseForAggregation(); + const currentRound = aggregator.round; + + const startTime = performance.now(); + peers.forEach(peer => { + aggregator.add(peer.id, WeightsContainer.of([peer.value]), currentRound); + }); + + const result = await promise; + const endTime = performance.now(); + + const arr = await WSIntoArrays(result); + const aggregatedValue = arr[0][0]; + + return { + name, + time: endTime - startTime, + result: aggregatedValue, + }; +} + +function formatTiming(timings: TimingResult[]): string { + const maxNameLen = Math.max(...timings.map(t => t.name.length)); + return timings + .map(t => ` ${t.name.padEnd(maxNameLen)} | ${t.time.toFixed(3)}ms | result: ${t.result.toFixed(4)}`) + .join('\n'); +} + +describe("Comparison: Centered Clipping vs Percentile Clipping", () => { +/** + * ============================================================ + * Comparison: Centered Clipping (CC) vs Percentile Clipping + * ============================================================ + * + * These tests highlight the fundamental differences between two + * aggregation strategies used in adversarial / federated settings. + * + * Centered Clipping (CC): + * - Iterative, principled aggregation rule with bounded updates + * - Provides theoretical robustness against Byzantine clients + * - Gradually refines the estimate over multiple iterations + * - More stable across rounds and symmetric/adversarial scenarios + * - Computationally more expensive (multiple passes over data) + * - Converges slowly if initialized far from the true signal + * + * Percentile Clipping: + * - Single-pass, heuristic aggregation based on norm thresholds + * - Fast and simple, with low computational overhead + * - Works well when outliers are clearly separable + * - Highly sensitive to data distribution and chosen percentile (tau) + * - Can behave like simple averaging in moderate/noisy settings + * - Can fail when Byzantine clients dominate or blend with honest ones + * + * When to use which: + * + * - Use Centered Clipping when: + * - Robustness is critical (adversarial or unreliable clients) + * - You can afford additional computation + * - You expect persistent or structured Byzantine behavior + * + * - Use Percentile Clipping when: + * - You need fast, lightweight aggregation + * - Data is mostly clean with occasional outliers + * - Strong robustness guarantees are not required + * + * Summary: + * CC - slower but principled and robust + * Percentile - faster but heuristic and less reliable + * + * The tests below illustrate these trade-offs across different + * regimes (extreme outliers, moderate attacks, multi-round behavior, etc.). + */ + it("CC improves with more iterations", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "h3", value: 1 }, + { id: "b1", value: 1000 }, + ]; + + const ids = peers.map(p => p.id); + + const cc1 = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 1, 0); + const cc50 = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 50, 0); + + cc1.setNodes(Set(ids)); + cc50.setNodes(Set(ids)); + + const r1 = await measureAggregation(cc1, "cc1", peers); + const r50 = await measureAggregation(cc50, "cc50", peers); + + const honest = 1; + + expect(Math.abs(r50.result - honest)) + .to.be.lessThan(Math.abs(r1.result - honest)); + }); + + it("percentile behaves like mean under moderate Byzantine values", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "h3", value: 1 }, + { id: "b1", value: 3 }, + { id: "b2", value: 3 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 5, "absolute", 1.0, 5, 0); + const pc = new PercentileClippingAggregator(0, 5, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resPC = await measureAggregation(pc, "pc", peers); + + const honest = 1; + const mean = (1 + 1 + 1 + 3 + 3) / 5; + + // percentile behaves close to mean + expect(Math.abs(resPC.result - mean)).to.be.lessThan(0.5); + + // both are biased away from honest + expect(Math.abs(resPC.result - honest)).to.be.greaterThan(0.3); + }); + + it("iterations improve CC but not percentile", async () => { + const peers = [ + { id: "h1", value: 0 }, + { id: "h2", value: 0 }, + { id: "b1", value: 10 }, + ]; + + const ids = peers.map(p => p.id); + + const cc1 = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 1, 0); + const cc5 = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 5, 0); + + cc1.setNodes(Set(ids)); + cc5.setNodes(Set(ids)); + + const r1 = await measureAggregation(cc1, "cc1", peers); + const r5 = await measureAggregation(cc5, "cc5", peers); + + expect(Math.abs(r5.result)) + .to.be.lessThanOrEqual(Math.abs(r1.result)); + }); + + it("percentile sensitivity to tau parameter", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "b1", value: 100 }, + ]; + + const ids = peers.map(p => p.id); + + const lowTau = new PercentileClippingAggregator(0, 3, "absolute", 0.1); + const highTau = new PercentileClippingAggregator(0, 3, "absolute", 0.9); + + lowTau.setNodes(Set(ids)); + highTau.setNodes(Set(ids)); + + const rLow = await measureAggregation(lowTau, "low", peers); + const rHigh = await measureAggregation(highTau, "high", peers); + + expect(Math.abs(rLow.result - 1)) + .to.be.lessThan(Math.abs(rHigh.result - 1)); + }); + + it("CC is at least as stable across rounds as percentile", async () => { + const ids = ["h1", "h2", "b1"]; + + const cc = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 5, 0); + const pc = new PercentileClippingAggregator(0, 3, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + // Round 1 + const prev = 5; + + await measureAggregation(cc, "cc1", [ + { id: "h1", value: prev }, + { id: "h2", value: prev }, + { id: "b1", value: prev }, + ]); + + await measureAggregation(pc, "pc1", [ + { id: "h1", value: prev }, + { id: "h2", value: prev }, + { id: "b1", value: prev }, + ]); + + // Round 2 (attack) + const rCC = await measureAggregation(cc, "cc2", [ + { id: "h1", value: 10 }, + { id: "h2", value: 10 }, + { id: "b1", value: 100 }, + ]); + + const rPC = await measureAggregation(pc, "pc2", [ + { id: "h1", value: 10 }, + { id: "h2", value: 10 }, + { id: "b1", value: 100 }, + ]); + + const deltaCC = Math.abs(rCC.result - prev); + const deltaPC = Math.abs(rPC.result - prev); + + expect(deltaCC).to.be.at.most(deltaPC + 1e-6); + }); + + it("percentile breaks when Byzantine dominate percentile", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "b1", value: 10 }, + { id: "b2", value: 10 }, + { id: "b3", value: 10 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 5, "absolute", 1.0, 10, 0); + const pc = new PercentileClippingAggregator(0, 5, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resCC = await measureAggregation(cc, "cc", peers); + const resPC = await measureAggregation(pc, "pc", peers); + + const honest = 1; + + // Percentile clearly drifts + expect(resPC.result).to.be.greaterThan(3); + + // Both are worse than honest + expect(Math.abs(resCC.result - honest)).to.be.greaterThan(1); + expect(Math.abs(resPC.result - honest)).to.be.greaterThan(1); + }); + + it("both aggregators behave similarly without Byzantine clients", async () => { + const peers = [ + { id: "a", value: 1 }, + { id: "b", value: 2 }, + { id: "c", value: 3 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 3, "absolute", 10, 1, 0); + const pc = new PercentileClippingAggregator(0, 3, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resCC = await measureAggregation(cc, "cc", peers); + const resPC = await measureAggregation(pc, "pc", peers); + + expect(resCC.result).to.be.closeTo(resPC.result, 1e-6); + }); + + it("prints timing comparison (CC vs Percentile)", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "h3", value: 1 }, + { id: "b1", value: 1000 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 20, 0); + const pc = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resCC = await measureAggregation(cc, "CC", peers); + const resPC = await measureAggregation(pc, "Percentile", peers); + + console.log("\nTiming comparison:\n" + formatTiming([resCC, resPC])); + + expect(resPC.time).to.be.lessThan(resCC.time); + }); + + it("CC handles symmetric attacks better than percentile", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "b1", value: 100 }, + { id: "b2", value: -100 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 5, 0); + const pc = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resCC = await measureAggregation(cc, "cc", peers); + const resPC = await measureAggregation(pc, "pc", peers); + + const honest = 1; + + expect(Math.abs(resCC.result - honest)) + .to.be.lessThan(Math.abs(resPC.result - honest)); + }); + + it("percentile is sensitive to honest variance", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 2 }, + { id: "h3", value: 3 }, + { id: "b1", value: 10 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 5, 0); + const pc = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resCC = await measureAggregation(cc, "cc", peers); + const resPC = await measureAggregation(pc, "pc", peers); + + const honestMean = (1 + 2 + 3) / 3; + + expect(Math.abs(resCC.result - honestMean)) + .to.be.lessThan(Math.abs(resPC.result - honestMean)); + }); +}); \ No newline at end of file diff --git a/discojs/src/aggregator/percentile_clipping.spec.ts b/discojs/src/aggregator/percentile_clipping.spec.ts new file mode 100644 index 000000000..1ba39f8c3 --- /dev/null +++ b/discojs/src/aggregator/percentile_clipping.spec.ts @@ -0,0 +1,169 @@ +import { Set } from "immutable"; +import { describe, expect, it } from "vitest"; + +import { WeightsContainer } from "../index.js"; +import { PercentileClippingAggregator } from "./percentile_clipping.js"; + +async function WSIntoArrays(ws: WeightsContainer): Promise { + return Promise.all(ws.weights.map(async t => Array.from(await t.data()))); +} + +describe("PercentileClippingAggregator", () => { + + it("throws on invalid constructor parameters", () => { + expect(() => new PercentileClippingAggregator(0, 1, "absolute", 0)).to.throw(); + expect(() => new PercentileClippingAggregator(0, 1, "absolute", 1)).to.throw(); + expect(() => new PercentileClippingAggregator(0, 1, "absolute", -0.1)).to.throw(); + }); + + it("behaves like mean when no clipping occurs", async () => { + const agg = new PercentileClippingAggregator(0, 3, "absolute", 0.5); + agg.setNodes(Set(["a", "b", "c"])); + + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([2]), 0); + agg.add("c", WeightsContainer.of([3]), 0); + + const out = await p; + const arr = await WSIntoArrays(out); + + expect(arr[0][0]).to.be.closeTo(2, 1e-6); + }); + + it("reduces influence of a large outlier (heuristically)", async () => { + const agg = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + agg.setNodes(Set(["a", "b", "c", "d"])); + + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([1]), 0); + agg.add("c", WeightsContainer.of([1]), 0); + agg.add("d", WeightsContainer.of([100]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + const mean = (1 + 1 + 1 + 100) / 4; + + expect(Math.abs(v - 1)).to.be.lessThan(Math.abs(mean - 1)); + }); + + it("is idempotent when all inputs are identical", async () => { + const agg = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + agg.setNodes(Set(["a", "b", "c", "d"])); + + const p = agg.getPromiseForAggregation(); + ["a", "b", "c", "d"].forEach(id => agg.add(id, WeightsContainer.of([5]), 0)); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + expect(v).to.be.closeTo(5, 1e-6); + }); + + it("is invariant to client ordering", async () => { + const values = [1, 2, 100]; + const ids1 = ["a", "b", "c"]; + const ids2 = ["c", "a", "b"]; + + const run = async (ids: string[]) => { + const agg = new PercentileClippingAggregator(0, 3, "absolute", 0.5); + agg.setNodes(Set(ids)); + const p = agg.getPromiseForAggregation(); + ids.forEach((id, i) => + agg.add(id, WeightsContainer.of([values[i]]), 0) + ); + return (await (await p).weights[0].data())[0]; + }; + + const out1 = await run(ids1); + const out2 = await run(ids2); + + expect(out1).to.be.closeTo(out2, 1e-6); + }); + + it("lower percentile increases clipping strength", async () => { + const nodes = ["a", "b", "c", "d"]; + const inputs = [1, 1, 1, 100]; + + const aggLow = new PercentileClippingAggregator(0, 4, "absolute", 0.1); + const aggHigh = new PercentileClippingAggregator(0, 4, "absolute", 0.9); + + aggLow.setNodes(Set(nodes)); + aggHigh.setNodes(Set(nodes)); + + const pLow = aggLow.getPromiseForAggregation(); + const pHigh = aggHigh.getPromiseForAggregation(); + + nodes.forEach((n, i) => { + aggLow.add(n, WeightsContainer.of([inputs[i]]), 0); + aggHigh.add(n, WeightsContainer.of([inputs[i]]), 0); + }); + + const vLow = (await (await pLow).weights[0].data())[0]; + const vHigh = (await (await pHigh).weights[0].data())[0]; + + expect(Math.abs(vLow - 1)).to.be.lessThan(Math.abs(vHigh - 1)); + }); + + it("handles zero-norm inputs without NaN", async () => { + const agg = new PercentileClippingAggregator(0, 2, "absolute", 0.5); + agg.setNodes(Set(["a", "b"])); + + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([0]), 0); + agg.add("b", WeightsContainer.of([0]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + expect(Number.isFinite(v)).to.be.true; + }); + + it("respects roundCutoff", async () => { + const agg = new PercentileClippingAggregator(1, 1, "absolute", 0.5); + agg.setNodes(Set(["a"])); + + const p0 = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([10]), 0); + const v0 = (await (await p0).weights[0].data())[0]; + expect(v0).to.equal(10); + + const p2 = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([20]), 2); + const v2 = (await (await p2).weights[0].data())[0]; + expect(v2).to.equal(20); + }); + + it("can fail under strong Byzantine attack (documented limitation)", async () => { + const agg = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + agg.setNodes(Set(["a", "b", "c", "d"])); + + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([1]), 0); + agg.add("c", WeightsContainer.of([50]), 0); + agg.add("d", WeightsContainer.of([100]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + // We don't assert correctness — only that it doesn't explode + expect(Number.isFinite(v)).to.be.true; + }); + + it("reset state when starting fresh aggregator", async () => { + const run = async () => { + const agg = new PercentileClippingAggregator(0, 2, "absolute", 0.5); + agg.setNodes(Set(["a", "b"])); + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([1]), 0); + return (await (await p).weights[0].data())[0]; + }; + + expect(await run()).to.be.closeTo(await run(), 1e-6); + }); + +}); \ No newline at end of file diff --git a/discojs/src/aggregator/percentile_clipping.ts b/discojs/src/aggregator/percentile_clipping.ts new file mode 100644 index 000000000..5f2c6a575 --- /dev/null +++ b/discojs/src/aggregator/percentile_clipping.ts @@ -0,0 +1,140 @@ +import { Map } from "immutable"; +import * as tf from '@tensorflow/tfjs'; +import { AggregationStep } from "./aggregator.js"; +import { MultiRoundAggregator, ThresholdType } from "./multiround.js"; +import { WeightsContainer, client } from "../index.js"; +import { aggregation } from "../index.js"; + +/** + * Percentile-based clipping aggregator. + * + * This method clips updates using a threshold τ computed as a percentile + * of update norms. Unlike Centered Clipping, this is a single-pass heuristic + * and does not provide formal Byzantine robustness guarantees. + * + * Use Case: + * Suitable for mitigating mild outliers or noisy updates when most clients + * are honest. Not suitable for adversarial Byzantine settings, as the + * percentile threshold can be influenced by malicious clients. + * + * Algorithm: + * 1. Center all peer weights w.r.t. the previous aggregation + * 2. Compute Frobenius norm for each centered weight + * 3. Compute tau as the percentile of the norm array + * 4. Clip each centered weight: clip = centeredWeight * min(1, tau / norm) + * 5. Average all clipped weights + */ + +export class PercentileClippingAggregator extends MultiRoundAggregator { + private readonly tauPercentile: number; + private prevAggregate: WeightsContainer | null = null; + + /** + * @property tauPercentile The percentile (0 < tau < 1) used to compute the clipping threshold. + * - Type: `number` + * - Determines which percentile of the Frobenius norms to use as the clipping threshold. + * - For example, 0.1 clips at the 10th percentile of norms. + * - Smaller values are more aggressive (clip more updates). + * - Default value is 0.1. + */ + + constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType, tauPercentile = 0.1) { + super(roundCutoff, threshold, thresholdType); + if (tauPercentile <= 0 || tauPercentile >= 1) { + throw new Error("Tau percentile must be between 0 and 1 (exclusive)."); + } + this.tauPercentile = tauPercentile; + } + + override _add(nodeId: client.NodeID, contribution: WeightsContainer): void { + this.log( + this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, + nodeId, + ); + // Store contribution as is, without client-side momentum + this.contributions = this.contributions.setIn([0, nodeId], contribution); + } + + override aggregate(): WeightsContainer { + const currentContributions = this.contributions.get(0); + if (!currentContributions || currentContributions.size === 0) throw new Error("aggregating without any contribution"); + + this.log(AggregationStep.AGGREGATE); + + // Step 1: Get the centering reference (previous aggregation or initial avg vector) + let centerReference: WeightsContainer; + if (this.prevAggregate) { + centerReference = this.prevAggregate.map(t => tf.clone(t)); + } else { + centerReference = aggregation.avg(currentContributions.values()).map(t => tf.clone(t)); + } + + // Step 2: Center the weights with respect to the reference + const centeredWeights = Array.from(currentContributions.values()).map(w => + w.sub(centerReference) + ); + + // Step 3: Compute Frobenius norms for each centered weight + const normArray = centeredWeights.map(w => frobeniusNorm(w)); + + // Step 4: Compute tau as the percentile of the norm array + const tau = this.computePercentile(normArray, this.tauPercentile); + + // Step 5: Clip weights based on tau + // Each peer gets one scale factor based on their Frobenius norm + const clippedWeights = centeredWeights.map((w, peerIdx) => { + //const scaleFactor = Math.min(1, tau / normArray[peerIdx]); + const norm = normArray[peerIdx]; + const safeNorm = Math.max(norm, 1e-12); + + const scaleFactor = Math.min(1, tau / safeNorm); + return w.map((t: tf.Tensor) => t.mul(scaleFactor)); + }); + + centeredWeights.forEach(w => w.dispose()); + + // Step 6: Average the clipped weights and add back the reference + const clippedAvg = aggregation.avg(clippedWeights); + const result = centerReference.add(clippedAvg); + + centerReference.dispose(); + clippedWeights.forEach(w => w.dispose()); + clippedAvg.dispose(); + + // Step 7: Store result for next round + this.prevAggregate = result; + return result; + } + + private computePercentile(array: number[], percentile: number): number { + // Linear interpolation for percentile calculation + const clean = array.filter(Number.isFinite); + if (clean.length === 0) return 0; + + const sorted = [...clean].sort((a, b) => a - b); + const pos = (sorted.length - 1) * percentile; + const base = Math.floor(pos); + const rest = pos - base; + + if (sorted[base + 1] !== undefined) { + return sorted[base] + rest * (sorted[base + 1] - sorted[base]); + } else { + return sorted[base]; + } + } + + override makePayloads(weights: WeightsContainer): Map { + return this.nodes.toMap().map(() => weights); + } +} + +function frobeniusNorm(w: WeightsContainer): number { + // Computes the Frobenius (L2) norm of all tensors in a WeightsContainer + return tf.tidy(() => { + const total = w.weights + .map(t => tf.sum(tf.square(t))) + .reduce((a, b) => tf.add(a, b), tf.scalar(0)); + + return tf.sqrt(total).dataSync()[0]; + }); +} diff --git a/package-lock.json b/package-lock.json index 61a448f66..16939dca0 100644 --- a/package-lock.json +++ b/package-lock.json @@ -65,6 +65,7 @@ "devDependencies": { "@tensorflow/tfjs-node": "4", "@types/simple-peer": "9", + "fast-check": "3", "nodemon": "3", "ts-node": "10" } @@ -7412,6 +7413,29 @@ ], "license": "MIT" }, + "node_modules/fast-check": { + "version": "3.23.2", + "resolved": "https://registry.npmjs.org/fast-check/-/fast-check-3.23.2.tgz", + "integrity": "sha512-h5+1OzzfCC3Ef7VbtKdcv7zsstUQwUDlYpUTvjeUsJAssPgLn7QzbboPtL5ro04Mq0rPOsMzl7q5hIbRs2wD1A==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/dubzzz" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fast-check" + } + ], + "license": "MIT", + "dependencies": { + "pure-rand": "^6.1.0" + }, + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/fast-deep-equal": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", @@ -9079,6 +9103,7 @@ "os": [ "android" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9099,6 +9124,7 @@ "os": [ "darwin" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9119,6 +9145,7 @@ "os": [ "darwin" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9139,6 +9166,7 @@ "os": [ "freebsd" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9159,6 +9187,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9179,6 +9208,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9199,6 +9229,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9219,6 +9250,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9239,6 +9271,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9259,6 +9292,7 @@ "os": [ "win32" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9279,6 +9313,7 @@ "os": [ "win32" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -10895,6 +10930,23 @@ "node": ">=6" } }, + "node_modules/pure-rand": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/pure-rand/-/pure-rand-6.1.0.tgz", + "integrity": "sha512-bVWawvoZoBYpp6yIoQtQXHZjmz35RSVHnUOTefl8Vcjr8snTPY1wnpSPMWekcFwbxI6gtmT7rSYPFvz71ldiOA==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/dubzzz" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fast-check" + } + ], + "license": "MIT" + }, "node_modules/qs": { "version": "6.14.1", "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz",