mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-09 15:37:58 -05:00
Add vector operations for golang bindings (#399)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
)
|
||||
|
||||
type IcicleErrorCode int
|
||||
@@ -16,13 +16,13 @@ const (
|
||||
|
||||
type IcicleError struct {
|
||||
IcicleErrorCode IcicleErrorCode
|
||||
CudaErrorCode cuda_runtime.CudaError
|
||||
CudaErrorCode cr.CudaError
|
||||
reason string
|
||||
}
|
||||
|
||||
func FromCudaError(error cuda_runtime.CudaError) (err IcicleError) {
|
||||
func FromCudaError(error cr.CudaError) (err IcicleError) {
|
||||
switch error {
|
||||
case cuda_runtime.CudaSuccess:
|
||||
case cr.CudaSuccess:
|
||||
err.IcicleErrorCode = IcicleSuccess
|
||||
default:
|
||||
err.IcicleErrorCode = InternalCudaError
|
||||
@@ -38,6 +38,6 @@ func FromCodeAndReason(code IcicleErrorCode, reason string) IcicleError {
|
||||
return IcicleError{
|
||||
IcicleErrorCode: code,
|
||||
reason: reason,
|
||||
CudaErrorCode: cuda_runtime.CudaErrorUnknown,
|
||||
CudaErrorCode: cr.CudaErrorUnknown,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,12 @@ package core
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
)
|
||||
|
||||
type MSMConfig struct {
|
||||
/// Details related to the device such as its id and stream.
|
||||
Ctx cuda_runtime.DeviceContext
|
||||
Ctx cr.DeviceContext
|
||||
|
||||
pointsSize int32
|
||||
|
||||
@@ -55,13 +55,8 @@ type MSMConfig struct {
|
||||
IsAsync bool
|
||||
}
|
||||
|
||||
// type MSM interface {
|
||||
// Msm(scalars, points *cuda_runtime.HostOrDeviceSlice, cfg *MSMConfig, results *cuda_runtime.HostOrDeviceSlice) cuda_runtime.CudaError
|
||||
// GetDefaultMSMConfig() MSMConfig
|
||||
// }
|
||||
|
||||
func GetDefaultMSMConfig() MSMConfig {
|
||||
ctx, _ := cuda_runtime.GetDefaultDeviceContext()
|
||||
ctx, _ := cr.GetDefaultDeviceContext()
|
||||
return MSMConfig{
|
||||
ctx, // Ctx
|
||||
0, // pointsSize
|
||||
|
||||
@@ -4,13 +4,13 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core/internal"
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMSMDefaultConfig(t *testing.T) {
|
||||
ctx, _ := cuda_runtime.GetDefaultDeviceContext()
|
||||
ctx, _ := cr.GetDefaultDeviceContext()
|
||||
expected := MSMConfig{
|
||||
ctx, // Ctx
|
||||
0, // pointsSize
|
||||
|
||||
@@ -3,7 +3,7 @@ package core
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
)
|
||||
|
||||
type NTTDir int8
|
||||
@@ -26,7 +26,7 @@ const (
|
||||
|
||||
type NTTConfig[T any] struct {
|
||||
/// Details related to the device such as its id and stream id. See [DeviceContext](@ref device_context::DeviceContext).
|
||||
Ctx cuda_runtime.DeviceContext
|
||||
Ctx cr.DeviceContext
|
||||
/// Coset generator. Used to perform coset (i)NTTs. Default value: `S::one()` (corresponding to no coset being used).
|
||||
CosetGen T
|
||||
/// The number of NTTs to compute. Default value: 1.
|
||||
@@ -41,7 +41,7 @@ type NTTConfig[T any] struct {
|
||||
}
|
||||
|
||||
func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] {
|
||||
ctx, _ := cuda_runtime.GetDefaultDeviceContext()
|
||||
ctx, _ := cr.GetDefaultDeviceContext()
|
||||
return NTTConfig[T]{
|
||||
ctx, // Ctx
|
||||
cosetGen, // CosetGen
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core/internal"
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ func TestNTTDefaultConfig(t *testing.T) {
|
||||
cosetGenField.One()
|
||||
var cosetGen [1]uint32
|
||||
copy(cosetGen[:], cosetGenField.GetLimbs())
|
||||
ctx, _ := cuda_runtime.GetDefaultDeviceContext()
|
||||
ctx, _ := cr.GetDefaultDeviceContext()
|
||||
expected := NTTConfig[[1]uint32]{
|
||||
ctx, // Ctx
|
||||
cosetGen, // CosetGen
|
||||
|
||||
@@ -3,7 +3,7 @@ package core
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
)
|
||||
|
||||
type HostOrDeviceSlice interface {
|
||||
@@ -45,25 +45,25 @@ func (d DeviceSlice) IsOnDevice() bool {
|
||||
|
||||
// TODO: change signature to be Malloc(element, numElements)
|
||||
// calc size internally
|
||||
func (d *DeviceSlice) Malloc(size, sizeOfElement int) (DeviceSlice, cuda_runtime.CudaError) {
|
||||
dp, err := cuda_runtime.Malloc(uint(size))
|
||||
func (d *DeviceSlice) Malloc(size, sizeOfElement int) (DeviceSlice, cr.CudaError) {
|
||||
dp, err := cr.Malloc(uint(size))
|
||||
d.inner = dp
|
||||
d.capacity = size
|
||||
d.length = size / sizeOfElement
|
||||
return *d, err
|
||||
}
|
||||
|
||||
func (d *DeviceSlice) MallocAsync(size, sizeOfElement int, stream cuda_runtime.CudaStream) (DeviceSlice, cuda_runtime.CudaError) {
|
||||
dp, err := cuda_runtime.MallocAsync(uint(size), stream)
|
||||
func (d *DeviceSlice) MallocAsync(size, sizeOfElement int, stream cr.CudaStream) (DeviceSlice, cr.CudaError) {
|
||||
dp, err := cr.MallocAsync(uint(size), stream)
|
||||
d.inner = dp
|
||||
d.capacity = size
|
||||
d.length = size / sizeOfElement
|
||||
return *d, err
|
||||
}
|
||||
|
||||
func (d *DeviceSlice) Free() cuda_runtime.CudaError {
|
||||
err := cuda_runtime.Free(d.inner)
|
||||
if err == cuda_runtime.CudaSuccess {
|
||||
func (d *DeviceSlice) Free() cr.CudaError {
|
||||
err := cr.Free(d.inner)
|
||||
if err == cr.CudaSuccess {
|
||||
d.length, d.capacity = 0, 0
|
||||
d.inner = nil
|
||||
}
|
||||
@@ -123,12 +123,12 @@ func (h HostSlice[T]) CopyToDevice(dst *DeviceSlice, shouldAllocate bool) *Devic
|
||||
|
||||
// hostSrc := unsafe.Pointer(h.AsPointer())
|
||||
hostSrc := unsafe.Pointer(&h[0])
|
||||
cuda_runtime.CopyToDevice(dst.inner, hostSrc, uint(size))
|
||||
cr.CopyToDevice(dst.inner, hostSrc, uint(size))
|
||||
dst.length = h.Len()
|
||||
return dst
|
||||
}
|
||||
|
||||
func (h HostSlice[T]) CopyToDeviceAsync(dst *DeviceSlice, stream cuda_runtime.CudaStream, shouldAllocate bool) *DeviceSlice {
|
||||
func (h HostSlice[T]) CopyToDeviceAsync(dst *DeviceSlice, stream cr.CudaStream, shouldAllocate bool) *DeviceSlice {
|
||||
size := h.Len() * h.SizeOfElement()
|
||||
if shouldAllocate {
|
||||
dst.MallocAsync(size, h.SizeOfElement(), stream)
|
||||
@@ -138,7 +138,7 @@ func (h HostSlice[T]) CopyToDeviceAsync(dst *DeviceSlice, stream cuda_runtime.Cu
|
||||
}
|
||||
|
||||
hostSrc := unsafe.Pointer(&h[0])
|
||||
cuda_runtime.CopyToDeviceAsync(dst.inner, hostSrc, uint(size), stream)
|
||||
cr.CopyToDeviceAsync(dst.inner, hostSrc, uint(size), stream)
|
||||
dst.length = h.Len()
|
||||
return dst
|
||||
}
|
||||
@@ -148,13 +148,13 @@ func (h HostSlice[T]) CopyFromDevice(src *DeviceSlice) {
|
||||
panic("destination and source slices have different lengths")
|
||||
}
|
||||
bytesSize := src.Len() * h.SizeOfElement()
|
||||
cuda_runtime.CopyFromDevice(unsafe.Pointer(&h[0]), src.inner, uint(bytesSize))
|
||||
cr.CopyFromDevice(unsafe.Pointer(&h[0]), src.inner, uint(bytesSize))
|
||||
}
|
||||
|
||||
func (h HostSlice[T]) CopyFromDeviceAsync(src *DeviceSlice, stream cuda_runtime.Stream) {
|
||||
func (h HostSlice[T]) CopyFromDeviceAsync(src *DeviceSlice, stream cr.Stream) {
|
||||
if h.Len() != src.Len() {
|
||||
panic("destination and source slices have different lengths")
|
||||
}
|
||||
bytesSize := src.Len() * h.SizeOfElement()
|
||||
cuda_runtime.CopyFromDeviceAsync(unsafe.Pointer(&h[0]), src.inner, uint(bytesSize), stream)
|
||||
cr.CopyFromDeviceAsync(unsafe.Pointer(&h[0]), src.inner, uint(bytesSize), stream)
|
||||
}
|
||||
|
||||
74
wrappers/golang/core/vec_ops.go
Normal file
74
wrappers/golang/core/vec_ops.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
)
|
||||
|
||||
type VecOps int
|
||||
|
||||
const (
|
||||
Sub VecOps = iota
|
||||
Add
|
||||
Mul
|
||||
)
|
||||
|
||||
type VecOpsConfig struct {
|
||||
/*Details related to the device such as its id and stream. */
|
||||
Ctx cr.DeviceContext
|
||||
/* True if `a` is on device and false if it is not. Default value: false. */
|
||||
isAOnDevice bool
|
||||
/* True if `b` is on device and false if it is not. Default value: false. */
|
||||
isBOnDevice bool
|
||||
/* If true, output is preserved on device, otherwise on host. Default value: false. */
|
||||
isResultOnDevice bool
|
||||
/* True if `result` vector should be in Montgomery form and false otherwise. Default value: false. */
|
||||
IsResultMontgomeryForm bool
|
||||
/* Whether to run the vector operations asynchronously. If set to `true`, the function will be
|
||||
* non-blocking and you'll need to synchronize it explicitly by calling
|
||||
* `SynchronizeStream`. If set to false, the function will block the current CPU thread. */
|
||||
IsAsync bool
|
||||
}
|
||||
|
||||
/**
|
||||
* A function that returns the default value of [VecOpsConfig](@ref VecOpsConfig).
|
||||
* @return Default value of [VecOpsConfig](@ref VecOpsConfig).
|
||||
*/
|
||||
func DefaultVecOpsConfig() VecOpsConfig {
|
||||
ctx, _ := cr.GetDefaultDeviceContext()
|
||||
config := VecOpsConfig{
|
||||
ctx, // ctx
|
||||
false, // isAOnDevice
|
||||
false, // isBOnDevice
|
||||
false, // isResultOnDevice
|
||||
false, // IsResultMontgomeryForm
|
||||
false, // IsAsync
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func VecOpCheck(a, b, out HostOrDeviceSlice, cfg *VecOpsConfig) {
|
||||
aLen, bLen, outLen := a.Len(), b.Len(), out.Len()
|
||||
if aLen != bLen {
|
||||
errorString := fmt.Sprintf(
|
||||
"a and b vector lengths %d; %d are not equal",
|
||||
aLen,
|
||||
bLen,
|
||||
)
|
||||
panic(errorString)
|
||||
}
|
||||
if aLen != outLen {
|
||||
errorString := fmt.Sprintf(
|
||||
"a and out vector lengths %d; %d are not equal",
|
||||
aLen,
|
||||
outLen,
|
||||
)
|
||||
panic(errorString)
|
||||
}
|
||||
|
||||
cfg.isAOnDevice = a.IsOnDevice()
|
||||
cfg.isBOnDevice = b.IsOnDevice()
|
||||
cfg.isResultOnDevice = out.IsOnDevice()
|
||||
}
|
||||
23
wrappers/golang/core/vec_ops_test.go
Normal file
23
wrappers/golang/core/vec_ops_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVecOpsDefaultConfig(t *testing.T) {
|
||||
ctx, _ := cr.GetDefaultDeviceContext()
|
||||
expected := VecOpsConfig{
|
||||
ctx, // Ctx
|
||||
false, // isAOnDevice
|
||||
false, // isBOnDevice
|
||||
false, // isResultOnDevice
|
||||
false, // IsResultMontgomeryForm
|
||||
false, // IsAsync
|
||||
}
|
||||
|
||||
actual := DefaultVecOpsConfig()
|
||||
|
||||
assert.Equal(t, expected, actual)
|
||||
}
|
||||
39
wrappers/golang/curves/bls12377/include/vec_ops.h
Normal file
39
wrappers/golang/curves/bls12377/include/vec_ops.h
Normal file
@@ -0,0 +1,39 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
|
||||
#ifndef _BLS12_377_VEC_OPS_H
|
||||
#define _BLS12_377_VEC_OPS_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bls12_377MulCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
cudaError_t bls12_377AddCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
cudaError_t bls12_377SubCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -6,7 +6,7 @@ import "C"
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
core "github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
"unsafe"
|
||||
)
|
||||
@@ -85,12 +85,7 @@ func (f ScalarField) ToBytesLittleEndian() []byte {
|
||||
}
|
||||
|
||||
func GenerateScalars(size int) core.HostSlice[ScalarField] {
|
||||
scalars := make([]ScalarField, size)
|
||||
for i := range scalars {
|
||||
scalars[i] = ScalarField{}
|
||||
}
|
||||
|
||||
scalarSlice := core.HostSliceFromElements[ScalarField](scalars)
|
||||
scalarSlice := make(core.HostSlice[ScalarField], size)
|
||||
|
||||
cScalars := (*C.scalar_t)(unsafe.Pointer(&scalarSlice[0]))
|
||||
cSize := (C.int)(size)
|
||||
|
||||
49
wrappers/golang/curves/bls12377/vec_ops.go
Normal file
49
wrappers/golang/curves/bls12377/vec_ops.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package bls12377
|
||||
|
||||
// #cgo CFLAGS: -I./include/
|
||||
// #include "vec_ops.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
)
|
||||
|
||||
func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.VecOps) (ret cr.CudaError) {
|
||||
core.VecOpCheck(a, b, out, &config)
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
cConfig := (*C.VecOpsConfig)(unsafe.Pointer(&config))
|
||||
cSize := (C.int)(a.Len())
|
||||
|
||||
switch op {
|
||||
case core.Sub:
|
||||
ret = (cr.CudaError)(C.bls12_377SubCuda(cA, cB, cSize, cConfig, cOut))
|
||||
case core.Add:
|
||||
ret = (cr.CudaError)(C.bls12_377AddCuda(cA, cB, cSize, cConfig, cOut))
|
||||
case core.Mul:
|
||||
ret = (cr.CudaError)(C.bls12_377MulCuda(cA, cB, cSize, cConfig, cOut))
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
33
wrappers/golang/curves/bls12377/vec_ops_test.go
Normal file
33
wrappers/golang/curves/bls12377/vec_ops_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package bls12377
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVecOps(t *testing.T) {
|
||||
testSize := 1 << 14
|
||||
|
||||
a := GenerateScalars(testSize)
|
||||
b := GenerateScalars(testSize)
|
||||
var scalar ScalarField
|
||||
scalar.One()
|
||||
ones := core.HostSliceWithValue(scalar, testSize)
|
||||
|
||||
out := make(core.HostSlice[ScalarField], testSize)
|
||||
out2 := make(core.HostSlice[ScalarField], testSize)
|
||||
out3 := make(core.HostSlice[ScalarField], testSize)
|
||||
|
||||
cfg := core.DefaultVecOpsConfig()
|
||||
|
||||
VecOp(a, b, out, cfg, core.Add)
|
||||
VecOp(out, b, out2, cfg, core.Sub)
|
||||
|
||||
assert.Equal(t, a, out2)
|
||||
|
||||
VecOp(a, ones, out3, cfg, core.Mul)
|
||||
|
||||
assert.Equal(t, a, out3)
|
||||
}
|
||||
39
wrappers/golang/curves/bls12381/include/vec_ops.h
Normal file
39
wrappers/golang/curves/bls12381/include/vec_ops.h
Normal file
@@ -0,0 +1,39 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
|
||||
#ifndef _BLS12_381_VEC_OPS_H
|
||||
#define _BLS12_381_VEC_OPS_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bls12_381MulCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
cudaError_t bls12_381AddCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
cudaError_t bls12_381SubCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -6,7 +6,7 @@ import "C"
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
core "github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
"unsafe"
|
||||
)
|
||||
@@ -85,12 +85,7 @@ func (f ScalarField) ToBytesLittleEndian() []byte {
|
||||
}
|
||||
|
||||
func GenerateScalars(size int) core.HostSlice[ScalarField] {
|
||||
scalars := make([]ScalarField, size)
|
||||
for i := range scalars {
|
||||
scalars[i] = ScalarField{}
|
||||
}
|
||||
|
||||
scalarSlice := core.HostSliceFromElements[ScalarField](scalars)
|
||||
scalarSlice := make(core.HostSlice[ScalarField], size)
|
||||
|
||||
cScalars := (*C.scalar_t)(unsafe.Pointer(&scalarSlice[0]))
|
||||
cSize := (C.int)(size)
|
||||
|
||||
49
wrappers/golang/curves/bls12381/vec_ops.go
Normal file
49
wrappers/golang/curves/bls12381/vec_ops.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package bls12381
|
||||
|
||||
// #cgo CFLAGS: -I./include/
|
||||
// #include "vec_ops.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
)
|
||||
|
||||
func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.VecOps) (ret cr.CudaError) {
|
||||
core.VecOpCheck(a, b, out, &config)
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
cConfig := (*C.VecOpsConfig)(unsafe.Pointer(&config))
|
||||
cSize := (C.int)(a.Len())
|
||||
|
||||
switch op {
|
||||
case core.Sub:
|
||||
ret = (cr.CudaError)(C.bls12_381SubCuda(cA, cB, cSize, cConfig, cOut))
|
||||
case core.Add:
|
||||
ret = (cr.CudaError)(C.bls12_381AddCuda(cA, cB, cSize, cConfig, cOut))
|
||||
case core.Mul:
|
||||
ret = (cr.CudaError)(C.bls12_381MulCuda(cA, cB, cSize, cConfig, cOut))
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
33
wrappers/golang/curves/bls12381/vec_ops_test.go
Normal file
33
wrappers/golang/curves/bls12381/vec_ops_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package bls12381
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVecOps(t *testing.T) {
|
||||
testSize := 1 << 14
|
||||
|
||||
a := GenerateScalars(testSize)
|
||||
b := GenerateScalars(testSize)
|
||||
var scalar ScalarField
|
||||
scalar.One()
|
||||
ones := core.HostSliceWithValue(scalar, testSize)
|
||||
|
||||
out := make(core.HostSlice[ScalarField], testSize)
|
||||
out2 := make(core.HostSlice[ScalarField], testSize)
|
||||
out3 := make(core.HostSlice[ScalarField], testSize)
|
||||
|
||||
cfg := core.DefaultVecOpsConfig()
|
||||
|
||||
VecOp(a, b, out, cfg, core.Add)
|
||||
VecOp(out, b, out2, cfg, core.Sub)
|
||||
|
||||
assert.Equal(t, a, out2)
|
||||
|
||||
VecOp(a, ones, out3, cfg, core.Mul)
|
||||
|
||||
assert.Equal(t, a, out3)
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
|
||||
// Copyright 2023 Ingonyama
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Code generated by Ingonyama DO NOT EDIT
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdbool.h>
|
||||
// ve_mod_mult.h
|
||||
|
||||
#ifndef _BN254_VEC_MULT_H
|
||||
#define _BN254_VEC_MULT_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
cudaStream_t stream; /**< Stream to use. Default value: 0. */
|
||||
int device_id; /**< Index of the currently used GPU. Default value: 0. */
|
||||
cudaMemPool_t mempool; /**< Mempool to use. Default value: 0. */
|
||||
} DeviceContext;
|
||||
|
||||
typedef struct BN254_scalar_t BN254_scalar_t;
|
||||
|
||||
int bn254AddCuda(
|
||||
BN254_scalar_t* vec_a,
|
||||
BN254_scalar_t* vec_b,
|
||||
int n,
|
||||
bool is_on_device,
|
||||
DeviceContext ctx,
|
||||
BN254_scalar_t* result
|
||||
);
|
||||
|
||||
int bn254SubCuda(
|
||||
BN254_scalar_t* vec_a,
|
||||
BN254_scalar_t* vec_b,
|
||||
int n,
|
||||
bool is_on_device,
|
||||
DeviceContext ctx,
|
||||
BN254_scalar_t* result
|
||||
);
|
||||
|
||||
int bn254MulCuda(
|
||||
BN254_scalar_t* vec_a,
|
||||
BN254_scalar_t* vec_b,
|
||||
int n,
|
||||
bool is_on_device,
|
||||
bool is_montgomery,
|
||||
DeviceContext ctx,
|
||||
BN254_scalar_t* result
|
||||
);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* _BN254_VEC_MULT_H */
|
||||
39
wrappers/golang/curves/bn254/include/vec_ops.h
Normal file
39
wrappers/golang/curves/bn254/include/vec_ops.h
Normal file
@@ -0,0 +1,39 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
|
||||
#ifndef _BN254_VEC_OPS_H
|
||||
#define _BN254_VEC_OPS_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bn254MulCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
cudaError_t bn254AddCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
cudaError_t bn254SubCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -6,7 +6,7 @@ import "C"
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
core "github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
"unsafe"
|
||||
)
|
||||
@@ -85,12 +85,7 @@ func (f ScalarField) ToBytesLittleEndian() []byte {
|
||||
}
|
||||
|
||||
func GenerateScalars(size int) core.HostSlice[ScalarField] {
|
||||
scalars := make([]ScalarField, size)
|
||||
for i := range scalars {
|
||||
scalars[i] = ScalarField{}
|
||||
}
|
||||
|
||||
scalarSlice := core.HostSliceFromElements[ScalarField](scalars)
|
||||
scalarSlice := make(core.HostSlice[ScalarField], size)
|
||||
|
||||
cScalars := (*C.scalar_t)(unsafe.Pointer(&scalarSlice[0]))
|
||||
cSize := (C.int)(size)
|
||||
|
||||
49
wrappers/golang/curves/bn254/vec_ops.go
Normal file
49
wrappers/golang/curves/bn254/vec_ops.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package bn254
|
||||
|
||||
// #cgo CFLAGS: -I./include/
|
||||
// #include "vec_ops.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
)
|
||||
|
||||
func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.VecOps) (ret cr.CudaError) {
|
||||
core.VecOpCheck(a, b, out, &config)
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
cConfig := (*C.VecOpsConfig)(unsafe.Pointer(&config))
|
||||
cSize := (C.int)(a.Len())
|
||||
|
||||
switch op {
|
||||
case core.Sub:
|
||||
ret = (cr.CudaError)(C.bn254SubCuda(cA, cB, cSize, cConfig, cOut))
|
||||
case core.Add:
|
||||
ret = (cr.CudaError)(C.bn254AddCuda(cA, cB, cSize, cConfig, cOut))
|
||||
case core.Mul:
|
||||
ret = (cr.CudaError)(C.bn254MulCuda(cA, cB, cSize, cConfig, cOut))
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
33
wrappers/golang/curves/bn254/vec_ops_test.go
Normal file
33
wrappers/golang/curves/bn254/vec_ops_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package bn254
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVecOps(t *testing.T) {
|
||||
testSize := 1 << 14
|
||||
|
||||
a := GenerateScalars(testSize)
|
||||
b := GenerateScalars(testSize)
|
||||
var scalar ScalarField
|
||||
scalar.One()
|
||||
ones := core.HostSliceWithValue(scalar, testSize)
|
||||
|
||||
out := make(core.HostSlice[ScalarField], testSize)
|
||||
out2 := make(core.HostSlice[ScalarField], testSize)
|
||||
out3 := make(core.HostSlice[ScalarField], testSize)
|
||||
|
||||
cfg := core.DefaultVecOpsConfig()
|
||||
|
||||
VecOp(a, b, out, cfg, core.Add)
|
||||
VecOp(out, b, out2, cfg, core.Sub)
|
||||
|
||||
assert.Equal(t, a, out2)
|
||||
|
||||
VecOp(a, ones, out3, cfg, core.Mul)
|
||||
|
||||
assert.Equal(t, a, out3)
|
||||
}
|
||||
39
wrappers/golang/curves/bw6761/include/vec_ops.h
Normal file
39
wrappers/golang/curves/bw6761/include/vec_ops.h
Normal file
@@ -0,0 +1,39 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
|
||||
#ifndef _BW6_761_VEC_OPS_H
|
||||
#define _BW6_761_VEC_OPS_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bw6_761MulCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
cudaError_t bw6_761AddCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
cudaError_t bw6_761SubCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@@ -6,7 +6,7 @@ import "C"
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
core "github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
"unsafe"
|
||||
)
|
||||
@@ -85,12 +85,7 @@ func (f ScalarField) ToBytesLittleEndian() []byte {
|
||||
}
|
||||
|
||||
func GenerateScalars(size int) core.HostSlice[ScalarField] {
|
||||
scalars := make([]ScalarField, size)
|
||||
for i := range scalars {
|
||||
scalars[i] = ScalarField{}
|
||||
}
|
||||
|
||||
scalarSlice := core.HostSliceFromElements[ScalarField](scalars)
|
||||
scalarSlice := make(core.HostSlice[ScalarField], size)
|
||||
|
||||
cScalars := (*C.scalar_t)(unsafe.Pointer(&scalarSlice[0]))
|
||||
cSize := (C.int)(size)
|
||||
|
||||
49
wrappers/golang/curves/bw6761/vec_ops.go
Normal file
49
wrappers/golang/curves/bw6761/vec_ops.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package bw6761
|
||||
|
||||
// #cgo CFLAGS: -I./include/
|
||||
// #include "vec_ops.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
)
|
||||
|
||||
func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.VecOps) (ret cr.CudaError) {
|
||||
core.VecOpCheck(a, b, out, &config)
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
cConfig := (*C.VecOpsConfig)(unsafe.Pointer(&config))
|
||||
cSize := (C.int)(a.Len())
|
||||
|
||||
switch op {
|
||||
case core.Sub:
|
||||
ret = (cr.CudaError)(C.bw6_761SubCuda(cA, cB, cSize, cConfig, cOut))
|
||||
case core.Add:
|
||||
ret = (cr.CudaError)(C.bw6_761AddCuda(cA, cB, cSize, cConfig, cOut))
|
||||
case core.Mul:
|
||||
ret = (cr.CudaError)(C.bw6_761MulCuda(cA, cB, cSize, cConfig, cOut))
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
33
wrappers/golang/curves/bw6761/vec_ops_test.go
Normal file
33
wrappers/golang/curves/bw6761/vec_ops_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package bw6761
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVecOps(t *testing.T) {
|
||||
testSize := 1 << 14
|
||||
|
||||
a := GenerateScalars(testSize)
|
||||
b := GenerateScalars(testSize)
|
||||
var scalar ScalarField
|
||||
scalar.One()
|
||||
ones := core.HostSliceWithValue(scalar, testSize)
|
||||
|
||||
out := make(core.HostSlice[ScalarField], testSize)
|
||||
out2 := make(core.HostSlice[ScalarField], testSize)
|
||||
out3 := make(core.HostSlice[ScalarField], testSize)
|
||||
|
||||
cfg := core.DefaultVecOpsConfig()
|
||||
|
||||
VecOp(a, b, out, cfg, core.Add)
|
||||
VecOp(out, b, out2, cfg, core.Sub)
|
||||
|
||||
assert.Equal(t, a, out2)
|
||||
|
||||
VecOp(a, ones, out3, cfg, core.Mul)
|
||||
|
||||
assert.Equal(t, a, out3)
|
||||
}
|
||||
@@ -25,6 +25,7 @@ typedef struct g2_affine_t g2_affine_t;
|
||||
|
||||
typedef struct MSMConfig MSMConfig;
|
||||
typedef struct NTTConfig NTTConfig;
|
||||
typedef struct VecOpsConfig VecOpsConfig;
|
||||
typedef struct DeviceContext DeviceContext;
|
||||
|
||||
typedef cudaError_t cudaError_t;
|
||||
|
||||
@@ -104,7 +104,8 @@ func generateFiles() {
|
||||
"ntt_test.go.tmpl",
|
||||
"curve_test.go.tmpl",
|
||||
"curve.go.tmpl",
|
||||
/* "vec_ops.h.tmpl,"*/
|
||||
"vec_ops_test.go.tmpl",
|
||||
"vec_ops.go.tmpl",
|
||||
"helpers_test.go.tmpl",
|
||||
}
|
||||
|
||||
@@ -171,7 +172,7 @@ func generateFiles() {
|
||||
"msm.h.tmpl",
|
||||
"g2_msm.h.tmpl",
|
||||
"ntt.h.tmpl",
|
||||
/*"vec_ops.h.tmpl",*/
|
||||
"vec_ops.h.tmpl",
|
||||
}
|
||||
|
||||
for _, includeFile := range templateIncludeFiles {
|
||||
|
||||
@@ -1,49 +1,35 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdbool.h>
|
||||
// ve_mod_mult.h
|
||||
#include "../../include/types.h"
|
||||
|
||||
#ifndef _BN254_VEC_MULT_H
|
||||
#define _BN254_VEC_MULT_H
|
||||
#ifndef _{{toUpper .Curve}}_VEC_OPS_H
|
||||
#define _{{toUpper .Curve}}_VEC_OPS_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
cudaStream_t stream; /**< Stream to use. Default value: 0. */
|
||||
int device_id; /**< Index of the currently used GPU. Default value: 0. */
|
||||
cudaMemPool_t mempool; /**< Mempool to use. Default value: 0. */
|
||||
} DeviceContext;
|
||||
|
||||
typedef struct BN254_scalar_t BN254_scalar_t;
|
||||
|
||||
int bn254AddCuda(
|
||||
BN254_scalar_t* vec_a,
|
||||
BN254_scalar_t* vec_b,
|
||||
cudaError_t {{.Curve}}MulCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
bool is_on_device,
|
||||
DeviceContext ctx,
|
||||
BN254_scalar_t* result
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
int bn254SubCuda(
|
||||
BN254_scalar_t* vec_a,
|
||||
BN254_scalar_t* vec_b,
|
||||
cudaError_t {{.Curve}}AddCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
bool is_on_device,
|
||||
DeviceContext ctx,
|
||||
BN254_scalar_t* result
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
int bn254MulCuda(
|
||||
BN254_scalar_t* vec_a,
|
||||
BN254_scalar_t* vec_b,
|
||||
cudaError_t {{.Curve}}SubCuda(
|
||||
scalar_t* vec_a,
|
||||
scalar_t* vec_b,
|
||||
int n,
|
||||
bool is_on_device,
|
||||
bool is_montgomery,
|
||||
DeviceContext ctx,
|
||||
BN254_scalar_t* result
|
||||
VecOpsConfig* config,
|
||||
scalar_t* result
|
||||
);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@@ -5,19 +5,14 @@ import "C"
|
||||
{{- end }}
|
||||
|
||||
{{- define "scalar_field_go_imports" }}
|
||||
core "github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
"unsafe"
|
||||
{{- end }}
|
||||
|
||||
{{- define "scalar_field_funcs" }}
|
||||
func GenerateScalars(size int) core.HostSlice[ScalarField] {
|
||||
scalars := make([]ScalarField, size)
|
||||
for i := range scalars {
|
||||
scalars[i] = ScalarField{}
|
||||
}
|
||||
|
||||
scalarSlice := core.HostSliceFromElements[ScalarField](scalars)
|
||||
scalarSlice := make(core.HostSlice[ScalarField], size)
|
||||
|
||||
cScalars := (*C.scalar_t)(unsafe.Pointer(&scalarSlice[0]))
|
||||
cSize := (C.int)(size)
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package {{.PackageName}}
|
||||
|
||||
// #cgo CFLAGS: -I./include/
|
||||
// #include "vec_ops.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
)
|
||||
|
||||
func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.VecOps) (ret cr.CudaError) {
|
||||
core.VecOpCheck(a, b, out, &config)
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
cConfig := (*C.VecOpsConfig)(unsafe.Pointer(&config))
|
||||
cSize := (C.int)(a.Len())
|
||||
|
||||
switch op {
|
||||
case core.Sub:
|
||||
ret = (cr.CudaError)(C.{{.Curve}}SubCuda(cA, cB, cSize, cConfig, cOut))
|
||||
case core.Add:
|
||||
ret = (cr.CudaError)(C.{{.Curve}}AddCuda(cA, cB, cSize, cConfig, cOut))
|
||||
case core.Mul:
|
||||
ret = (cr.CudaError)(C.{{.Curve}}MulCuda(cA, cB, cSize, cConfig, cOut))
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
package {{.PackageName}}
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVecOps(t *testing.T) {
|
||||
testSize := 1 << 14
|
||||
|
||||
a := GenerateScalars(testSize)
|
||||
b := GenerateScalars(testSize)
|
||||
var scalar ScalarField
|
||||
scalar.One()
|
||||
ones := core.HostSliceWithValue(scalar, testSize)
|
||||
|
||||
out := make(core.HostSlice[ScalarField], testSize)
|
||||
out2 := make(core.HostSlice[ScalarField], testSize)
|
||||
out3 := make(core.HostSlice[ScalarField], testSize)
|
||||
|
||||
cfg := core.DefaultVecOpsConfig()
|
||||
|
||||
VecOp(a, b, out, cfg, core.Add)
|
||||
VecOp(out, b, out2, cfg, core.Sub)
|
||||
|
||||
assert.Equal(t, a, out2)
|
||||
|
||||
VecOp(a, ones, out3, cfg, core.Mul)
|
||||
|
||||
assert.Equal(t, a, out3)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user