Skip to content

Commit

Permalink
Add MSM
Browse files Browse the repository at this point in the history
  • Loading branch information
paulmillr committed Aug 30, 2024
1 parent c1eb761 commit 3b81611
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 4 deletions.
58 changes: 57 additions & 1 deletion src/abstract/curve.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Abelian group utilities
import { IField, validateField, nLength } from './modular.js';
import { validateObject } from './utils.js';
import { validateObject, bitLen } from './utils.js';
const _0n = BigInt(0);
const _1n = BigInt(1);

Expand Down Expand Up @@ -181,6 +181,62 @@ export function wNAF<T extends Group<T>>(c: GroupConstructor<T>, bits: number) {
};
}

/**
* Pippenger algorithm for multi-scalar multiplication (MSM).
* MSM is basically (Pa + Qb + Rc + ...).
* 30x faster vs naive addition on L=4096, 10x faster with precomputes.
* For N=254bit, L=1, it does: 1024 ADD + 254 DBL. For L=5: 1536 ADD + 254 DBL.
* Algorithmically constant-time (for same L), even when 1 point + scalar, or when scalar = 0.
* @param c Curve Point constructor
* @param field field over CURVE.N - important that it's not over CURVE.P
* @param points array of L curve points
* @param scalars array of L scalars (aka private keys / bigints)
*/
export function pippenger<T extends Group<T>>(
c: GroupConstructor<T>,
field: IField<bigint>,
points: T[],
scalars: bigint[]
): T {
// If we split scalars by some window (let's say 8 bits), every chunk will only
// take 256 buckets even if there are 4096 scalars, also re-uses double.
// TODO:
// - https://eprint.iacr.org/2024/750.pdf
// - https://tches.iacr.org/index.php/TCHES/article/view/10287
// 0 is accepted in scalars
if (!Array.isArray(points) || !Array.isArray(scalars) || scalars.length !== points.length)
throw new Error('arrays of scalars and points must have equal length');
scalars.forEach((s, i) => {
if (!field.isValid(s)) throw new Error(`wrong scalar at index ${i}`);
});
points.forEach((p, i) => {
if (!(p instanceof (c as any))) throw new Error(`wrong point at index ${i}`);
});
const wbits = bitLen(BigInt(points.length));
const windowSize = wbits > 12 ? wbits - 3 : wbits > 4 ? wbits - 2 : wbits ? 2 : 1; // in bits
const MASK = (1 << windowSize) - 1;
const buckets = new Array(MASK + 1).fill(c.ZERO); // +1 for zero array
const lastBits = Math.floor((field.BITS - 1) / windowSize) * windowSize;
let sum = c.ZERO;
for (let i = lastBits; i >= 0; i -= windowSize) {
buckets.fill(c.ZERO);
for (let j = 0; j < scalars.length; j++) {
const scalar = scalars[j];
const wbits = Number((scalar >> BigInt(i)) & BigInt(MASK));
buckets[wbits] = buckets[wbits].add(points[j]);
}
let resI = c.ZERO; // not using this will do small speed-up, but will lose ct
// Skip first bucket, because it is zero
for (let j = buckets.length - 1, sumI = c.ZERO; j > 0; j--) {
sumI = sumI.add(buckets[j]);
resI = resI.add(sumI);
}
sum = sum.add(resI);
if (i !== 0) for (let j = 0; j < windowSize; j++) sum = sum.double();
}
return sum as T;
}

