Files
icicle/goicicle/templates/msm/msm.go.tmpl
Jeremy Felder fb52650bbc CI: Additional checks (#155)
adds CI checks for building and testing Golang bindings
adds CI checks for formatting Rust and Golang files
Fixes Golang tests for BN254
Splits Actions checks for PR against main into multiple files

Resolves #108
Resolves #107
Resolves #138
2023-08-31 09:04:53 +03:00

192 lines
6.0 KiB
Cheetah

import (
"errors"
"fmt"
"unsafe"
)
// #cgo CFLAGS: -I./include/
// #cgo CFLAGS: -I/usr/local/cuda/include
// #cgo LDFLAGS: -L${SRCDIR}/../../ {{.SharedLib}}
// #include "msm.h"
import "C"
func Msm(out *G1ProjectivePoint, points []G1PointAffine, scalars []G1ScalarField, device_id int) (*G1ProjectivePoint, error) {
if len(points) != len(scalars) {
return nil, errors.New("error on: len(points) != len(scalars)")
}
pointsC := (*C.{{.CurveNameUpperCase}}_affine_t)(unsafe.Pointer(&points[0]))
scalarsC := (*C.{{.CurveNameUpperCase}}_scalar_t)(unsafe.Pointer(&scalars[0]))
outC := (*C.{{.CurveNameUpperCase}}_projective_t)(unsafe.Pointer(out))
ret := C.msm_cuda_{{.CurveNameLowerCase}}(outC, pointsC, scalarsC, C.size_t(len(points)), C.size_t(device_id))
if ret != 0 {
return nil, fmt.Errorf("msm_cuda_{{.CurveNameLowerCase}} returned error code: %d", ret)
}
return out, nil
}
func MsmG2(out *G2Point, points []G2PointAffine, scalars []G1ScalarField, device_id int) (*G2Point, error) {
if len(points) != len(scalars) {
return nil, errors.New("error on: len(points) != len(scalars)")
}
pointsC := (*C.{{.CurveNameUpperCase}}_g2_affine_t)(unsafe.Pointer(&points[0]))
scalarsC := (*C.{{.CurveNameUpperCase}}_scalar_t)(unsafe.Pointer(&scalars[0]))
outC := (*C.{{.CurveNameUpperCase}}_g2_projective_t)(unsafe.Pointer(out))
ret := C.msm_g2_cuda_{{.CurveNameLowerCase}}(outC, pointsC, scalarsC, C.size_t(len(points)), C.size_t(device_id))
if ret != 0 {
return nil, fmt.Errorf("msm_g2_cuda_{{.CurveNameLowerCase}} returned error code: %d", ret)
}
return out, nil
}
func MsmBatch(points *[]G1PointAffine, scalars *[]G1ScalarField, batchSize, deviceId int) ([]G1ProjectivePoint, error) {
// Check for nil pointers
if points == nil || scalars == nil {
return nil, errors.New("points or scalars is nil")
}
if len(*points) != len(*scalars) {
return nil, errors.New("error on: len(points) != len(scalars)")
}
// Check for empty slices
if len(*points) == 0 || len(*scalars) == 0 {
return nil, errors.New("points or scalars is empty")
}
// Check for zero batchSize
if batchSize <= 0 {
return nil, errors.New("error on: batchSize must be greater than zero")
}
out := make([]G1ProjectivePoint, batchSize)
for i := 0; i < len(out); i++ {
var p G1ProjectivePoint
p.SetZero()
out[i] = p
}
outC := (*C.{{.CurveNameUpperCase}}_projective_t)(unsafe.Pointer(&out[0]))
pointsC := (*C.{{.CurveNameUpperCase}}_affine_t)(unsafe.Pointer(&(*points)[0]))
scalarsC := (*C.{{.CurveNameUpperCase}}_scalar_t)(unsafe.Pointer(&(*scalars)[0]))
msmSizeC := C.size_t(len(*points) / batchSize)
deviceIdC := C.size_t(deviceId)
batchSizeC := C.size_t(batchSize)
ret := C.msm_batch_cuda_{{.CurveNameLowerCase}}(outC, pointsC, scalarsC, batchSizeC, msmSizeC, deviceIdC)
if ret != 0 {
return nil, fmt.Errorf("msm_batch_cuda_{{.CurveNameLowerCase}} returned error code: %d", ret)
}
return out, nil
}
func MsmG2Batch(points *[]G2PointAffine, scalars *[]G1ScalarField, batchSize, deviceId int) ([]G2Point, error) {
// Check for nil pointers
if points == nil || scalars == nil {
return nil, errors.New("points or scalars is nil")
}
if len(*points) != len(*scalars) {
return nil, errors.New("error on: len(points) != len(scalars)")
}
// Check for empty slices
if len(*points) == 0 || len(*scalars) == 0 {
return nil, errors.New("points or scalars is empty")
}
// Check for zero batchSize
if batchSize <= 0 {
return nil, errors.New("error on: batchSize must be greater than zero")
}
out := make([]G2Point, batchSize)
outC := (*C.{{.CurveNameUpperCase}}_g2_projective_t)(unsafe.Pointer(&out[0]))
pointsC := (*C.{{.CurveNameUpperCase}}_g2_affine_t)(unsafe.Pointer(&(*points)[0]))
scalarsC := (*C.{{.CurveNameUpperCase}}_scalar_t)(unsafe.Pointer(&(*scalars)[0]))
msmSizeC := C.size_t(len(*points) / batchSize)
deviceIdC := C.size_t(deviceId)
batchSizeC := C.size_t(batchSize)
ret := C.msm_batch_g2_cuda_{{.CurveNameLowerCase}}(outC, pointsC, scalarsC, batchSizeC, msmSizeC, deviceIdC)
if ret != 0 {
return nil, fmt.Errorf("msm_batch_cuda_{{.CurveNameLowerCase}} returned error code: %d", ret)
}
return out, nil
}
func Commit(d_out, d_scalars, d_points unsafe.Pointer, count, bucketFactor int) int {
d_outC := (*C.{{.CurveNameUpperCase}}_projective_t)(d_out)
scalarsC := (*C.{{.CurveNameUpperCase}}_scalar_t)(d_scalars)
pointsC := (*C.{{.CurveNameUpperCase}}_affine_t)(d_points)
countC := (C.size_t)(count)
largeBucketFactorC := C.uint(bucketFactor)
ret := C.commit_cuda_{{.CurveNameLowerCase}}(d_outC, scalarsC, pointsC, countC, largeBucketFactorC, 0)
if ret != 0 {
return -1
}
return 0
}
func CommitG2(d_out, d_scalars, d_points unsafe.Pointer, count, bucketFactor int) int {
d_outC := (*C.{{.CurveNameUpperCase}}_g2_projective_t)(d_out)
scalarsC := (*C.{{.CurveNameUpperCase}}_scalar_t)(d_scalars)
pointsC := (*C.{{.CurveNameUpperCase}}_g2_affine_t)(d_points)
countC := (C.size_t)(count)
largeBucketFactorC := C.uint(bucketFactor)
ret := C.commit_g2_cuda_{{.CurveNameLowerCase}}(d_outC, scalarsC, pointsC, countC, largeBucketFactorC, 0)
if ret != 0 {
return -1
}
return 0
}
func CommitBatch(d_out, d_scalars, d_points unsafe.Pointer, count, batch_size int) int {
d_outC := (*C.{{.CurveNameUpperCase}}_projective_t)(d_out)
scalarsC := (*C.{{.CurveNameUpperCase}}_scalar_t)(d_scalars)
pointsC := (*C.{{.CurveNameUpperCase}}_affine_t)(d_points)
countC := (C.size_t)(count)
batch_sizeC := (C.size_t)(batch_size)
ret := C.commit_batch_cuda_{{.CurveNameLowerCase}}(d_outC, scalarsC, pointsC, countC, batch_sizeC, 0)
if ret != 0 {
return -1
}
return 0
}
func CommitG2Batch(d_out, d_scalars, d_points unsafe.Pointer, count, batch_size int) int {
d_outC := (*C.{{.CurveNameUpperCase}}_g2_projective_t)(d_out)
scalarsC := (*C.{{.CurveNameUpperCase}}_scalar_t)(d_scalars)
pointsC := (*C.{{.CurveNameUpperCase}}_g2_affine_t)(d_points)
countC := (C.size_t)(count)
batch_sizeC := (C.size_t)(batch_size)
ret := C.msm_batch_g2_cuda_{{.CurveNameLowerCase}}(d_outC, pointsC, scalarsC, countC, batch_sizeC, 0)
if ret != 0 {
return -1
}
return 0
}