From 5d85cd5af60cbeb3789b281346c6a97dc09ae2d8 Mon Sep 17 00:00:00 2001 From: He-Pin Date: Thu, 30 Apr 2026 23:44:58 +0800 Subject: [PATCH] Optimize default set operations --- sjsonnet/src/sjsonnet/stdlib/SetModule.scala | 277 +++++++++++++++---- 1 file changed, 223 insertions(+), 54 deletions(-) diff --git a/sjsonnet/src/sjsonnet/stdlib/SetModule.scala b/sjsonnet/src/sjsonnet/stdlib/SetModule.scala index 3135e059..072a8fd0 100644 --- a/sjsonnet/src/sjsonnet/stdlib/SetModule.scala +++ b/sjsonnet/src/sjsonnet/stdlib/SetModule.scala @@ -53,6 +53,165 @@ object SetModule extends AbstractFunctionModule { } } + @inline private def trimSetOutput(out: Array[Eval], len: Int): Array[Eval] = { + if (len == out.length) out + else java.util.Arrays.copyOf(out, len) + } + + @inline private def compareDefaultSetKeys(ev: EvalScope, a: Val, b: Val): Int = { + a match { + case aNum: Val.Num => + b match { + case bNum: Val.Num => java.lang.Double.compare(aNum.asDouble, bNum.asDouble) + case _ => ev.compare(a, b) + } + case aStr: Val.Str => + b match { + case bStr: Val.Str => Util.compareStringsByCodepoint(aStr.str, bStr.str) + case _ => ev.compare(a, b) + } + case _ => ev.compare(a, b) + } + } + + // WHY: default-key set operations are hot in numeric/string workloads. Keep the current + // forced key at each cursor and compare primitive-like values directly; this avoids repeated + // Eval.value calls and generic ev.compare dispatch without changing keyF semantics. + private def setUnionDefaultKeyF( + pos: Position, + ev: EvalScope, + a: Array[? <: Eval], + b: Array[? <: Eval]): Val.Arr = { + val out = new Array[Eval](a.length + b.length) + + var outLen = 0 + var idxA = 0 + var idxB = 0 + var elemA = a(idxA).value + var elemB = b(idxB).value + + while (idxA < a.length && idxB < b.length) { + val cmp = compareDefaultSetKeys(ev, elemA, elemB) + if (cmp < 0) { + out(outLen) = a(idxA) + outLen += 1 + idxA += 1 + if (idxA < a.length) elemA = a(idxA).value + } else if (cmp > 0) { + out(outLen) = b(idxB) + outLen += 1 + idxB += 1 + if (idxB < b.length) elemB = b(idxB).value + } else { + out(outLen) = a(idxA) + outLen += 1 + idxA += 1 + idxB += 1 + if (idxA < a.length && idxB < b.length) { + elemA = a(idxA).value + elemB = b(idxB).value + } + } + } + + while (idxA < a.length) { + out(outLen) = a(idxA) + outLen += 1 + idxA += 1 + } + while (idxB < b.length) { + out(outLen) = b(idxB) + outLen += 1 + idxB += 1 + } + + Val.Arr(pos, trimSetOutput(out, outLen)) + } + + private def setInterDefaultKeyF( + pos: Position, + ev: EvalScope, + a: Array[? <: Eval], + b: Array[? <: Eval]): Val.Arr = { + if (a.isEmpty || b.isEmpty) return Val.Arr(pos, Array.empty[Eval]) + + val out = new Array[Eval](math.min(a.length, b.length)) + + var outLen = 0 + var idxA = 0 + var idxB = 0 + var elemA = a(idxA).value + var elemB = b(idxB).value + + while (idxA < a.length && idxB < b.length) { + val cmp = compareDefaultSetKeys(ev, elemA, elemB) + if (cmp < 0) { + idxA += 1 + if (idxA < a.length) elemA = a(idxA).value + } else if (cmp > 0) { + idxB += 1 + if (idxB < b.length) elemB = b(idxB).value + } else { + out(outLen) = a(idxA) + outLen += 1 + idxA += 1 + idxB += 1 + if (idxA < a.length && idxB < b.length) { + elemA = a(idxA).value + elemB = b(idxB).value + } + } + } + + Val.Arr(pos, trimSetOutput(out, outLen)) + } + + private def setDiffDefaultKeyF( + pos: Position, + ev: EvalScope, + a: Array[? <: Eval], + b: Array[? <: Eval]): Val.Arr = { + if (a.isEmpty) return Val.Arr(pos, Array.empty[Eval]) + + val out = new Array[Eval](a.length) + + var outLen = 0 + var idxA = 0 + var idxB = 0 + var elemB: Val = null + + while (idxA < a.length) { + val elemA = a(idxA).value + var foundEqual = false + var continue = true + + while (idxB < b.length && continue) { + if (elemB == null) elemB = b(idxB).value + val cmp = compareDefaultSetKeys(ev, elemA, elemB) + if (cmp <= 0) { + foundEqual = cmp == 0 + if (foundEqual) { + idxB += 1 + elemB = null + } + continue = false + } else { + idxB += 1 + elemB = null + } + } + + if (!foundEqual) { + out(outLen) = a(idxA) + outLen += 1 + } + + idxA += 1 + } + + Val.Arr(pos, trimSetOutput(out, outLen)) + } + private def validateSet(ev: EvalScope, pos: Position, keyF: Val, arr: Val): Unit = { if (ev.settings.throwErrorForInvalidSets) { val sorted = uniqArr(pos.noOffset, ev, sortArr(pos.noOffset, ev, arr, keyF), keyF) @@ -325,6 +484,8 @@ object SetModule extends AbstractFunctionModule { args(1) } else if (b.isEmpty) { args(0) + } else if (isDefaultKeyF(keyF)) { + setUnionDefaultKeyF(pos, ev, a, b) } else { val out = new mutable.ArrayBuilder.ofRef[Eval] out.sizeHint(a.length + b.length) @@ -385,36 +546,40 @@ object SetModule extends AbstractFunctionModule { val a = toArrOrString(args(0), pos, ev) val b = toArrOrString(args(1), pos, ev) - val out = new mutable.ArrayBuilder.ofRef[Eval] - // Set a reasonable size hint - intersection will be at most the size of the smaller set - out.sizeHint(math.min(a.length, b.length)) + if (isDefaultKeyF(keyF)) { + setInterDefaultKeyF(pos, ev, a, b) + } else { + val out = new mutable.ArrayBuilder.ofRef[Eval] + // Set a reasonable size hint - intersection will be at most the size of the smaller set + out.sizeHint(math.min(a.length, b.length)) - var idxA = 0 - var idxB = 0 + var idxA = 0 + var idxB = 0 - while (idxA < a.length && idxB < b.length) { - val elemA = a(idxA).value - val elemB = b(idxB).value + while (idxA < a.length && idxB < b.length) { + val elemA = a(idxA).value + val elemB = b(idxB).value - val keyA = applyKeyFunc(elemA, keyF, pos, ev) - val keyB = applyKeyFunc(elemB, keyF, pos, ev) + val keyA = applyKeyFunc(elemA, keyF, pos, ev) + val keyB = applyKeyFunc(elemB, keyF, pos, ev) - val cmp = ev.compare(keyA, keyB) - if (cmp < 0) { - // keyA < keyB, elemA not in intersection - idxA += 1 - } else if (cmp > 0) { - // keyA > keyB, elemB not in intersection - idxB += 1 - } else { - // keyA == keyB, found intersection element - out.+=(a(idxA)) - idxA += 1 - idxB += 1 + val cmp = ev.compare(keyA, keyB) + if (cmp < 0) { + // keyA < keyB, elemA not in intersection + idxA += 1 + } else if (cmp > 0) { + // keyA > keyB, elemB not in intersection + idxB += 1 + } else { + // keyA == keyB, found intersection element + out.+=(a(idxA)) + idxA += 1 + idxB += 1 + } } - } - Val.Arr(pos, out.result()) + Val.Arr(pos, out.result()) + } }, /** * [[https://jsonnet.org/ref/stdlib.html#std-setDiff std.setDiff(a, b, keyF=id)]]. @@ -431,45 +596,49 @@ object SetModule extends AbstractFunctionModule { val a = toArrOrString(args(0), pos, ev) val b = toArrOrString(args(1), pos, ev) - val out = new mutable.ArrayBuilder.ofRef[Eval] - // Set a reasonable size hint - difference will be at most the size of the first set - out.sizeHint(a.length) + if (isDefaultKeyF(keyF)) { + setDiffDefaultKeyF(pos, ev, a, b) + } else { + val out = new mutable.ArrayBuilder.ofRef[Eval] + // Set a reasonable size hint - difference will be at most the size of the first set + out.sizeHint(a.length) - var idxA = 0 - var idxB = 0 + var idxA = 0 + var idxB = 0 - while (idxA < a.length) { - val elemA = a(idxA).value - val keyA = applyKeyFunc(elemA, keyF, pos, ev) + while (idxA < a.length) { + val elemA = a(idxA).value + val keyA = applyKeyFunc(elemA, keyF, pos, ev) - // Advance idxB to find first element >= keyA - var foundEqual = false - var continue = true - while (idxB < b.length && continue) { - val elemB = b(idxB).value - val keyB = applyKeyFunc(elemB, keyF, pos, ev) + // Advance idxB to find first element >= keyA + var foundEqual = false + var continue = true + while (idxB < b.length && continue) { + val elemB = b(idxB).value + val keyB = applyKeyFunc(elemB, keyF, pos, ev) + + val cmp = ev.compare(keyA, keyB) + if (cmp <= 0) { + // keyA <= keyB, found position + foundEqual = (cmp == 0) + if (foundEqual) idxB += 1 // Move past the match + continue = false + } else { + // keyA > keyB, keep advancing in b + idxB += 1 + } + } - val cmp = ev.compare(keyA, keyB) - if (cmp <= 0) { - // keyA <= keyB, found position - foundEqual = (cmp == 0) - if (foundEqual) idxB += 1 // Move past the match - continue = false - } else { - // keyA > keyB, keep advancing in b - idxB += 1 + // Add elemA if we didn't find it in b + if (!foundEqual) { + out.+=(a(idxA)) } - } - // Add elemA if we didn't find it in b - if (!foundEqual) { - out.+=(a(idxA)) + idxA += 1 } - idxA += 1 + Val.Arr(pos, out.result()) } - - Val.Arr(pos, out.result()) }, /** * [[https://jsonnet.org/ref/stdlib.html#std-setMember std.setMember(x, arr, keyF=id)]].