mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-08 23:17:54 -05:00
FEAT: MultiGPU for golang bindings (#417)
## Describe the changes This PR adds multi gpu support in the golang bindings. Tha main changes are to DeviceSlice which now includes a `deviceId` attribute specifying which device the underlying data resides on and checks for correct deviceId and current device when using DeviceSlices in any operation. In Go, most concurrency can be done via Goroutines (described as lightweight threads - in reality, more of a threadpool manager), however, there is no guarantee that a goroutine stays on a specific host thread. Therefore, a function `RunOnDevice` was added to the cuda_runtime package which locks a goroutine into a specific host thread, sets a current GPU device, runs a provided function, and unlocks the goroutine from the host thread after the provided function finishes. While the goroutine is locked to the hsot thread, the Go runtime will not assign other goroutines to that host thread
This commit is contained in:
@@ -43,8 +43,17 @@ func (d DeviceSlice) IsOnDevice() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO: change signature to be Malloc(element, numElements)
|
||||
// calc size internally
|
||||
func (d DeviceSlice) GetDeviceId() int {
|
||||
return cr.GetDeviceFromPointer(d.inner)
|
||||
}
|
||||
|
||||
// CheckDevice is used to ensure that the DeviceSlice about to be used resides on the currently set device
|
||||
func (d DeviceSlice) CheckDevice() {
|
||||
if currentDeviceId, err := cr.GetDevice(); err != cr.CudaSuccess || d.GetDeviceId() != currentDeviceId {
|
||||
panic("Attempt to use DeviceSlice on a different device")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DeviceSlice) Malloc(size, sizeOfElement int) (DeviceSlice, cr.CudaError) {
|
||||
dp, err := cr.Malloc(uint(size))
|
||||
d.inner = dp
|
||||
@@ -62,6 +71,7 @@ func (d *DeviceSlice) MallocAsync(size, sizeOfElement int, stream cr.CudaStream)
|
||||
}
|
||||
|
||||
func (d *DeviceSlice) Free() cr.CudaError {
|
||||
d.CheckDevice()
|
||||
err := cr.Free(d.inner)
|
||||
if err == cr.CudaSuccess {
|
||||
d.length, d.capacity = 0, 0
|
||||
@@ -70,6 +80,16 @@ func (d *DeviceSlice) Free() cr.CudaError {
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *DeviceSlice) FreeAsync(stream cr.Stream) cr.CudaError {
|
||||
d.CheckDevice()
|
||||
err := cr.FreeAsync(d.inner, stream)
|
||||
if err == cr.CudaSuccess {
|
||||
d.length, d.capacity = 0, 0
|
||||
d.inner = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type HostSliceInterface interface {
|
||||
Size() int
|
||||
}
|
||||
@@ -117,6 +137,7 @@ func (h HostSlice[T]) CopyToDevice(dst *DeviceSlice, shouldAllocate bool) *Devic
|
||||
if shouldAllocate {
|
||||
dst.Malloc(size, h.SizeOfElement())
|
||||
}
|
||||
dst.CheckDevice()
|
||||
if size > dst.Cap() {
|
||||
panic("Number of bytes to copy is too large for destination")
|
||||
}
|
||||
@@ -133,6 +154,7 @@ func (h HostSlice[T]) CopyToDeviceAsync(dst *DeviceSlice, stream cr.CudaStream,
|
||||
if shouldAllocate {
|
||||
dst.MallocAsync(size, h.SizeOfElement(), stream)
|
||||
}
|
||||
dst.CheckDevice()
|
||||
if size > dst.Cap() {
|
||||
panic("Number of bytes to copy is too large for destination")
|
||||
}
|
||||
@@ -144,6 +166,7 @@ func (h HostSlice[T]) CopyToDeviceAsync(dst *DeviceSlice, stream cr.CudaStream,
|
||||
}
|
||||
|
||||
func (h HostSlice[T]) CopyFromDevice(src *DeviceSlice) {
|
||||
src.CheckDevice()
|
||||
if h.Len() != src.Len() {
|
||||
panic("destination and source slices have different lengths")
|
||||
}
|
||||
@@ -152,6 +175,7 @@ func (h HostSlice[T]) CopyFromDevice(src *DeviceSlice) {
|
||||
}
|
||||
|
||||
func (h HostSlice[T]) CopyFromDeviceAsync(src *DeviceSlice, stream cr.Stream) {
|
||||
src.CheckDevice()
|
||||
if h.Len() != src.Len() {
|
||||
panic("destination and source slices have different lengths")
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ package cuda_runtime
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@@ -17,20 +19,28 @@ type DeviceContext struct {
|
||||
Stream *Stream // Assuming the type is provided by a CUDA binding crate
|
||||
|
||||
/// Index of the currently used GPU. Default value: 0.
|
||||
DeviceId uint
|
||||
deviceId uint
|
||||
|
||||
/// Mempool to use. Default value: 0.
|
||||
// TODO: use cuda_bindings.CudaMemPool as type
|
||||
Mempool uint // Assuming the type is provided by a CUDA binding crate
|
||||
Mempool MemPool // Assuming the type is provided by a CUDA binding crate
|
||||
}
|
||||
|
||||
func (d DeviceContext) GetDeviceId() int {
|
||||
return int(d.deviceId)
|
||||
}
|
||||
|
||||
func GetDefaultDeviceContext() (DeviceContext, CudaError) {
|
||||
device, err := GetDevice()
|
||||
if err != CudaSuccess {
|
||||
panic(fmt.Sprintf("Could not get current device due to %v", err))
|
||||
}
|
||||
var defaultStream Stream
|
||||
var defaultMempool MemPool
|
||||
|
||||
return DeviceContext{
|
||||
&defaultStream,
|
||||
0,
|
||||
0,
|
||||
uint(device),
|
||||
defaultMempool,
|
||||
}, CudaSuccess
|
||||
}
|
||||
|
||||
@@ -47,3 +57,78 @@ func GetDeviceCount() (int, CudaError) {
|
||||
err := C.cudaGetDeviceCount(cCount)
|
||||
return count, (CudaError)(err)
|
||||
}
|
||||
|
||||
func GetDevice() (int, CudaError) {
|
||||
var device int
|
||||
cDevice := (*C.int)(unsafe.Pointer(&device))
|
||||
err := C.cudaGetDevice(cDevice)
|
||||
return device, (CudaError)(err)
|
||||
}
|
||||
|
||||
func GetDeviceFromPointer(ptr unsafe.Pointer) int {
|
||||
var cCudaPointerAttributes CudaPointerAttributes
|
||||
err := C.cudaPointerGetAttributes(&cCudaPointerAttributes, ptr)
|
||||
if (CudaError)(err) != CudaSuccess {
|
||||
panic("Could not get attributes of pointer")
|
||||
}
|
||||
return int(cCudaPointerAttributes.device)
|
||||
}
|
||||
|
||||
// RunOnDevice forces the provided function to run all GPU related calls within it
|
||||
// on the same host thread and therefore the same GPU device.
|
||||
//
|
||||
// NOTE: Goroutines launched within funcToRun are not bound to the
|
||||
// same host thread as funcToRun and therefore not to the same GPU device.
|
||||
// If that is a requirement, RunOnDevice should be called for each with the
|
||||
// same deviceId as the original call.
|
||||
//
|
||||
// As an example:
|
||||
//
|
||||
// cr.RunOnDevice(i, func(args ...any) {
|
||||
// defer wg.Done()
|
||||
// cfg := GetDefaultMSMConfig()
|
||||
// stream, _ := cr.CreateStream()
|
||||
// for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
// size := 1 << power
|
||||
//
|
||||
// // This will always print "Inner goroutine device: 0"
|
||||
// // go func () {
|
||||
// // device, _ := cr.GetDevice()
|
||||
// // fmt.Println("Inner goroutine device: ", device)
|
||||
// // }()
|
||||
// // To force the above goroutine to same device as the wrapping function:
|
||||
// // RunOnDevice(i, func(arg ...any) {
|
||||
// // device, _ := cr.GetDevice()
|
||||
// // fmt.Println("Inner goroutine device: ", device)
|
||||
// // })
|
||||
//
|
||||
// scalars := GenerateScalars(size)
|
||||
// points := GenerateAffinePoints(size)
|
||||
//
|
||||
// var p Projective
|
||||
// var out core.DeviceSlice
|
||||
// _, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
// assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
// cfg.Ctx.Stream = &stream
|
||||
// cfg.IsAsync = true
|
||||
//
|
||||
// e = Msm(scalars, points, &cfg, out)
|
||||
// assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
//
|
||||
// outHost := make(core.HostSlice[Projective], 1)
|
||||
//
|
||||
// cr.SynchronizeStream(&stream)
|
||||
// outHost.CopyFromDevice(&out)
|
||||
// out.Free()
|
||||
// // Check with gnark-crypto
|
||||
// assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
// }
|
||||
// }, i)
|
||||
func RunOnDevice(deviceId int, funcToRun func(args ...any), args ...any) {
|
||||
go func(id int) {
|
||||
defer runtime.UnlockOSThread()
|
||||
runtime.LockOSThread()
|
||||
SetDevice(id)
|
||||
funcToRun(args...)
|
||||
}(deviceId)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type MemPool = CudaMemPool
|
||||
|
||||
func Malloc(size uint) (unsafe.Pointer, CudaError) {
|
||||
if size == 0 {
|
||||
return nil, CudaErrorMemoryAllocation
|
||||
|
||||
@@ -17,3 +17,6 @@ type CudaEvent C.cudaEvent_t
|
||||
|
||||
// CudaMemPool as declared in include/driver_types.h:2928
|
||||
type CudaMemPool C.cudaMemPool_t
|
||||
|
||||
// CudaMemPool as declared in include/driver_types.h:2928
|
||||
type CudaPointerAttributes = C.struct_cudaPointerAttributes
|
||||
|
||||
@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
|
||||
}
|
||||
|
||||
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
|
||||
}
|
||||
|
||||
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
|
||||
}
|
||||
|
||||
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
|
||||
}
|
||||
|
||||
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@
|
||||
package bls12377
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
|
||||
|
||||
func TestMSMG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -165,3 +171,41 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2MultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := G2GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package bls12377
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
|
||||
|
||||
func TestMSM(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -136,3 +142,41 @@ func TestMSMSkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMMultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
|
||||
@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
|
||||
}
|
||||
|
||||
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, true)
|
||||
}
|
||||
|
||||
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, false)
|
||||
}
|
||||
|
||||
@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
aDevice := a.(core.DeviceSlice)
|
||||
aDevice.CheckDevice()
|
||||
cA = (*C.scalar_t)(aDevice.AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
bDevice := b.(core.DeviceSlice)
|
||||
bDevice.CheckDevice()
|
||||
cB = (*C.scalar_t)(bDevice.AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
outDevice := out.(core.DeviceSlice)
|
||||
outDevice.CheckDevice()
|
||||
cOut = (*C.scalar_t)(outDevice.AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
|
||||
}
|
||||
|
||||
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
|
||||
}
|
||||
|
||||
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
|
||||
}
|
||||
|
||||
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
|
||||
}
|
||||
|
||||
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@
|
||||
package bls12381
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
|
||||
|
||||
func TestMSMG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -165,3 +171,41 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2MultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := G2GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package bls12381
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
|
||||
|
||||
func TestMSM(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -136,3 +142,41 @@ func TestMSMSkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMMultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
|
||||
@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
|
||||
}
|
||||
|
||||
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, true)
|
||||
}
|
||||
|
||||
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, false)
|
||||
}
|
||||
|
||||
@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
aDevice := a.(core.DeviceSlice)
|
||||
aDevice.CheckDevice()
|
||||
cA = (*C.scalar_t)(aDevice.AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
bDevice := b.(core.DeviceSlice)
|
||||
bDevice.CheckDevice()
|
||||
cB = (*C.scalar_t)(bDevice.AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
outDevice := out.(core.DeviceSlice)
|
||||
outDevice.CheckDevice()
|
||||
cOut = (*C.scalar_t)(outDevice.AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
|
||||
}
|
||||
|
||||
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
|
||||
}
|
||||
|
||||
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
|
||||
}
|
||||
|
||||
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
|
||||
}
|
||||
|
||||
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@
|
||||
package bn254
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
|
||||
|
||||
func TestMSMG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -165,3 +171,41 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2MultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := G2GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package bn254
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
|
||||
|
||||
func TestMSM(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -136,3 +142,41 @@ func TestMSMSkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMMultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
|
||||
@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
|
||||
}
|
||||
|
||||
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, true)
|
||||
}
|
||||
|
||||
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, false)
|
||||
}
|
||||
|
||||
@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
aDevice := a.(core.DeviceSlice)
|
||||
aDevice.CheckDevice()
|
||||
cA = (*C.scalar_t)(aDevice.AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
bDevice := b.(core.DeviceSlice)
|
||||
bDevice.CheckDevice()
|
||||
cB = (*C.scalar_t)(bDevice.AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
outDevice := out.(core.DeviceSlice)
|
||||
outDevice.CheckDevice()
|
||||
cOut = (*C.scalar_t)(outDevice.AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
|
||||
}
|
||||
|
||||
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
|
||||
}
|
||||
|
||||
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
|
||||
}
|
||||
|
||||
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
|
||||
}
|
||||
|
||||
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@
|
||||
package bw6761
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -55,6 +58,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
|
||||
|
||||
func TestMSMG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -67,12 +71,14 @@ func TestMSMG2(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -138,3 +144,41 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2MultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := G2GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package bw6761
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
|
||||
|
||||
func TestMSM(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -136,3 +142,41 @@ func TestMSMSkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMMultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
|
||||
@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
|
||||
}
|
||||
|
||||
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, true)
|
||||
}
|
||||
|
||||
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, false)
|
||||
}
|
||||
|
||||
@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
aDevice := a.(core.DeviceSlice)
|
||||
aDevice.CheckDevice()
|
||||
cA = (*C.scalar_t)(aDevice.AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
bDevice := b.(core.DeviceSlice)
|
||||
bDevice.CheckDevice()
|
||||
cB = (*C.scalar_t)(bDevice.AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
outDevice := out.(core.DeviceSlice)
|
||||
outDevice.CheckDevice()
|
||||
cOut = (*C.scalar_t)(outDevice.AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
@@ -150,10 +150,12 @@ func convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points *core.DeviceSlice
|
||||
}
|
||||
|
||||
func {{if .IsG2}}G2{{end}}AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func {{if .IsG2}}G2{{end}}AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -169,10 +171,12 @@ func convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points *core.DeviceS
|
||||
}
|
||||
|
||||
func {{if .IsG2}}G2{{end}}ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func {{if .IsG2}}G2{{end}}ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
{{end}}
|
||||
@@ -22,7 +22,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -30,7 +32,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[{{if .IsG2}}G2{{end}}Affine])[0])
|
||||
}
|
||||
@@ -38,7 +42,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[{{if .IsG2}}G2{{end}}Projective])[0])
|
||||
}
|
||||
|
||||
@@ -5,9 +5,12 @@
|
||||
package {{.PackageName}}
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -102,6 +105,7 @@ func testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars core.HostSlice[Scala
|
||||
|
||||
func TestMSM{{if .IsG2}}G2{{end}}(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -114,12 +118,14 @@ func TestMSM{{if .IsG2}}G2{{end}}(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = {{if .IsG2}}G2{{end}}Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -185,3 +191,41 @@ func TestMSM{{if .IsG2}}G2{{end}}SkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSM{{if .IsG2}}G2{{end}}MultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := {{if .IsG2}}G2{{end}}GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p {{if .IsG2}}G2{{end}}Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = {{if .IsG2}}G2{{end}}Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
|
||||
@@ -33,9 +33,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
|
||||
}
|
||||
|
||||
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, true)
|
||||
}
|
||||
|
||||
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, false)
|
||||
}{{- end}}
|
||||
@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
aDevice := a.(core.DeviceSlice)
|
||||
aDevice.CheckDevice()
|
||||
cA = (*C.scalar_t)(aDevice.AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
bDevice := b.(core.DeviceSlice)
|
||||
bDevice.CheckDevice()
|
||||
cB = (*C.scalar_t)(bDevice.AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
outDevice := out.(core.DeviceSlice)
|
||||
outDevice.CheckDevice()
|
||||
cOut = (*C.scalar_t)(outDevice.AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user