From ec3c98052a002109758f4607ff017ff5c3e76eef Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Fri, 26 Apr 2024 08:17:13 +0300 Subject: [PATCH] Fix coset ntt and have correctness --- backend/groth16/bn254/icicle/icicle.go | 187 +++++++------------------ 1 file changed, 47 insertions(+), 140 deletions(-) diff --git a/backend/groth16/bn254/icicle/icicle.go b/backend/groth16/bn254/icicle/icicle.go index 322bf0f7e..8f4d91444 100644 --- a/backend/groth16/bn254/icicle/icicle.go +++ b/backend/groth16/bn254/icicle/icicle.go @@ -6,10 +6,8 @@ import ( "fmt" "math/big" "math/bits" - "runtime" "time" - "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fp" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -32,7 +30,7 @@ import ( icicle_g2 "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/g2" icicle_msm "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/msm" icicle_ntt "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/ntt" - // icicle_vecops "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/vecOps" + icicle_vecops "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/vecOps" fcs "github.com/consensys/gnark/frontend/cs" ) @@ -263,11 +261,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b } // H (witness reduction / FFT part) - // var h icicle_core.DeviceSlice - var hCPU []fr.Element + var h icicle_core.DeviceSlice chHDone := make(chan struct{}, 1) go func() { - hCPU = computeH(solution.A, solution.B, solution.C, &pk.Domain) + h = computeH(solution.A, solution.B, solution.C, pk) solution.A = nil solution.B = nil @@ -329,7 +326,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b // computes r[δ], s[δ], kr[δ] deltas := curve.BatchScalarMultiplicationG1(&pk.G1.Delta, []fr.Element{_r, _s, _kr}) - n := runtime.NumCPU() var bs1, ar curve.G1Jac @@ -368,31 +364,15 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b computeKRS := func() error { var krs, krs2, p1 curve.G1Jac - var krs2CPU curve.G1Jac sizeH := int(pk.Domain.Cardinality - 1) - // CPU START - - if _, err := krs2CPU.MultiExp(pk.G1.Z, hCPU[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { - panic("krs2CPU didn't complete") - } - - // CPU END - cfg := icicle_msm.GetDefaultMSMConfig() cfg.ArePointsMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true resKrs2 := make(icicle_core.HostSlice[icicle_bn254.Projective], 1) - // icicle_msm.Msm(h.RangeTo(sizeH, false), pk.G1Device.Z, &cfg, resKrs2) - icicle_msm.Msm(icicle_core.HostSliceFromElements(hCPU[:sizeH]), pk.G1Device.Z, &cfg, resKrs2) - + icicle_msm.Msm(h.RangeTo(sizeH, false), pk.G1Device.Z, &cfg, resKrs2) krs2 = g1ProjectiveToG1Jac(resKrs2[0]) - if krs2.Equal(&krs2CPU) { - fmt.Println("krs2 succeeded") - } else { - fmt.Println("krs2 failed correctness") - } // filter the wire values if needed // TODO Perf @Tabaie worst memory allocation offender toRemove := commitmentInfo.GetPrivateCommitted() @@ -498,68 +478,7 @@ func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr return } -// func computeH(a, b, c []fr.Element, pk *ProvingKey) icicle_core.DeviceSlice { -// // H part of Krs -// // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1)) -// // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c) -// // 2 - ca = fft_coset(_a), ba = fft_coset(_b), cc = fft_coset(_c) -// // 3 - h = ifft_coset(ca o cb - cc) - -// n := len(a) - -// // add padding to ensure input length is domain cardinality -// padding := make([]fr.Element, int(pk.Domain.Cardinality)-n) -// a = append(a, padding...) -// b = append(b, padding...) -// c = append(c, padding...) -// n = len(a) - -// computeADone := make(chan icicle_core.DeviceSlice, 1) -// computeBDone := make(chan icicle_core.DeviceSlice, 1) -// computeCDone := make(chan icicle_core.DeviceSlice, 1) - -// computeInttNttOnDevice := func(scalars []fr.Element, channel chan icicle_core.DeviceSlice) { -// cfg := icicle_ntt.GetDefaultNttConfig() -// scalarsStream, _ := icicle_cr.CreateStream() -// cfg.Ctx.Stream = &scalarsStream -// cfg.Ordering = icicle_core.KNR -// cfg.IsAsync = true -// scalarsHost := icicle_core.HostSliceFromElements(scalars) -// var scalarsDevice icicle_core.DeviceSlice -// scalarsHost.CopyToDeviceAsync(&scalarsDevice, scalarsStream, true) -// icicle_ntt.Ntt(scalarsDevice, icicle_core.KInverse, &cfg, scalarsDevice) -// cfg.Ordering = icicle_core.KRN -// cfg.CosetGen = [8]uint32(icicle_core.ConvertUint64ArrToUint32Arr(pk.Domain.FrMultiplicativeGen[:])) -// icicle_ntt.Ntt(scalarsDevice, icicle_core.KForward, &cfg, scalarsDevice) -// icicle_cr.SynchronizeStream(&scalarsStream) -// channel <-scalarsDevice -// } - -// go computeInttNttOnDevice(a, computeADone) -// go computeInttNttOnDevice(b, computeBDone) -// go computeInttNttOnDevice(c, computeCDone) - -// aDevice := <-computeADone -// bDevice := <-computeBDone -// cDevice := <-computeCDone - -// vecCfg := icicle_core.DefaultVecOpsConfig() -// icicle_vecops.VecOp(aDevice, bDevice, aDevice, vecCfg, icicle_core.Mul) -// icicle_vecops.VecOp(aDevice, cDevice, aDevice, vecCfg, icicle_core.Sub) -// icicle_vecops.VecOp(aDevice, pk.DenDevice, aDevice, vecCfg, icicle_core.Mul) - -// cfg := icicle_ntt.GetDefaultNttConfig() -// cfg.CosetGen = [8]uint32(icicle_core.ConvertUint64ArrToUint32Arr(pk.Domain.FrMultiplicativeGenInv[:])) -// cfg.Ordering = icicle_core.KNR -// icicle_ntt.Ntt(aDevice, icicle_core.KInverse, &cfg, aDevice) - -// resHost := make(icicle_core.HostSlice[fr.Element], n) -// resHost.CopyFromDevice(&aDevice) - -// return aDevice -// } - -func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { +func computeH(a, b, c []fr.Element, pk *ProvingKey) icicle_core.DeviceSlice { // H part of Krs // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1)) // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c) @@ -569,69 +488,57 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { n := len(a) // add padding to ensure input length is domain cardinality - padding := make([]fr.Element, int(domain.Cardinality)-n) + padding := make([]fr.Element, int(pk.Domain.Cardinality)-n) a = append(a, padding...) b = append(b, padding...) c = append(c, padding...) n = len(a) - aCopy := make([]fr.Element, n) - copy(aCopy, a) - - cfg := icicle_ntt.GetDefaultNttConfig() - cfg.Ordering = icicle_core.KNR - scalarsHost := icicle_core.HostSliceFromElements(aCopy) - scalarsHostOut := make(icicle_core.HostSlice[fr.Element], len(aCopy)) - icicle_ntt.Ntt(scalarsHost, icicle_core.KInverse, &cfg, scalarsHostOut) - - domain.FFTInverse(a, fft.DIF) - - for i, elem := range a { - if !elem.Equal(&scalarsHostOut[i]) { - fmt.Println("computeH: A failed") - } + computeADone := make(chan icicle_core.DeviceSlice, 1) + computeBDone := make(chan icicle_core.DeviceSlice, 1) + computeCDone := make(chan icicle_core.DeviceSlice, 1) + + cosetGenBits := pk.Domain.FrMultiplicativeGen.Bits() + cosetGen := icicle_core.ConvertUint64ArrToUint32Arr(cosetGenBits[:]) + var configCosetGen [8]uint32 + copy(configCosetGen[:], cosetGen[:8]) + + computeInttNttOnDevice := func(scalars []fr.Element, channel chan icicle_core.DeviceSlice) { + cfg := icicle_ntt.GetDefaultNttConfig() + scalarsStream, _ := icicle_cr.CreateStream() + cfg.Ctx.Stream = &scalarsStream + cfg.Ordering = icicle_core.KNR + cfg.IsAsync = true + scalarsHost := icicle_core.HostSliceFromElements(scalars) + var scalarsDevice icicle_core.DeviceSlice + scalarsHost.CopyToDeviceAsync(&scalarsDevice, scalarsStream, true) + icicle_ntt.Ntt(scalarsDevice, icicle_core.KInverse, &cfg, scalarsDevice) + cfg.Ordering = icicle_core.KRN + cfg.CosetGen = configCosetGen + icicle_ntt.Ntt(scalarsDevice, icicle_core.KForward, &cfg, scalarsDevice) + icicle_cr.SynchronizeStream(&scalarsStream) + channel <-scalarsDevice } - domain.FFTInverse(b, fft.DIF) - domain.FFTInverse(c, fft.DIF) - + go computeInttNttOnDevice(a, computeADone) + go computeInttNttOnDevice(b, computeBDone) + go computeInttNttOnDevice(c, computeCDone) - gen, _ := fft.Generator(2 * domain.Cardinality) - // genBits := gen.Bits() - // limbs := icicle_core.ConvertUint64ArrToUint32Arr(genBits[:]) - // var rouIcicle icicle_bn254.ScalarField - // rouIcicle.FromLimbs(limbs) - cfgCustom := icicle_ntt.GetDefaultNttConfig() - cfg.CosetGen = ([8]uint32)(icicle_core.ConvertUint64ArrToUint32Arr(gen[:])) - cfgCustom.Ordering = icicle_core.KRN - icicle_ntt.Ntt(scalarsHostOut, icicle_core.KForward, &cfgCustom, scalarsHost) + aDevice := <-computeADone + bDevice := <-computeBDone + cDevice := <-computeCDone - domain.FFT(a, fft.DIT, fft.OnCoset()) - - if !scalarsHost[0].Equal(&a[0]) { - fmt.Println("computeH: A Forward failed") - } + vecCfg := icicle_core.DefaultVecOpsConfig() + icicle_bn254.FromMontgomery(&aDevice) + icicle_vecops.VecOp(aDevice, bDevice, aDevice, vecCfg, icicle_core.Mul) + icicle_vecops.VecOp(aDevice, cDevice, aDevice, vecCfg, icicle_core.Sub) + icicle_bn254.FromMontgomery(&aDevice) + icicle_vecops.VecOp(aDevice, pk.DenDevice, aDevice, vecCfg, icicle_core.Mul) - domain.FFT(b, fft.DIT, fft.OnCoset()) - domain.FFT(c, fft.DIT, fft.OnCoset()) - - var den, one fr.Element - one.SetOne() - den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) - den.Sub(&den, &one).Inverse(&den) - - // h = ifft_coset(ca o cb - cc) - // reusing a to avoid unnecessary memory allocation - utils.Parallelize(n, func(start, end int) { - for i := start; i < end; i++ { - a[i].Mul(&a[i], &b[i]). - Sub(&a[i], &c[i]). - Mul(&a[i], &den) - } - }) - - // ifft_coset - domain.FFTInverse(a, fft.DIF, fft.OnCoset()) + cfg := icicle_ntt.GetDefaultNttConfig() + cfg.CosetGen = configCosetGen + cfg.Ordering = icicle_core.KNR + icicle_ntt.Ntt(aDevice, icicle_core.KInverse, &cfg, aDevice) - return a + return aDevice }