diff --git a/wrappers/golang/core/slice.go b/wrappers/golang/core/slice.go index 6f884867..8ec1c1f9 100644 --- a/wrappers/golang/core/slice.go +++ b/wrappers/golang/core/slice.go @@ -43,8 +43,17 @@ func (d DeviceSlice) IsOnDevice() bool { return true } -// TODO: change signature to be Malloc(element, numElements) -// calc size internally +func (d DeviceSlice) GetDeviceId() int { + return cr.GetDeviceFromPointer(d.inner) +} + +// CheckDevice is used to ensure that the DeviceSlice about to be used resides on the currently set device +func (d DeviceSlice) CheckDevice() { + if currentDeviceId, err := cr.GetDevice(); err != cr.CudaSuccess || d.GetDeviceId() != currentDeviceId { + panic("Attempt to use DeviceSlice on a different device") + } +} + func (d *DeviceSlice) Malloc(size, sizeOfElement int) (DeviceSlice, cr.CudaError) { dp, err := cr.Malloc(uint(size)) d.inner = dp @@ -62,6 +71,7 @@ func (d *DeviceSlice) MallocAsync(size, sizeOfElement int, stream cr.CudaStream) } func (d *DeviceSlice) Free() cr.CudaError { + d.CheckDevice() err := cr.Free(d.inner) if err == cr.CudaSuccess { d.length, d.capacity = 0, 0 @@ -70,6 +80,16 @@ func (d *DeviceSlice) Free() cr.CudaError { return err } +func (d *DeviceSlice) FreeAsync(stream cr.Stream) cr.CudaError { + d.CheckDevice() + err := cr.FreeAsync(d.inner, stream) + if err == cr.CudaSuccess { + d.length, d.capacity = 0, 0 + d.inner = nil + } + return err +} + type HostSliceInterface interface { Size() int } @@ -117,6 +137,7 @@ func (h HostSlice[T]) CopyToDevice(dst *DeviceSlice, shouldAllocate bool) *Devic if shouldAllocate { dst.Malloc(size, h.SizeOfElement()) } + dst.CheckDevice() if size > dst.Cap() { panic("Number of bytes to copy is too large for destination") } @@ -133,6 +154,7 @@ func (h HostSlice[T]) CopyToDeviceAsync(dst *DeviceSlice, stream cr.CudaStream, if shouldAllocate { dst.MallocAsync(size, h.SizeOfElement(), stream) } + dst.CheckDevice() if size > dst.Cap() { panic("Number of bytes to copy is too large for destination") } @@ -144,6 +166,7 @@ func (h HostSlice[T]) CopyToDeviceAsync(dst *DeviceSlice, stream cr.CudaStream, } func (h HostSlice[T]) CopyFromDevice(src *DeviceSlice) { + src.CheckDevice() if h.Len() != src.Len() { panic("destination and source slices have different lengths") } @@ -152,6 +175,7 @@ func (h HostSlice[T]) CopyFromDevice(src *DeviceSlice) { } func (h HostSlice[T]) CopyFromDeviceAsync(src *DeviceSlice, stream cr.Stream) { + src.CheckDevice() if h.Len() != src.Len() { panic("destination and source slices have different lengths") } diff --git a/wrappers/golang/cuda_runtime/device_context.go b/wrappers/golang/cuda_runtime/device_context.go index 56c72641..5df5bead 100644 --- a/wrappers/golang/cuda_runtime/device_context.go +++ b/wrappers/golang/cuda_runtime/device_context.go @@ -9,6 +9,8 @@ package cuda_runtime */ import "C" import ( + "fmt" + "runtime" "unsafe" ) @@ -17,20 +19,28 @@ type DeviceContext struct { Stream *Stream // Assuming the type is provided by a CUDA binding crate /// Index of the currently used GPU. Default value: 0. - DeviceId uint + deviceId uint /// Mempool to use. Default value: 0. - // TODO: use cuda_bindings.CudaMemPool as type - Mempool uint // Assuming the type is provided by a CUDA binding crate + Mempool MemPool // Assuming the type is provided by a CUDA binding crate +} + +func (d DeviceContext) GetDeviceId() int { + return int(d.deviceId) } func GetDefaultDeviceContext() (DeviceContext, CudaError) { + device, err := GetDevice() + if err != CudaSuccess { + panic(fmt.Sprintf("Could not get current device due to %v", err)) + } var defaultStream Stream + var defaultMempool MemPool return DeviceContext{ &defaultStream, - 0, - 0, + uint(device), + defaultMempool, }, CudaSuccess } @@ -47,3 +57,78 @@ func GetDeviceCount() (int, CudaError) { err := C.cudaGetDeviceCount(cCount) return count, (CudaError)(err) } + +func GetDevice() (int, CudaError) { + var device int + cDevice := (*C.int)(unsafe.Pointer(&device)) + err := C.cudaGetDevice(cDevice) + return device, (CudaError)(err) +} + +func GetDeviceFromPointer(ptr unsafe.Pointer) int { + var cCudaPointerAttributes CudaPointerAttributes + err := C.cudaPointerGetAttributes(&cCudaPointerAttributes, ptr) + if (CudaError)(err) != CudaSuccess { + panic("Could not get attributes of pointer") + } + return int(cCudaPointerAttributes.device) +} + +// RunOnDevice forces the provided function to run all GPU related calls within it +// on the same host thread and therefore the same GPU device. +// +// NOTE: Goroutines launched within funcToRun are not bound to the +// same host thread as funcToRun and therefore not to the same GPU device. +// If that is a requirement, RunOnDevice should be called for each with the +// same deviceId as the original call. +// +// As an example: +// +// cr.RunOnDevice(i, func(args ...any) { +// defer wg.Done() +// cfg := GetDefaultMSMConfig() +// stream, _ := cr.CreateStream() +// for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { +// size := 1 << power +// +// // This will always print "Inner goroutine device: 0" +// // go func () { +// // device, _ := cr.GetDevice() +// // fmt.Println("Inner goroutine device: ", device) +// // }() +// // To force the above goroutine to same device as the wrapping function: +// // RunOnDevice(i, func(arg ...any) { +// // device, _ := cr.GetDevice() +// // fmt.Println("Inner goroutine device: ", device) +// // }) +// +// scalars := GenerateScalars(size) +// points := GenerateAffinePoints(size) +// +// var p Projective +// var out core.DeviceSlice +// _, e := out.MallocAsync(p.Size(), p.Size(), stream) +// assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") +// cfg.Ctx.Stream = &stream +// cfg.IsAsync = true +// +// e = Msm(scalars, points, &cfg, out) +// assert.Equal(t, e, cr.CudaSuccess, "Msm failed") +// +// outHost := make(core.HostSlice[Projective], 1) +// +// cr.SynchronizeStream(&stream) +// outHost.CopyFromDevice(&out) +// out.Free() +// // Check with gnark-crypto +// assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) +// } +// }, i) +func RunOnDevice(deviceId int, funcToRun func(args ...any), args ...any) { + go func(id int) { + defer runtime.UnlockOSThread() + runtime.LockOSThread() + SetDevice(id) + funcToRun(args...) + }(deviceId) +} diff --git a/wrappers/golang/cuda_runtime/memory.go b/wrappers/golang/cuda_runtime/memory.go index 7460833d..47f02bd7 100644 --- a/wrappers/golang/cuda_runtime/memory.go +++ b/wrappers/golang/cuda_runtime/memory.go @@ -12,6 +12,8 @@ import ( "unsafe" ) +type MemPool = CudaMemPool + func Malloc(size uint) (unsafe.Pointer, CudaError) { if size == 0 { return nil, CudaErrorMemoryAllocation diff --git a/wrappers/golang/cuda_runtime/types.go b/wrappers/golang/cuda_runtime/types.go index 1b7e2612..cdeb3365 100644 --- a/wrappers/golang/cuda_runtime/types.go +++ b/wrappers/golang/cuda_runtime/types.go @@ -17,3 +17,6 @@ type CudaEvent C.cudaEvent_t // CudaMemPool as declared in include/driver_types.h:2928 type CudaMemPool C.cudaMemPool_t + +// CudaMemPool as declared in include/driver_types.h:2928 +type CudaPointerAttributes = C.struct_cudaPointerAttributes diff --git a/wrappers/golang/curves/bls12377/curve.go b/wrappers/golang/curves/bls12377/curve.go index f000eda2..15624fbb 100644 --- a/wrappers/golang/curves/bls12377/curve.go +++ b/wrappers/golang/curves/bls12377/curve.go @@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud } func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertAffinePointsMontgomery(points, true) } func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertAffinePointsMontgomery(points, false) } @@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr } func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertProjectivePointsMontgomery(points, true) } func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertProjectivePointsMontgomery(points, false) } diff --git a/wrappers/golang/curves/bls12377/g2_curve.go b/wrappers/golang/curves/bls12377/g2_curve.go index 394724de..69eebb33 100644 --- a/wrappers/golang/curves/bls12377/g2_curve.go +++ b/wrappers/golang/curves/bls12377/g2_curve.go @@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C } func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2AffinePointsMontgomery(points, true) } func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2AffinePointsMontgomery(points, false) } @@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) } func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2ProjectivePointsMontgomery(points, true) } func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2ProjectivePointsMontgomery(points, false) } diff --git a/wrappers/golang/curves/bls12377/g2_msm.go b/wrappers/golang/curves/bls12377/g2_msm.go index 28fb546b..fdb10690 100644 --- a/wrappers/golang/curves/bls12377/g2_msm.go +++ b/wrappers/golang/curves/bls12377/g2_msm.go @@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c core.MsmCheck(scalars, points, cfg, results) var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c var pointsPointer unsafe.Pointer if points.IsOnDevice() { - pointsPointer = points.(core.DeviceSlice).AsPointer() + pointsDevice := points.(core.DeviceSlice) + pointsDevice.CheckDevice() + pointsPointer = pointsDevice.AsPointer() } else { pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) } @@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0]) } diff --git a/wrappers/golang/curves/bls12377/g2_msm_test.go b/wrappers/golang/curves/bls12377/g2_msm_test.go index eb8432d6..e6193d81 100644 --- a/wrappers/golang/curves/bls12377/g2_msm_test.go +++ b/wrappers/golang/curves/bls12377/g2_msm_test.go @@ -3,9 +3,12 @@ package bls12377 import ( - "github.com/stretchr/testify/assert" + "fmt" + "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/ingonyama-zk/icicle/wrappers/golang/core" cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime" @@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor func TestMSMG2(t *testing.T) { cfg := GetDefaultMSMConfig() + cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { size := 1 << power @@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) { _, e := out.MallocAsync(p.Size(), p.Size(), stream) assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") cfg.Ctx.Stream = &stream + e = G2Msm(scalars, points, &cfg, out) assert.Equal(t, e, cr.CudaSuccess, "Msm failed") outHost := make(core.HostSlice[G2Projective], 1) - outHost.CopyFromDevice(&out) - out.Free() + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + cr.SynchronizeStream(&stream) // Check with gnark-crypto assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) } @@ -165,3 +171,41 @@ func TestMSMG2SkewedDistribution(t *testing.T) { assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) } } + +func TestMSMG2MultiDevice(t *testing.T) { + numDevices, _ := cr.GetDeviceCount() + fmt.Println("There are ", numDevices, " devices available") + wg := sync.WaitGroup{} + + for i := 0; i < numDevices; i++ { + wg.Add(1) + cr.RunOnDevice(i, func(args ...any) { + defer wg.Done() + cfg := GetDefaultMSMConfig() + cfg.IsAsync = true + for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { + size := 1 << power + scalars := GenerateScalars(size) + points := G2GenerateAffinePoints(size) + + stream, _ := cr.CreateStream() + var p G2Projective + var out core.DeviceSlice + _, e := out.MallocAsync(p.Size(), p.Size(), stream) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + cfg.Ctx.Stream = &stream + + e = G2Msm(scalars, points, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], 1) + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + + cr.SynchronizeStream(&stream) + // Check with gnark-crypto + assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) + } + }) + } + wg.Wait() +} diff --git a/wrappers/golang/curves/bls12377/msm.go b/wrappers/golang/curves/bls12377/msm.go index bad05683..5036396e 100644 --- a/wrappers/golang/curves/bls12377/msm.go +++ b/wrappers/golang/curves/bls12377/msm.go @@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor core.MsmCheck(scalars, points, cfg, results) var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor var pointsPointer unsafe.Pointer if points.IsOnDevice() { - pointsPointer = points.(core.DeviceSlice).AsPointer() + pointsDevice := points.(core.DeviceSlice) + pointsDevice.CheckDevice() + pointsPointer = pointsDevice.AsPointer() } else { pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) } @@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0]) } diff --git a/wrappers/golang/curves/bls12377/msm_test.go b/wrappers/golang/curves/bls12377/msm_test.go index 5c41051f..1ee36f8d 100644 --- a/wrappers/golang/curves/bls12377/msm_test.go +++ b/wrappers/golang/curves/bls12377/msm_test.go @@ -1,9 +1,12 @@ package bls12377 import ( - "github.com/stretchr/testify/assert" + "fmt" + "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/ingonyama-zk/icicle/wrappers/golang/core" cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime" @@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core. func TestMSM(t *testing.T) { cfg := GetDefaultMSMConfig() + cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { size := 1 << power @@ -65,12 +69,14 @@ func TestMSM(t *testing.T) { _, e := out.MallocAsync(p.Size(), p.Size(), stream) assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") cfg.Ctx.Stream = &stream + e = Msm(scalars, points, &cfg, out) assert.Equal(t, e, cr.CudaSuccess, "Msm failed") outHost := make(core.HostSlice[Projective], 1) - outHost.CopyFromDevice(&out) - out.Free() + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + cr.SynchronizeStream(&stream) // Check with gnark-crypto assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) } @@ -136,3 +142,41 @@ func TestMSMSkewedDistribution(t *testing.T) { assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) } } + +func TestMSMMultiDevice(t *testing.T) { + numDevices, _ := cr.GetDeviceCount() + fmt.Println("There are ", numDevices, " devices available") + wg := sync.WaitGroup{} + + for i := 0; i < numDevices; i++ { + wg.Add(1) + cr.RunOnDevice(i, func(args ...any) { + defer wg.Done() + cfg := GetDefaultMSMConfig() + cfg.IsAsync = true + for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { + size := 1 << power + scalars := GenerateScalars(size) + points := GenerateAffinePoints(size) + + stream, _ := cr.CreateStream() + var p Projective + var out core.DeviceSlice + _, e := out.MallocAsync(p.Size(), p.Size(), stream) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + cfg.Ctx.Stream = &stream + + e = Msm(scalars, points, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], 1) + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + + cr.SynchronizeStream(&stream) + // Check with gnark-crypto + assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) + } + }) + } + wg.Wait() +} diff --git a/wrappers/golang/curves/bls12377/ntt.go b/wrappers/golang/curves/bls12377/ntt.go index 8c15128a..7d8f58d1 100644 --- a/wrappers/golang/curves/bls12377/ntt.go +++ b/wrappers/golang/curves/bls12377/ntt.go @@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0]) } diff --git a/wrappers/golang/curves/bls12377/scalar_field.go b/wrappers/golang/curves/bls12377/scalar_field.go index d0676397..6725428b 100644 --- a/wrappers/golang/curves/bls12377/scalar_field.go +++ b/wrappers/golang/curves/bls12377/scalar_field.go @@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr } func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError { + scalars.CheckDevice() return convertScalarsMontgomery(scalars, true) } func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError { + scalars.CheckDevice() return convertScalarsMontgomery(scalars, false) } diff --git a/wrappers/golang/curves/bls12377/vec_ops.go b/wrappers/golang/curves/bls12377/vec_ops.go index c6abcc14..35c1367f 100644 --- a/wrappers/golang/curves/bls12377/vec_ops.go +++ b/wrappers/golang/curves/bls12377/vec_ops.go @@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V var cA, cB, cOut *C.scalar_t if a.IsOnDevice() { - cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer()) + aDevice := a.(core.DeviceSlice) + aDevice.CheckDevice() + cA = (*C.scalar_t)(aDevice.AsPointer()) } else { cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0])) } if b.IsOnDevice() { - cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer()) + bDevice := b.(core.DeviceSlice) + bDevice.CheckDevice() + cB = (*C.scalar_t)(bDevice.AsPointer()) } else { cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0])) } if out.IsOnDevice() { - cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer()) + outDevice := out.(core.DeviceSlice) + outDevice.CheckDevice() + cOut = (*C.scalar_t)(outDevice.AsPointer()) } else { cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0])) } diff --git a/wrappers/golang/curves/bls12381/curve.go b/wrappers/golang/curves/bls12381/curve.go index 2eed613a..97bd5b11 100644 --- a/wrappers/golang/curves/bls12381/curve.go +++ b/wrappers/golang/curves/bls12381/curve.go @@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud } func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertAffinePointsMontgomery(points, true) } func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertAffinePointsMontgomery(points, false) } @@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr } func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertProjectivePointsMontgomery(points, true) } func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertProjectivePointsMontgomery(points, false) } diff --git a/wrappers/golang/curves/bls12381/g2_curve.go b/wrappers/golang/curves/bls12381/g2_curve.go index 8bcde14f..68697b37 100644 --- a/wrappers/golang/curves/bls12381/g2_curve.go +++ b/wrappers/golang/curves/bls12381/g2_curve.go @@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C } func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2AffinePointsMontgomery(points, true) } func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2AffinePointsMontgomery(points, false) } @@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) } func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2ProjectivePointsMontgomery(points, true) } func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2ProjectivePointsMontgomery(points, false) } diff --git a/wrappers/golang/curves/bls12381/g2_msm.go b/wrappers/golang/curves/bls12381/g2_msm.go index 4ecae93c..433bc14e 100644 --- a/wrappers/golang/curves/bls12381/g2_msm.go +++ b/wrappers/golang/curves/bls12381/g2_msm.go @@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c core.MsmCheck(scalars, points, cfg, results) var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c var pointsPointer unsafe.Pointer if points.IsOnDevice() { - pointsPointer = points.(core.DeviceSlice).AsPointer() + pointsDevice := points.(core.DeviceSlice) + pointsDevice.CheckDevice() + pointsPointer = pointsDevice.AsPointer() } else { pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) } @@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0]) } diff --git a/wrappers/golang/curves/bls12381/g2_msm_test.go b/wrappers/golang/curves/bls12381/g2_msm_test.go index 9ce77512..f27acbaf 100644 --- a/wrappers/golang/curves/bls12381/g2_msm_test.go +++ b/wrappers/golang/curves/bls12381/g2_msm_test.go @@ -3,9 +3,12 @@ package bls12381 import ( - "github.com/stretchr/testify/assert" + "fmt" + "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/ingonyama-zk/icicle/wrappers/golang/core" cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime" @@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor func TestMSMG2(t *testing.T) { cfg := GetDefaultMSMConfig() + cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { size := 1 << power @@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) { _, e := out.MallocAsync(p.Size(), p.Size(), stream) assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") cfg.Ctx.Stream = &stream + e = G2Msm(scalars, points, &cfg, out) assert.Equal(t, e, cr.CudaSuccess, "Msm failed") outHost := make(core.HostSlice[G2Projective], 1) - outHost.CopyFromDevice(&out) - out.Free() + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + cr.SynchronizeStream(&stream) // Check with gnark-crypto assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) } @@ -165,3 +171,41 @@ func TestMSMG2SkewedDistribution(t *testing.T) { assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) } } + +func TestMSMG2MultiDevice(t *testing.T) { + numDevices, _ := cr.GetDeviceCount() + fmt.Println("There are ", numDevices, " devices available") + wg := sync.WaitGroup{} + + for i := 0; i < numDevices; i++ { + wg.Add(1) + cr.RunOnDevice(i, func(args ...any) { + defer wg.Done() + cfg := GetDefaultMSMConfig() + cfg.IsAsync = true + for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { + size := 1 << power + scalars := GenerateScalars(size) + points := G2GenerateAffinePoints(size) + + stream, _ := cr.CreateStream() + var p G2Projective + var out core.DeviceSlice + _, e := out.MallocAsync(p.Size(), p.Size(), stream) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + cfg.Ctx.Stream = &stream + + e = G2Msm(scalars, points, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], 1) + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + + cr.SynchronizeStream(&stream) + // Check with gnark-crypto + assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) + } + }) + } + wg.Wait() +} diff --git a/wrappers/golang/curves/bls12381/msm.go b/wrappers/golang/curves/bls12381/msm.go index 10d9f286..47923008 100644 --- a/wrappers/golang/curves/bls12381/msm.go +++ b/wrappers/golang/curves/bls12381/msm.go @@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor core.MsmCheck(scalars, points, cfg, results) var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor var pointsPointer unsafe.Pointer if points.IsOnDevice() { - pointsPointer = points.(core.DeviceSlice).AsPointer() + pointsDevice := points.(core.DeviceSlice) + pointsDevice.CheckDevice() + pointsPointer = pointsDevice.AsPointer() } else { pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) } @@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0]) } diff --git a/wrappers/golang/curves/bls12381/msm_test.go b/wrappers/golang/curves/bls12381/msm_test.go index 7f2d4935..0868d281 100644 --- a/wrappers/golang/curves/bls12381/msm_test.go +++ b/wrappers/golang/curves/bls12381/msm_test.go @@ -1,9 +1,12 @@ package bls12381 import ( - "github.com/stretchr/testify/assert" + "fmt" + "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/ingonyama-zk/icicle/wrappers/golang/core" cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime" @@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core. func TestMSM(t *testing.T) { cfg := GetDefaultMSMConfig() + cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { size := 1 << power @@ -65,12 +69,14 @@ func TestMSM(t *testing.T) { _, e := out.MallocAsync(p.Size(), p.Size(), stream) assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") cfg.Ctx.Stream = &stream + e = Msm(scalars, points, &cfg, out) assert.Equal(t, e, cr.CudaSuccess, "Msm failed") outHost := make(core.HostSlice[Projective], 1) - outHost.CopyFromDevice(&out) - out.Free() + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + cr.SynchronizeStream(&stream) // Check with gnark-crypto assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) } @@ -136,3 +142,41 @@ func TestMSMSkewedDistribution(t *testing.T) { assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) } } + +func TestMSMMultiDevice(t *testing.T) { + numDevices, _ := cr.GetDeviceCount() + fmt.Println("There are ", numDevices, " devices available") + wg := sync.WaitGroup{} + + for i := 0; i < numDevices; i++ { + wg.Add(1) + cr.RunOnDevice(i, func(args ...any) { + defer wg.Done() + cfg := GetDefaultMSMConfig() + cfg.IsAsync = true + for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { + size := 1 << power + scalars := GenerateScalars(size) + points := GenerateAffinePoints(size) + + stream, _ := cr.CreateStream() + var p Projective + var out core.DeviceSlice + _, e := out.MallocAsync(p.Size(), p.Size(), stream) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + cfg.Ctx.Stream = &stream + + e = Msm(scalars, points, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], 1) + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + + cr.SynchronizeStream(&stream) + // Check with gnark-crypto + assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) + } + }) + } + wg.Wait() +} diff --git a/wrappers/golang/curves/bls12381/ntt.go b/wrappers/golang/curves/bls12381/ntt.go index e2ef983b..320603eb 100644 --- a/wrappers/golang/curves/bls12381/ntt.go +++ b/wrappers/golang/curves/bls12381/ntt.go @@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0]) } diff --git a/wrappers/golang/curves/bls12381/scalar_field.go b/wrappers/golang/curves/bls12381/scalar_field.go index f5362531..fc395a8c 100644 --- a/wrappers/golang/curves/bls12381/scalar_field.go +++ b/wrappers/golang/curves/bls12381/scalar_field.go @@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr } func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError { + scalars.CheckDevice() return convertScalarsMontgomery(scalars, true) } func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError { + scalars.CheckDevice() return convertScalarsMontgomery(scalars, false) } diff --git a/wrappers/golang/curves/bls12381/vec_ops.go b/wrappers/golang/curves/bls12381/vec_ops.go index 9786bbf1..e6ed4438 100644 --- a/wrappers/golang/curves/bls12381/vec_ops.go +++ b/wrappers/golang/curves/bls12381/vec_ops.go @@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V var cA, cB, cOut *C.scalar_t if a.IsOnDevice() { - cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer()) + aDevice := a.(core.DeviceSlice) + aDevice.CheckDevice() + cA = (*C.scalar_t)(aDevice.AsPointer()) } else { cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0])) } if b.IsOnDevice() { - cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer()) + bDevice := b.(core.DeviceSlice) + bDevice.CheckDevice() + cB = (*C.scalar_t)(bDevice.AsPointer()) } else { cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0])) } if out.IsOnDevice() { - cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer()) + outDevice := out.(core.DeviceSlice) + outDevice.CheckDevice() + cOut = (*C.scalar_t)(outDevice.AsPointer()) } else { cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0])) } diff --git a/wrappers/golang/curves/bn254/curve.go b/wrappers/golang/curves/bn254/curve.go index e1216903..4430c00b 100644 --- a/wrappers/golang/curves/bn254/curve.go +++ b/wrappers/golang/curves/bn254/curve.go @@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud } func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertAffinePointsMontgomery(points, true) } func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertAffinePointsMontgomery(points, false) } @@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr } func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertProjectivePointsMontgomery(points, true) } func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertProjectivePointsMontgomery(points, false) } diff --git a/wrappers/golang/curves/bn254/g2_curve.go b/wrappers/golang/curves/bn254/g2_curve.go index 95881d03..63e793d8 100644 --- a/wrappers/golang/curves/bn254/g2_curve.go +++ b/wrappers/golang/curves/bn254/g2_curve.go @@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C } func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2AffinePointsMontgomery(points, true) } func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2AffinePointsMontgomery(points, false) } @@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) } func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2ProjectivePointsMontgomery(points, true) } func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2ProjectivePointsMontgomery(points, false) } diff --git a/wrappers/golang/curves/bn254/g2_msm.go b/wrappers/golang/curves/bn254/g2_msm.go index b8a9a2b7..11e4e449 100644 --- a/wrappers/golang/curves/bn254/g2_msm.go +++ b/wrappers/golang/curves/bn254/g2_msm.go @@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c core.MsmCheck(scalars, points, cfg, results) var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c var pointsPointer unsafe.Pointer if points.IsOnDevice() { - pointsPointer = points.(core.DeviceSlice).AsPointer() + pointsDevice := points.(core.DeviceSlice) + pointsDevice.CheckDevice() + pointsPointer = pointsDevice.AsPointer() } else { pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) } @@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0]) } diff --git a/wrappers/golang/curves/bn254/g2_msm_test.go b/wrappers/golang/curves/bn254/g2_msm_test.go index c01b3e74..4dd1feef 100644 --- a/wrappers/golang/curves/bn254/g2_msm_test.go +++ b/wrappers/golang/curves/bn254/g2_msm_test.go @@ -3,9 +3,12 @@ package bn254 import ( - "github.com/stretchr/testify/assert" + "fmt" + "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/ingonyama-zk/icicle/wrappers/golang/core" cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime" @@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor func TestMSMG2(t *testing.T) { cfg := GetDefaultMSMConfig() + cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { size := 1 << power @@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) { _, e := out.MallocAsync(p.Size(), p.Size(), stream) assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") cfg.Ctx.Stream = &stream + e = G2Msm(scalars, points, &cfg, out) assert.Equal(t, e, cr.CudaSuccess, "Msm failed") outHost := make(core.HostSlice[G2Projective], 1) - outHost.CopyFromDevice(&out) - out.Free() + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + cr.SynchronizeStream(&stream) // Check with gnark-crypto assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) } @@ -165,3 +171,41 @@ func TestMSMG2SkewedDistribution(t *testing.T) { assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) } } + +func TestMSMG2MultiDevice(t *testing.T) { + numDevices, _ := cr.GetDeviceCount() + fmt.Println("There are ", numDevices, " devices available") + wg := sync.WaitGroup{} + + for i := 0; i < numDevices; i++ { + wg.Add(1) + cr.RunOnDevice(i, func(args ...any) { + defer wg.Done() + cfg := GetDefaultMSMConfig() + cfg.IsAsync = true + for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { + size := 1 << power + scalars := GenerateScalars(size) + points := G2GenerateAffinePoints(size) + + stream, _ := cr.CreateStream() + var p G2Projective + var out core.DeviceSlice + _, e := out.MallocAsync(p.Size(), p.Size(), stream) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + cfg.Ctx.Stream = &stream + + e = G2Msm(scalars, points, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], 1) + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + + cr.SynchronizeStream(&stream) + // Check with gnark-crypto + assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) + } + }) + } + wg.Wait() +} diff --git a/wrappers/golang/curves/bn254/msm.go b/wrappers/golang/curves/bn254/msm.go index 5ffea612..cec2da12 100644 --- a/wrappers/golang/curves/bn254/msm.go +++ b/wrappers/golang/curves/bn254/msm.go @@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor core.MsmCheck(scalars, points, cfg, results) var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor var pointsPointer unsafe.Pointer if points.IsOnDevice() { - pointsPointer = points.(core.DeviceSlice).AsPointer() + pointsDevice := points.(core.DeviceSlice) + pointsDevice.CheckDevice() + pointsPointer = pointsDevice.AsPointer() } else { pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) } @@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0]) } diff --git a/wrappers/golang/curves/bn254/msm_test.go b/wrappers/golang/curves/bn254/msm_test.go index 39834dfb..d1853504 100644 --- a/wrappers/golang/curves/bn254/msm_test.go +++ b/wrappers/golang/curves/bn254/msm_test.go @@ -1,9 +1,12 @@ package bn254 import ( - "github.com/stretchr/testify/assert" + "fmt" + "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/ingonyama-zk/icicle/wrappers/golang/core" cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime" @@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core. func TestMSM(t *testing.T) { cfg := GetDefaultMSMConfig() + cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { size := 1 << power @@ -65,12 +69,14 @@ func TestMSM(t *testing.T) { _, e := out.MallocAsync(p.Size(), p.Size(), stream) assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") cfg.Ctx.Stream = &stream + e = Msm(scalars, points, &cfg, out) assert.Equal(t, e, cr.CudaSuccess, "Msm failed") outHost := make(core.HostSlice[Projective], 1) - outHost.CopyFromDevice(&out) - out.Free() + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + cr.SynchronizeStream(&stream) // Check with gnark-crypto assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) } @@ -136,3 +142,41 @@ func TestMSMSkewedDistribution(t *testing.T) { assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) } } + +func TestMSMMultiDevice(t *testing.T) { + numDevices, _ := cr.GetDeviceCount() + fmt.Println("There are ", numDevices, " devices available") + wg := sync.WaitGroup{} + + for i := 0; i < numDevices; i++ { + wg.Add(1) + cr.RunOnDevice(i, func(args ...any) { + defer wg.Done() + cfg := GetDefaultMSMConfig() + cfg.IsAsync = true + for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { + size := 1 << power + scalars := GenerateScalars(size) + points := GenerateAffinePoints(size) + + stream, _ := cr.CreateStream() + var p Projective + var out core.DeviceSlice + _, e := out.MallocAsync(p.Size(), p.Size(), stream) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + cfg.Ctx.Stream = &stream + + e = Msm(scalars, points, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], 1) + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + + cr.SynchronizeStream(&stream) + // Check with gnark-crypto + assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) + } + }) + } + wg.Wait() +} diff --git a/wrappers/golang/curves/bn254/ntt.go b/wrappers/golang/curves/bn254/ntt.go index 3b9c54a1..360d3f63 100644 --- a/wrappers/golang/curves/bn254/ntt.go +++ b/wrappers/golang/curves/bn254/ntt.go @@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0]) } diff --git a/wrappers/golang/curves/bn254/scalar_field.go b/wrappers/golang/curves/bn254/scalar_field.go index 9edf9ee8..1b21d061 100644 --- a/wrappers/golang/curves/bn254/scalar_field.go +++ b/wrappers/golang/curves/bn254/scalar_field.go @@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr } func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError { + scalars.CheckDevice() return convertScalarsMontgomery(scalars, true) } func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError { + scalars.CheckDevice() return convertScalarsMontgomery(scalars, false) } diff --git a/wrappers/golang/curves/bn254/vec_ops.go b/wrappers/golang/curves/bn254/vec_ops.go index 8373cf4e..78f9b917 100644 --- a/wrappers/golang/curves/bn254/vec_ops.go +++ b/wrappers/golang/curves/bn254/vec_ops.go @@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V var cA, cB, cOut *C.scalar_t if a.IsOnDevice() { - cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer()) + aDevice := a.(core.DeviceSlice) + aDevice.CheckDevice() + cA = (*C.scalar_t)(aDevice.AsPointer()) } else { cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0])) } if b.IsOnDevice() { - cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer()) + bDevice := b.(core.DeviceSlice) + bDevice.CheckDevice() + cB = (*C.scalar_t)(bDevice.AsPointer()) } else { cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0])) } if out.IsOnDevice() { - cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer()) + outDevice := out.(core.DeviceSlice) + outDevice.CheckDevice() + cOut = (*C.scalar_t)(outDevice.AsPointer()) } else { cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0])) } diff --git a/wrappers/golang/curves/bw6761/curve.go b/wrappers/golang/curves/bw6761/curve.go index 0cb5fe20..3c492e66 100644 --- a/wrappers/golang/curves/bw6761/curve.go +++ b/wrappers/golang/curves/bw6761/curve.go @@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud } func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertAffinePointsMontgomery(points, true) } func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertAffinePointsMontgomery(points, false) } @@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr } func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertProjectivePointsMontgomery(points, true) } func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertProjectivePointsMontgomery(points, false) } diff --git a/wrappers/golang/curves/bw6761/g2_curve.go b/wrappers/golang/curves/bw6761/g2_curve.go index fc872f26..035bbd94 100644 --- a/wrappers/golang/curves/bw6761/g2_curve.go +++ b/wrappers/golang/curves/bw6761/g2_curve.go @@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C } func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2AffinePointsMontgomery(points, true) } func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2AffinePointsMontgomery(points, false) } @@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) } func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2ProjectivePointsMontgomery(points, true) } func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convertG2ProjectivePointsMontgomery(points, false) } diff --git a/wrappers/golang/curves/bw6761/g2_msm.go b/wrappers/golang/curves/bw6761/g2_msm.go index 8d9a320a..e3b31c1e 100644 --- a/wrappers/golang/curves/bw6761/g2_msm.go +++ b/wrappers/golang/curves/bw6761/g2_msm.go @@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c core.MsmCheck(scalars, points, cfg, results) var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c var pointsPointer unsafe.Pointer if points.IsOnDevice() { - pointsPointer = points.(core.DeviceSlice).AsPointer() + pointsDevice := points.(core.DeviceSlice) + pointsDevice.CheckDevice() + pointsPointer = pointsDevice.AsPointer() } else { pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0]) } @@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0]) } diff --git a/wrappers/golang/curves/bw6761/g2_msm_test.go b/wrappers/golang/curves/bw6761/g2_msm_test.go index 388e7dbc..ba916dfc 100644 --- a/wrappers/golang/curves/bw6761/g2_msm_test.go +++ b/wrappers/golang/curves/bw6761/g2_msm_test.go @@ -3,9 +3,12 @@ package bw6761 import ( - "github.com/stretchr/testify/assert" + "fmt" + "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/ingonyama-zk/icicle/wrappers/golang/core" cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime" @@ -55,6 +58,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor func TestMSMG2(t *testing.T) { cfg := GetDefaultMSMConfig() + cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { size := 1 << power @@ -67,12 +71,14 @@ func TestMSMG2(t *testing.T) { _, e := out.MallocAsync(p.Size(), p.Size(), stream) assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") cfg.Ctx.Stream = &stream + e = G2Msm(scalars, points, &cfg, out) assert.Equal(t, e, cr.CudaSuccess, "Msm failed") outHost := make(core.HostSlice[G2Projective], 1) - outHost.CopyFromDevice(&out) - out.Free() + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + cr.SynchronizeStream(&stream) // Check with gnark-crypto assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) } @@ -138,3 +144,41 @@ func TestMSMG2SkewedDistribution(t *testing.T) { assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) } } + +func TestMSMG2MultiDevice(t *testing.T) { + numDevices, _ := cr.GetDeviceCount() + fmt.Println("There are ", numDevices, " devices available") + wg := sync.WaitGroup{} + + for i := 0; i < numDevices; i++ { + wg.Add(1) + cr.RunOnDevice(i, func(args ...any) { + defer wg.Done() + cfg := GetDefaultMSMConfig() + cfg.IsAsync = true + for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { + size := 1 << power + scalars := GenerateScalars(size) + points := G2GenerateAffinePoints(size) + + stream, _ := cr.CreateStream() + var p G2Projective + var out core.DeviceSlice + _, e := out.MallocAsync(p.Size(), p.Size(), stream) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + cfg.Ctx.Stream = &stream + + e = G2Msm(scalars, points, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[G2Projective], 1) + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + + cr.SynchronizeStream(&stream) + // Check with gnark-crypto + assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0])) + } + }) + } + wg.Wait() +} diff --git a/wrappers/golang/curves/bw6761/msm.go b/wrappers/golang/curves/bw6761/msm.go index 9bd2a8ce..6b4eb882 100644 --- a/wrappers/golang/curves/bw6761/msm.go +++ b/wrappers/golang/curves/bw6761/msm.go @@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor core.MsmCheck(scalars, points, cfg, results) var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor var pointsPointer unsafe.Pointer if points.IsOnDevice() { - pointsPointer = points.(core.DeviceSlice).AsPointer() + pointsDevice := points.(core.DeviceSlice) + pointsDevice.CheckDevice() + pointsPointer = pointsDevice.AsPointer() } else { pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0]) } @@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0]) } diff --git a/wrappers/golang/curves/bw6761/msm_test.go b/wrappers/golang/curves/bw6761/msm_test.go index 415f88cb..943a8718 100644 --- a/wrappers/golang/curves/bw6761/msm_test.go +++ b/wrappers/golang/curves/bw6761/msm_test.go @@ -1,9 +1,12 @@ package bw6761 import ( - "github.com/stretchr/testify/assert" + "fmt" + "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/ingonyama-zk/icicle/wrappers/golang/core" cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime" @@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core. func TestMSM(t *testing.T) { cfg := GetDefaultMSMConfig() + cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { size := 1 << power @@ -65,12 +69,14 @@ func TestMSM(t *testing.T) { _, e := out.MallocAsync(p.Size(), p.Size(), stream) assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") cfg.Ctx.Stream = &stream + e = Msm(scalars, points, &cfg, out) assert.Equal(t, e, cr.CudaSuccess, "Msm failed") outHost := make(core.HostSlice[Projective], 1) - outHost.CopyFromDevice(&out) - out.Free() + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + cr.SynchronizeStream(&stream) // Check with gnark-crypto assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) } @@ -136,3 +142,41 @@ func TestMSMSkewedDistribution(t *testing.T) { assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) } } + +func TestMSMMultiDevice(t *testing.T) { + numDevices, _ := cr.GetDeviceCount() + fmt.Println("There are ", numDevices, " devices available") + wg := sync.WaitGroup{} + + for i := 0; i < numDevices; i++ { + wg.Add(1) + cr.RunOnDevice(i, func(args ...any) { + defer wg.Done() + cfg := GetDefaultMSMConfig() + cfg.IsAsync = true + for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { + size := 1 << power + scalars := GenerateScalars(size) + points := GenerateAffinePoints(size) + + stream, _ := cr.CreateStream() + var p Projective + var out core.DeviceSlice + _, e := out.MallocAsync(p.Size(), p.Size(), stream) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + cfg.Ctx.Stream = &stream + + e = Msm(scalars, points, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[Projective], 1) + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + + cr.SynchronizeStream(&stream) + // Check with gnark-crypto + assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0])) + } + }) + } + wg.Wait() +} diff --git a/wrappers/golang/curves/bw6761/ntt.go b/wrappers/golang/curves/bw6761/ntt.go index 009fff23..64bf1bfa 100644 --- a/wrappers/golang/curves/bw6761/ntt.go +++ b/wrappers/golang/curves/bw6761/ntt.go @@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0]) } diff --git a/wrappers/golang/curves/bw6761/scalar_field.go b/wrappers/golang/curves/bw6761/scalar_field.go index cabefd2c..7d2e65b4 100644 --- a/wrappers/golang/curves/bw6761/scalar_field.go +++ b/wrappers/golang/curves/bw6761/scalar_field.go @@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr } func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError { + scalars.CheckDevice() return convertScalarsMontgomery(scalars, true) } func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError { + scalars.CheckDevice() return convertScalarsMontgomery(scalars, false) } diff --git a/wrappers/golang/curves/bw6761/vec_ops.go b/wrappers/golang/curves/bw6761/vec_ops.go index 123716f8..ba3167aa 100644 --- a/wrappers/golang/curves/bw6761/vec_ops.go +++ b/wrappers/golang/curves/bw6761/vec_ops.go @@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V var cA, cB, cOut *C.scalar_t if a.IsOnDevice() { - cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer()) + aDevice := a.(core.DeviceSlice) + aDevice.CheckDevice() + cA = (*C.scalar_t)(aDevice.AsPointer()) } else { cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0])) } if b.IsOnDevice() { - cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer()) + bDevice := b.(core.DeviceSlice) + bDevice.CheckDevice() + cB = (*C.scalar_t)(bDevice.AsPointer()) } else { cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0])) } if out.IsOnDevice() { - cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer()) + outDevice := out.(core.DeviceSlice) + outDevice.CheckDevice() + cOut = (*C.scalar_t)(outDevice.AsPointer()) } else { cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0])) } diff --git a/wrappers/golang/internal/generator/templates/curve.go.tmpl b/wrappers/golang/internal/generator/templates/curve.go.tmpl index 33f319ac..ced2244c 100644 --- a/wrappers/golang/internal/generator/templates/curve.go.tmpl +++ b/wrappers/golang/internal/generator/templates/curve.go.tmpl @@ -150,10 +150,12 @@ func convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points *core.DeviceSlice } func {{if .IsG2}}G2{{end}}AffineToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points, true) } func {{if .IsG2}}G2{{end}}AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points, false) } @@ -169,10 +171,12 @@ func convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points *core.DeviceS } func {{if .IsG2}}G2{{end}}ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points, true) } func {{if .IsG2}}G2{{end}}ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError { + points.CheckDevice() return convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points, false) } {{end}} \ No newline at end of file diff --git a/wrappers/golang/internal/generator/templates/msm.go.tmpl b/wrappers/golang/internal/generator/templates/msm.go.tmpl index f86f4b02..0eec7d81 100644 --- a/wrappers/golang/internal/generator/templates/msm.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm.go.tmpl @@ -22,7 +22,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr core.MsmCheck(scalars, points, cfg, results) var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -30,7 +32,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr var pointsPointer unsafe.Pointer if points.IsOnDevice() { - pointsPointer = points.(core.DeviceSlice).AsPointer() + pointsDevice := points.(core.DeviceSlice) + pointsDevice.CheckDevice() + pointsPointer = pointsDevice.AsPointer() } else { pointsPointer = unsafe.Pointer(&points.(core.HostSlice[{{if .IsG2}}G2{{end}}Affine])[0]) } @@ -38,7 +42,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[{{if .IsG2}}G2{{end}}Projective])[0]) } diff --git a/wrappers/golang/internal/generator/templates/msm_test.go.tmpl b/wrappers/golang/internal/generator/templates/msm_test.go.tmpl index 9c3bbd20..7dcb478f 100644 --- a/wrappers/golang/internal/generator/templates/msm_test.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm_test.go.tmpl @@ -5,9 +5,12 @@ package {{.PackageName}} import ( - "github.com/stretchr/testify/assert" + "fmt" + "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/ingonyama-zk/icicle/wrappers/golang/core" cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime" @@ -102,6 +105,7 @@ func testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars core.HostSlice[Scala func TestMSM{{if .IsG2}}G2{{end}}(t *testing.T) { cfg := GetDefaultMSMConfig() + cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { size := 1 << power @@ -114,12 +118,14 @@ func TestMSM{{if .IsG2}}G2{{end}}(t *testing.T) { _, e := out.MallocAsync(p.Size(), p.Size(), stream) assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") cfg.Ctx.Stream = &stream + e = {{if .IsG2}}G2{{end}}Msm(scalars, points, &cfg, out) assert.Equal(t, e, cr.CudaSuccess, "Msm failed") outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], 1) - outHost.CopyFromDevice(&out) - out.Free() + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + cr.SynchronizeStream(&stream) // Check with gnark-crypto assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0])) } @@ -185,3 +191,41 @@ func TestMSM{{if .IsG2}}G2{{end}}SkewedDistribution(t *testing.T) { assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0])) } } + +func TestMSM{{if .IsG2}}G2{{end}}MultiDevice(t *testing.T) { + numDevices, _ := cr.GetDeviceCount() + fmt.Println("There are ", numDevices, " devices available") + wg := sync.WaitGroup{} + + for i := 0; i < numDevices; i++ { + wg.Add(1) + cr.RunOnDevice(i, func(args ...any) { + defer wg.Done() + cfg := GetDefaultMSMConfig() + cfg.IsAsync = true + for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { + size := 1 << power + scalars := GenerateScalars(size) + points := {{if .IsG2}}G2{{end}}GenerateAffinePoints(size) + + stream, _ := cr.CreateStream() + var p {{if .IsG2}}G2{{end}}Projective + var out core.DeviceSlice + _, e := out.MallocAsync(p.Size(), p.Size(), stream) + assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed") + cfg.Ctx.Stream = &stream + + e = {{if .IsG2}}G2{{end}}Msm(scalars, points, &cfg, out) + assert.Equal(t, e, cr.CudaSuccess, "Msm failed") + outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], 1) + outHost.CopyFromDeviceAsync(&out, stream) + out.FreeAsync(stream) + + cr.SynchronizeStream(&stream) + // Check with gnark-crypto + assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0])) + } + }) + } + wg.Wait() +} diff --git a/wrappers/golang/internal/generator/templates/ntt.go.tmpl b/wrappers/golang/internal/generator/templates/ntt.go.tmpl index 12cd11f2..2fb6623a 100644 --- a/wrappers/golang/internal/generator/templates/ntt.go.tmpl +++ b/wrappers/golang/internal/generator/templates/ntt.go.tmpl @@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo var scalarsPointer unsafe.Pointer if scalars.IsOnDevice() { - scalarsPointer = scalars.(core.DeviceSlice).AsPointer() + scalarsDevice := scalars.(core.DeviceSlice) + scalarsDevice.CheckDevice() + scalarsPointer = scalarsDevice.AsPointer() } else { scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0]) } @@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo var resultsPointer unsafe.Pointer if results.IsOnDevice() { - resultsPointer = results.(core.DeviceSlice).AsPointer() + resultsDevice := results.(core.DeviceSlice) + resultsDevice.CheckDevice() + resultsPointer = resultsDevice.AsPointer() } else { resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0]) } diff --git a/wrappers/golang/internal/generator/templates/scalar_field.go.tmpl b/wrappers/golang/internal/generator/templates/scalar_field.go.tmpl index 0402787e..723eb3fa 100644 --- a/wrappers/golang/internal/generator/templates/scalar_field.go.tmpl +++ b/wrappers/golang/internal/generator/templates/scalar_field.go.tmpl @@ -33,9 +33,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr } func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError { + scalars.CheckDevice() return convertScalarsMontgomery(scalars, true) } func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError { + scalars.CheckDevice() return convertScalarsMontgomery(scalars, false) }{{- end}} \ No newline at end of file diff --git a/wrappers/golang/internal/generator/templates/vec_ops.go.tmpl b/wrappers/golang/internal/generator/templates/vec_ops.go.tmpl index ed9ab38f..0e887229 100644 --- a/wrappers/golang/internal/generator/templates/vec_ops.go.tmpl +++ b/wrappers/golang/internal/generator/templates/vec_ops.go.tmpl @@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V var cA, cB, cOut *C.scalar_t if a.IsOnDevice() { - cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer()) + aDevice := a.(core.DeviceSlice) + aDevice.CheckDevice() + cA = (*C.scalar_t)(aDevice.AsPointer()) } else { cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0])) } if b.IsOnDevice() { - cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer()) + bDevice := b.(core.DeviceSlice) + bDevice.CheckDevice() + cB = (*C.scalar_t)(bDevice.AsPointer()) } else { cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0])) } if out.IsOnDevice() { - cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer()) + outDevice := out.(core.DeviceSlice) + outDevice.CheckDevice() + cOut = (*C.scalar_t)(outDevice.AsPointer()) } else { cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0])) }