From 5b8506fb845e9da27beb5d3a7c38ea7b081e3679 Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Sun, 28 Apr 2024 07:35:59 +0300 Subject: [PATCH] Use double size domain instead of FrMultiplicativeGen, Add timings --- backend/groth16/bn254/icicle/icicle.go | 46 ++++++++++++------- backend/groth16/bn254/icicle/provingkey.go | 4 +- backend/groth16/bn254/prove.go | 52 +++++++++++++++++----- 3 files changed, 72 insertions(+), 30 deletions(-) diff --git a/backend/groth16/bn254/icicle/icicle.go b/backend/groth16/bn254/icicle/icicle.go index 8f4d914447..02e887c9ee 100644 --- a/backend/groth16/bn254/icicle/icicle.go +++ b/backend/groth16/bn254/icicle/icicle.go @@ -23,6 +23,7 @@ import ( "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" + "github.com/rs/zerolog" icicle_core "github.com/ingonyama-zk/icicle/v2/wrappers/golang/core" icicle_cr "github.com/ingonyama-zk/icicle/v2/wrappers/golang/cuda_runtime" @@ -42,11 +43,12 @@ func (pk *ProvingKey) setupDevicePointers() error { return nil } pk.deviceInfo = &deviceInfo{} + gen, _ := fft.Generator(2 * pk.Domain.Cardinality) /************************* Den ***************************/ n := int(pk.Domain.Cardinality) var denI, oneI fr.Element oneI.SetOne() - denI.Exp(pk.Domain.FrMultiplicativeGen, big.NewInt(int64(pk.Domain.Cardinality))) + denI.Exp(gen, big.NewInt(int64(pk.Domain.Cardinality))) denI.Sub(&denI, &oneI).Inverse(&denI) log2SizeFloor := bits.Len(uint(n)) - 1 @@ -63,6 +65,7 @@ func (pk *ProvingKey) setupDevicePointers() error { go func() { denIcicleArrHost := (icicle_core.HostSlice[fr.Element])(denIcicleArr) denIcicleArrHost.CopyToDevice(&pk.DenDevice, true) + icicle_bn254.FromMontgomery(&pk.DenDevice) copyDenDone <- true }() @@ -72,10 +75,9 @@ func (pk *ProvingKey) setupDevicePointers() error { panic("Couldn't create device context") // TODO } - gen, _ := fft.Generator(2 * pk.Domain.Cardinality) genBits := gen.Bits() limbs := icicle_core.ConvertUint64ArrToUint32Arr(genBits[:]) - pk.CosetGenerator = limbs + copy(pk.CosetGenerator[:], limbs[:fr.Limbs*2]) var rouIcicle icicle_bn254.ScalarField rouIcicle.FromLimbs(limbs) e := icicle_ntt.InitDomain(rouIcicle, ctx, false) @@ -264,7 +266,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var h icicle_core.DeviceSlice chHDone := make(chan struct{}, 1) go func() { - h = computeH(solution.A, solution.B, solution.C, pk) + h = computeH(solution.A, solution.B, solution.C, pk, log) solution.A = nil solution.B = nil @@ -336,7 +338,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b cfg.ArePointsMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true res := make(icicle_core.HostSlice[icicle_bn254.Projective], 1) + start := time.Now() icicle_msm.Msm(wireValuesBDevice, pk.G1Device.B, &cfg, res) + log.Debug().Dur("took", time.Since(start)).Msg("MSM Bs1") bs1 = g1ProjectiveToG1Jac(res[0]) bs1.AddMixed(&pk.G1.Beta) @@ -352,7 +356,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b cfg.ArePointsMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true res := make(icicle_core.HostSlice[icicle_bn254.Projective], 1) + start := time.Now() icicle_msm.Msm(wireValuesADevice, pk.G1Device.A, &cfg, res) + log.Debug().Dur("took", time.Since(start)).Msg("MSM Ar1") ar = g1ProjectiveToG1Jac(res[0]) ar.AddMixed(&pk.G1.Alpha) @@ -370,7 +376,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b cfg.ArePointsMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true resKrs2 := make(icicle_core.HostSlice[icicle_bn254.Projective], 1) + start := time.Now() icicle_msm.Msm(h.RangeTo(sizeH, false), pk.G1Device.Z, &cfg, resKrs2) + log.Debug().Dur("took", time.Since(start)).Msg("MSM Krs2") krs2 = g1ProjectiveToG1Jac(resKrs2[0]) // filter the wire values if needed @@ -380,7 +388,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) _wireValuesHost := (icicle_core.HostSlice[fr.Element])(_wireValues) resKrs := make(icicle_core.HostSlice[icicle_bn254.Projective], 1) + start = time.Now() icicle_msm.Msm(_wireValuesHost, pk.G1Device.K, &cfg, resKrs) + log.Debug().Dur("took", time.Since(start)).Msg("MSM Krs") krs = g1ProjectiveToG1Jac(resKrs[0]) krs.AddMixed(&deltas[2]) @@ -408,7 +418,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b cfg.ArePointsMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true res := make(icicle_core.HostSlice[icicle_g2.G2Projective], 1) + start := time.Now() icicle_g2.G2Msm(wireValuesBDevice, pk.G2Device.B, &cfg, res) + log.Debug().Dur("took", time.Since(start)).Msg("MSM Bs2 G2") Bs = g2ProjectiveToG2Jac(&res[0]) deltaS.FromAffine(&pk.G2.Delta) @@ -478,7 +490,7 @@ func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr return } -func computeH(a, b, c []fr.Element, pk *ProvingKey) icicle_core.DeviceSlice { +func computeH(a, b, c []fr.Element, pk *ProvingKey, log zerolog.Logger) icicle_core.DeviceSlice { // H part of Krs // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1)) // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c) @@ -498,25 +510,22 @@ func computeH(a, b, c []fr.Element, pk *ProvingKey) icicle_core.DeviceSlice { computeBDone := make(chan icicle_core.DeviceSlice, 1) computeCDone := make(chan icicle_core.DeviceSlice, 1) - cosetGenBits := pk.Domain.FrMultiplicativeGen.Bits() - cosetGen := icicle_core.ConvertUint64ArrToUint32Arr(cosetGenBits[:]) - var configCosetGen [8]uint32 - copy(configCosetGen[:], cosetGen[:8]) - computeInttNttOnDevice := func(scalars []fr.Element, channel chan icicle_core.DeviceSlice) { cfg := icicle_ntt.GetDefaultNttConfig() scalarsStream, _ := icicle_cr.CreateStream() cfg.Ctx.Stream = &scalarsStream - cfg.Ordering = icicle_core.KNR + cfg.Ordering = icicle_core.KNM cfg.IsAsync = true scalarsHost := icicle_core.HostSliceFromElements(scalars) var scalarsDevice icicle_core.DeviceSlice scalarsHost.CopyToDeviceAsync(&scalarsDevice, scalarsStream, true) + start := time.Now() icicle_ntt.Ntt(scalarsDevice, icicle_core.KInverse, &cfg, scalarsDevice) - cfg.Ordering = icicle_core.KRN - cfg.CosetGen = configCosetGen + cfg.Ordering = icicle_core.KMN + cfg.CosetGen = pk.CosetGenerator icicle_ntt.Ntt(scalarsDevice, icicle_core.KForward, &cfg, scalarsDevice) icicle_cr.SynchronizeStream(&scalarsStream) + log.Debug().Dur("took", time.Since(start)).Msg("computeH: NTT + INTT") channel <-scalarsDevice } @@ -529,16 +538,21 @@ func computeH(a, b, c []fr.Element, pk *ProvingKey) icicle_core.DeviceSlice { cDevice := <-computeCDone vecCfg := icicle_core.DefaultVecOpsConfig() + start := time.Now() icicle_bn254.FromMontgomery(&aDevice) icicle_vecops.VecOp(aDevice, bDevice, aDevice, vecCfg, icicle_core.Mul) icicle_vecops.VecOp(aDevice, cDevice, aDevice, vecCfg, icicle_core.Sub) - icicle_bn254.FromMontgomery(&aDevice) icicle_vecops.VecOp(aDevice, pk.DenDevice, aDevice, vecCfg, icicle_core.Mul) + log.Debug().Dur("took", time.Since(start)).Msg("computeH: vecOps") + defer bDevice.Free() + defer cDevice.Free() cfg := icicle_ntt.GetDefaultNttConfig() - cfg.CosetGen = configCosetGen - cfg.Ordering = icicle_core.KNR + cfg.CosetGen = pk.CosetGenerator + cfg.Ordering = icicle_core.KNM + start = time.Now() icicle_ntt.Ntt(aDevice, icicle_core.KInverse, &cfg, aDevice) + log.Debug().Dur("took", time.Since(start)).Msg("computeH: INTT final") return aDevice } diff --git a/backend/groth16/bn254/icicle/provingkey.go b/backend/groth16/bn254/icicle/provingkey.go index e0fa5a210a..501c2d22da 100644 --- a/backend/groth16/bn254/icicle/provingkey.go +++ b/backend/groth16/bn254/icicle/provingkey.go @@ -1,14 +1,14 @@ package icicle import ( - + "github.com/consensys/gnark-crypto/ecc/bn254/fr" groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" cs "github.com/consensys/gnark/constraint/bn254" icicle_core "github.com/ingonyama-zk/icicle/v2/wrappers/golang/core" ) type deviceInfo struct { - CosetGenerator []uint32 + CosetGenerator [fr.Limbs*2]uint32 G1Device struct { A, B, K, Z icicle_core.DeviceSlice } diff --git a/backend/groth16/bn254/prove.go b/backend/groth16/bn254/prove.go index 100f30e85a..8674a50651 100644 --- a/backend/groth16/bn254/prove.go +++ b/backend/groth16/bn254/prove.go @@ -18,6 +18,10 @@ package groth16 import ( "fmt" + "math/big" + "runtime" + "time" + "github.com/consensys/gnark-crypto/ecc" curve "github.com/consensys/gnark-crypto/ecc/bn254" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -32,9 +36,7 @@ import ( "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/logger" - "math/big" - "runtime" - "time" + "github.com/rs/zerolog" fcs "github.com/consensys/gnark/frontend/cs" ) @@ -132,7 +134,7 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b var h []fr.Element chHDone := make(chan struct{}, 1) go func() { - h = computeH(solution.A, solution.B, solution.C, &pk.Domain) + h = computeH(solution.A, solution.B, solution.C, &pk.Domain, log) solution.A = nil solution.B = nil solution.C = nil @@ -191,7 +193,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chBs1Done := make(chan error, 1) computeBS1 := func() { <-chWireValuesB - if _, err := bs1.MultiExp(pk.G1.B, wireValuesB, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + start := time.Now() + _, err := bs1.MultiExp(pk.G1.B, wireValuesB, ecc.MultiExpConfig{NbTasks: n / 2}) + log.Debug().Dur("took", time.Since(start)).Msg("MSM Bs1") + if err != nil { chBs1Done <- err close(chBs1Done) return @@ -204,7 +209,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chArDone := make(chan error, 1) computeAR1 := func() { <-chWireValuesA - if _, err := ar.MultiExp(pk.G1.A, wireValuesA, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + start := time.Now() + _, err := ar.MultiExp(pk.G1.A, wireValuesA, ecc.MultiExpConfig{NbTasks: n / 2}) + log.Debug().Dur("took", time.Since(start)).Msg("MSM Ar1") + if err != nil { chArDone <- err close(chArDone) return @@ -224,7 +232,9 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b chKrs2Done := make(chan error, 1) sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 go func() { + start := time.Now() _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) + log.Debug().Dur("took", time.Since(start)).Msg("MSM Krs2") chKrs2Done <- err }() @@ -234,7 +244,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) - if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { + start := time.Now() + _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}) + log.Debug().Dur("took", time.Since(start)).Msg("MSM Krs") + if err != nil { chKrsDone <- err return } @@ -280,7 +293,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b nbTasks *= 2 } <-chWireValuesB - if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: nbTasks}); err != nil { + start := time.Now() + _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: nbTasks}) + log.Debug().Dur("took", time.Since(start)).Msg("MSM Bs2 G2") + if err != nil { return err } @@ -343,7 +359,7 @@ func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr return } -func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { +func computeH(a, b, c []fr.Element, domain *fft.Domain, log zerolog.Logger) []fr.Element { // H part of Krs // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1)) // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c) @@ -359,13 +375,21 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { c = append(c, padding...) n = len(a) + start := time.Now() domain.FFTInverse(a, fft.DIF) - domain.FFTInverse(b, fft.DIF) - domain.FFTInverse(c, fft.DIF) - domain.FFT(a, fft.DIT, fft.OnCoset()) + log.Debug().Dur("took", time.Since(start)).Msg("computeH: NTT + INTT") + + start = time.Now() + domain.FFTInverse(b, fft.DIF) domain.FFT(b, fft.DIT, fft.OnCoset()) + log.Debug().Dur("took", time.Since(start)).Msg("computeH: NTT + INTT") + + start = time.Now() + domain.FFTInverse(c, fft.DIF) domain.FFT(c, fft.DIT, fft.OnCoset()) + log.Debug().Dur("took", time.Since(start)).Msg("computeH: NTT + INTT") + var den, one fr.Element one.SetOne() @@ -374,6 +398,7 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { // h = ifft_coset(ca o cb - cc) // reusing a to avoid unnecessary memory allocation + start = time.Now() utils.Parallelize(n, func(start, end int) { for i := start; i < end; i++ { a[i].Mul(&a[i], &b[i]). @@ -381,9 +406,12 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { Mul(&a[i], &den) } }) + log.Debug().Dur("took", time.Since(start)).Msg("computeH: vecOps") // ifft_coset + start = time.Now() domain.FFTInverse(a, fft.DIF, fft.OnCoset()) + log.Debug().Dur("took", time.Since(start)).Msg("computeH: INTT final") return a }