Skip to content

Commit

Permalink
Fix coset ntt and have correctness
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfelder committed Apr 26, 2024
1 parent d9713f8 commit ec3c980
Showing 1 changed file with 47 additions and 140 deletions.
187 changes: 47 additions & 140 deletions backend/groth16/bn254/icicle/icicle.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ import (
"fmt"
"math/big"
"math/bits"
"runtime"
"time"

"github.com/consensys/gnark-crypto/ecc"
curve "github.com/consensys/gnark-crypto/ecc/bn254"
"github.com/consensys/gnark-crypto/ecc/bn254/fp"
"github.com/consensys/gnark-crypto/ecc/bn254/fr"
Expand All @@ -32,7 +30,7 @@ import (
icicle_g2 "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/g2"
icicle_msm "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/msm"
icicle_ntt "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/ntt"
// icicle_vecops "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/vecOps"
icicle_vecops "github.com/ingonyama-zk/icicle/v2/wrappers/golang/curves/bn254/vecOps"

fcs "github.com/consensys/gnark/frontend/cs"
)
Expand Down Expand Up @@ -263,11 +261,10 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b
}

// H (witness reduction / FFT part)
// var h icicle_core.DeviceSlice
var hCPU []fr.Element
var h icicle_core.DeviceSlice
chHDone := make(chan struct{}, 1)
go func() {
hCPU = computeH(solution.A, solution.B, solution.C, &pk.Domain)
h = computeH(solution.A, solution.B, solution.C, pk)

solution.A = nil
solution.B = nil
Expand Down Expand Up @@ -329,7 +326,6 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b

// computes r[δ], s[δ], kr[δ]
deltas := curve.BatchScalarMultiplicationG1(&pk.G1.Delta, []fr.Element{_r, _s, _kr})
n := runtime.NumCPU()

var bs1, ar curve.G1Jac

Expand Down Expand Up @@ -368,31 +364,15 @@ func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...b

computeKRS := func() error {
var krs, krs2, p1 curve.G1Jac
var krs2CPU curve.G1Jac
sizeH := int(pk.Domain.Cardinality - 1)

// CPU START

if _, err := krs2CPU.MultiExp(pk.G1.Z, hCPU[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}); err != nil {
panic("krs2CPU didn't complete")
}

// CPU END

cfg := icicle_msm.GetDefaultMSMConfig()
cfg.ArePointsMontgomeryForm = true
cfg.AreScalarsMontgomeryForm = true
resKrs2 := make(icicle_core.HostSlice[icicle_bn254.Projective], 1)
// icicle_msm.Msm(h.RangeTo(sizeH, false), pk.G1Device.Z, &cfg, resKrs2)
icicle_msm.Msm(icicle_core.HostSliceFromElements(hCPU[:sizeH]), pk.G1Device.Z, &cfg, resKrs2)

icicle_msm.Msm(h.RangeTo(sizeH, false), pk.G1Device.Z, &cfg, resKrs2)
krs2 = g1ProjectiveToG1Jac(resKrs2[0])

if krs2.Equal(&krs2CPU) {
fmt.Println("krs2 succeeded")
} else {
fmt.Println("krs2 failed correctness")
}
// filter the wire values if needed
// TODO Perf @Tabaie worst memory allocation offender
toRemove := commitmentInfo.GetPrivateCommitted()
Expand Down Expand Up @@ -498,68 +478,7 @@ func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr
return
}

// func computeH(a, b, c []fr.Element, pk *ProvingKey) icicle_core.DeviceSlice {
// // H part of Krs
// // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1))
// // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c)
// // 2 - ca = fft_coset(_a), ba = fft_coset(_b), cc = fft_coset(_c)
// // 3 - h = ifft_coset(ca o cb - cc)

// n := len(a)

// // add padding to ensure input length is domain cardinality
// padding := make([]fr.Element, int(pk.Domain.Cardinality)-n)
// a = append(a, padding...)
// b = append(b, padding...)
// c = append(c, padding...)
// n = len(a)

// computeADone := make(chan icicle_core.DeviceSlice, 1)
// computeBDone := make(chan icicle_core.DeviceSlice, 1)
// computeCDone := make(chan icicle_core.DeviceSlice, 1)

// computeInttNttOnDevice := func(scalars []fr.Element, channel chan icicle_core.DeviceSlice) {
// cfg := icicle_ntt.GetDefaultNttConfig()
// scalarsStream, _ := icicle_cr.CreateStream()
// cfg.Ctx.Stream = &scalarsStream
// cfg.Ordering = icicle_core.KNR
// cfg.IsAsync = true
// scalarsHost := icicle_core.HostSliceFromElements(scalars)
// var scalarsDevice icicle_core.DeviceSlice
// scalarsHost.CopyToDeviceAsync(&scalarsDevice, scalarsStream, true)
// icicle_ntt.Ntt(scalarsDevice, icicle_core.KInverse, &cfg, scalarsDevice)
// cfg.Ordering = icicle_core.KRN
// cfg.CosetGen = [8]uint32(icicle_core.ConvertUint64ArrToUint32Arr(pk.Domain.FrMultiplicativeGen[:]))
// icicle_ntt.Ntt(scalarsDevice, icicle_core.KForward, &cfg, scalarsDevice)
// icicle_cr.SynchronizeStream(&scalarsStream)
// channel <-scalarsDevice
// }

// go computeInttNttOnDevice(a, computeADone)
// go computeInttNttOnDevice(b, computeBDone)
// go computeInttNttOnDevice(c, computeCDone)

// aDevice := <-computeADone
// bDevice := <-computeBDone
// cDevice := <-computeCDone

// vecCfg := icicle_core.DefaultVecOpsConfig()
// icicle_vecops.VecOp(aDevice, bDevice, aDevice, vecCfg, icicle_core.Mul)
// icicle_vecops.VecOp(aDevice, cDevice, aDevice, vecCfg, icicle_core.Sub)
// icicle_vecops.VecOp(aDevice, pk.DenDevice, aDevice, vecCfg, icicle_core.Mul)

// cfg := icicle_ntt.GetDefaultNttConfig()
// cfg.CosetGen = [8]uint32(icicle_core.ConvertUint64ArrToUint32Arr(pk.Domain.FrMultiplicativeGenInv[:]))
// cfg.Ordering = icicle_core.KNR
// icicle_ntt.Ntt(aDevice, icicle_core.KInverse, &cfg, aDevice)

// resHost := make(icicle_core.HostSlice[fr.Element], n)
// resHost.CopyFromDevice(&aDevice)

// return aDevice
// }

func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
func computeH(a, b, c []fr.Element, pk *ProvingKey) icicle_core.DeviceSlice {
// H part of Krs
// Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1))
// 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c)
Expand All @@ -569,69 +488,57 @@ func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element {
n := len(a)

// add padding to ensure input length is domain cardinality
padding := make([]fr.Element, int(domain.Cardinality)-n)
padding := make([]fr.Element, int(pk.Domain.Cardinality)-n)
a = append(a, padding...)
b = append(b, padding...)
c = append(c, padding...)
n = len(a)

aCopy := make([]fr.Element, n)
copy(aCopy, a)

cfg := icicle_ntt.GetDefaultNttConfig()
cfg.Ordering = icicle_core.KNR
scalarsHost := icicle_core.HostSliceFromElements(aCopy)
scalarsHostOut := make(icicle_core.HostSlice[fr.Element], len(aCopy))
icicle_ntt.Ntt(scalarsHost, icicle_core.KInverse, &cfg, scalarsHostOut)

domain.FFTInverse(a, fft.DIF)

for i, elem := range a {
if !elem.Equal(&scalarsHostOut[i]) {
fmt.Println("computeH: A failed")
}
computeADone := make(chan icicle_core.DeviceSlice, 1)
computeBDone := make(chan icicle_core.DeviceSlice, 1)
computeCDone := make(chan icicle_core.DeviceSlice, 1)

cosetGenBits := pk.Domain.FrMultiplicativeGen.Bits()
cosetGen := icicle_core.ConvertUint64ArrToUint32Arr(cosetGenBits[:])
var configCosetGen [8]uint32
copy(configCosetGen[:], cosetGen[:8])

computeInttNttOnDevice := func(scalars []fr.Element, channel chan icicle_core.DeviceSlice) {
cfg := icicle_ntt.GetDefaultNttConfig()
scalarsStream, _ := icicle_cr.CreateStream()
cfg.Ctx.Stream = &scalarsStream
cfg.Ordering = icicle_core.KNR
cfg.IsAsync = true
scalarsHost := icicle_core.HostSliceFromElements(scalars)
var scalarsDevice icicle_core.DeviceSlice
scalarsHost.CopyToDeviceAsync(&scalarsDevice, scalarsStream, true)
icicle_ntt.Ntt(scalarsDevice, icicle_core.KInverse, &cfg, scalarsDevice)
cfg.Ordering = icicle_core.KRN
cfg.CosetGen = configCosetGen
icicle_ntt.Ntt(scalarsDevice, icicle_core.KForward, &cfg, scalarsDevice)
icicle_cr.SynchronizeStream(&scalarsStream)
channel <-scalarsDevice
}

domain.FFTInverse(b, fft.DIF)
domain.FFTInverse(c, fft.DIF)

go computeInttNttOnDevice(a, computeADone)
go computeInttNttOnDevice(b, computeBDone)
go computeInttNttOnDevice(c, computeCDone)

gen, _ := fft.Generator(2 * domain.Cardinality)
// genBits := gen.Bits()
// limbs := icicle_core.ConvertUint64ArrToUint32Arr(genBits[:])
// var rouIcicle icicle_bn254.ScalarField
// rouIcicle.FromLimbs(limbs)
cfgCustom := icicle_ntt.GetDefaultNttConfig()
cfg.CosetGen = ([8]uint32)(icicle_core.ConvertUint64ArrToUint32Arr(gen[:]))
cfgCustom.Ordering = icicle_core.KRN
icicle_ntt.Ntt(scalarsHostOut, icicle_core.KForward, &cfgCustom, scalarsHost)
aDevice := <-computeADone
bDevice := <-computeBDone
cDevice := <-computeCDone

domain.FFT(a, fft.DIT, fft.OnCoset())

if !scalarsHost[0].Equal(&a[0]) {
fmt.Println("computeH: A Forward failed")
}
vecCfg := icicle_core.DefaultVecOpsConfig()
icicle_bn254.FromMontgomery(&aDevice)
icicle_vecops.VecOp(aDevice, bDevice, aDevice, vecCfg, icicle_core.Mul)
icicle_vecops.VecOp(aDevice, cDevice, aDevice, vecCfg, icicle_core.Sub)
icicle_bn254.FromMontgomery(&aDevice)
icicle_vecops.VecOp(aDevice, pk.DenDevice, aDevice, vecCfg, icicle_core.Mul)

domain.FFT(b, fft.DIT, fft.OnCoset())
domain.FFT(c, fft.DIT, fft.OnCoset())

var den, one fr.Element
one.SetOne()
den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality)))
den.Sub(&den, &one).Inverse(&den)

// h = ifft_coset(ca o cb - cc)
// reusing a to avoid unnecessary memory allocation
utils.Parallelize(n, func(start, end int) {
for i := start; i < end; i++ {
a[i].Mul(&a[i], &b[i]).
Sub(&a[i], &c[i]).
Mul(&a[i], &den)
}
})

// ifft_coset
domain.FFTInverse(a, fft.DIF, fft.OnCoset())
cfg := icicle_ntt.GetDefaultNttConfig()
cfg.CosetGen = configCosetGen
cfg.Ordering = icicle_core.KNR
icicle_ntt.Ntt(aDevice, icicle_core.KInverse, &cfg, aDevice)

return a
return aDevice
}

0 comments on commit ec3c980

Please sign in to comment.