// Generic BasicCurve interface: works even for polynomial fields (BLS): P, n, h would be ok.
// Though generator can be different (Fp2 / Fp6 for BLS).
export type BasicCurve<T> = {
Expand Down
18 changes: 16 additions & 2 deletions src/abstract/edwards.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Twisted Edwards curve. The formula is: ax² + y² = 1 + dx²y²
import { AffinePoint, BasicCurve, Group, GroupConstructor, validateBasic, wNAF } from './curve.js';
import { mod } from './modular.js';
import {
AffinePoint,
BasicCurve,
Group,
GroupConstructor,
validateBasic,
wNAF,
pippenger,
} from './curve.js';
import { mod, Field } from './modular.js';
import * as ut from './utils.js';
import { ensureBytes, FHash, Hex, memoized, abool } from './utils.js';

Expand Down Expand Up @@ -70,6 +78,7 @@ export interface ExtPointConstructor extends GroupConstructor<ExtPointType> {
fromAffine(p: AffinePoint<bigint>): ExtPointType;
fromHex(hex: Hex): ExtPointType;
fromPrivateKey(privateKey: Hex): ExtPointType;
msm(points: ExtPointType[], scalars: bigint[]): ExtPointType;
}

/**
Expand Down Expand Up @@ -218,6 +227,10 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
const toInv = Fp.invertBatch(points.map((p) => p.ez));
return points.map((p, i) => p.toAffine(toInv[i])).map(Point.fromAffine);
}
// Multiscalar Multiplication
static msm(points: Point[], scalars: bigint[]) {
return pippenger(Point, Fn, points, scalars);
}

// "Private method", don't use it directly
_setWindowSize(windowSize: number) {
Expand Down Expand Up @@ -419,6 +432,7 @@ export function twistedEdwards(curveDef: CurveType): CurveFn {
}
const { BASE: G, ZERO: I } = Point;
const wnaf = wNAF(Point, nByteLength * 8);
const Fn = Field(CURVE.n, CURVE.nBitLength);

function modN(a: bigint) {
return mod(a, CURVE_ORDER);
Expand Down
17 changes: 16 additions & 1 deletion src/abstract/weierstrass.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
/*! noble-curves - MIT License (c) 2022 Paul Miller (paulmillr.com) */
// Short Weierstrass curve. The formula is: y² = x³ + ax + b
import { AffinePoint, BasicCurve, Group, GroupConstructor, validateBasic, wNAF } from './curve.js';
import {
AffinePoint,
BasicCurve,
Group,
GroupConstructor,
validateBasic,
wNAF,
pippenger,
} from './curve.js';
import * as mod from './modular.js';
import * as ut from './utils.js';
import { CHash, Hex, PrivKey, ensureBytes, memoized, abool } from './utils.js';
Expand Down Expand Up @@ -85,6 +93,7 @@ export interface ProjConstructor<T> extends GroupConstructor<ProjPointType<T>> {
fromHex(hex: Hex): ProjPointType<T>;
fromPrivateKey(privateKey: PrivKey): ProjPointType<T>;
normalizeZ(points: ProjPointType<T>[]): ProjPointType<T>[];
msm(points: ProjPointType<T>[], scalars: bigint[]): ProjPointType<T>;
}

export type CurvePointsType<T> = BasicWCurve<T> & {
Expand Down Expand Up @@ -412,6 +421,11 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>): CurvePointsRes<T
return Point.BASE.multiply(normPrivateKeyToScalar(privateKey));
}

// Multiscalar Multiplication
static msm(points: Point[], scalars: bigint[]) {
return pippenger(Point, Fn, points, scalars);
}

// "Private method", don't use it directly
_setWindowSize(windowSize: number) {
wnaf.setWindowSize(this, windowSize);
Expand Down Expand Up @@ -665,6 +679,7 @@ export function weierstrassPoints<T>(opts: CurvePointsType<T>): CurvePointsRes<T
}
const _bits = CURVE.nBitLength;
const wnaf = wNAF(Point, CURVE.endo ? Math.ceil(_bits / 2) : _bits);
const Fn = mod.Field(CURVE.n, _bits); // TODO: export/re-use maybe?
// Validate if generator point is on curve
return {
CURVE,
Expand Down
27 changes: 27 additions & 0 deletions test/basic.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,33 @@ for (const name in CURVES) {
{ numRuns: NUM_RUNS }
)
);
should('MSM (basic)', () => {
equal(p.msm([p.BASE], [0n]), p.ZERO, '0*G');
equal(p.msm([], []), p.ZERO, 'empty');
equal(p.msm([p.ZERO], [123n]), p.ZERO, '123 * Infinity');
equal(p.msm([p.BASE], [123n]), p.BASE.multiply(123n), '123 * G');
const points = [p.BASE, p.BASE.multiply(2n), p.BASE.multiply(4n), p.BASE.multiply(8n)];
// 1*3 + 5*2 + 4*7 + 11*8 = 129
equal(p.msm(points, [3n, 5n, 7n, 11n]), p.BASE.multiply(129n), '129 * G');
});
should('MSM (rand)', () =>
fc.assert(
fc.property(fc.array(fc.tuple(FC_BIGINT, FC_BIGINT)), FC_BIGINT, (pairs) => {
let total = 0n;
const scalars = [];
const points = [];
for (const [ps, s] of pairs) {
points.push(p.BASE.multiply(ps));
scalars.push(s);
total += ps * s;
}
total = mod.mod(total, CURVE_ORDER);
const exp = total ? p.BASE.multiply(total) : p.ZERO;
equal(p.msm(points, scalars), exp, 'total');
}),
{ numRuns: NUM_RUNS }
)
);
});

for (const op of ['add', 'subtract']) {
Expand Down

0 comments on commit 3b81611

Please sign in to comment.