mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-07 22:53:56 -05:00
69 lines
1.8 KiB
Go
69 lines
1.8 KiB
Go
package core
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
|
)
|
|
|
|
type NTTDir int8
|
|
|
|
const (
|
|
KForward NTTDir = iota
|
|
KInverse NTTDir = 1
|
|
)
|
|
|
|
type Ordering uint32
|
|
|
|
const (
|
|
KNN Ordering = iota
|
|
KNR Ordering = 1
|
|
KRN Ordering = 2
|
|
KRR Ordering = 3
|
|
KNM Ordering = 4
|
|
KMN Ordering = 5
|
|
)
|
|
|
|
type NTTConfig[T any] struct {
|
|
/// Details related to the device such as its id and stream id. See [DeviceContext](@ref device_context::DeviceContext).
|
|
Ctx cuda_runtime.DeviceContext
|
|
/// Coset generator. Used to perform coset (i)NTTs. Default value: `S::one()` (corresponding to no coset being used).
|
|
CosetGen T
|
|
/// The number of NTTs to compute. Default value: 1.
|
|
BatchSize int32
|
|
/// Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`.
|
|
Ordering Ordering
|
|
areInputsOnDevice bool
|
|
areOutputsOnDevice bool
|
|
/// Whether to run the NTT asynchronously. If set to `true`, the NTT function will be non-blocking and you'd need to synchronize
|
|
/// it explicitly by running `stream.synchronize()`. If set to false, the NTT function will block the current CPU thread.
|
|
IsAsync bool
|
|
}
|
|
|
|
func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] {
|
|
ctx, _ := cuda_runtime.GetDefaultDeviceContext()
|
|
return NTTConfig[T]{
|
|
ctx, // Ctx
|
|
cosetGen, // CosetGen
|
|
1, // BatchSize
|
|
KNN, // Ordering
|
|
false, // areInputsOnDevice
|
|
false, // areOutputsOnDevice
|
|
false, // IsAsync
|
|
}
|
|
}
|
|
|
|
func NttCheck[T any](input HostOrDeviceSlice, cfg *NTTConfig[T], output HostOrDeviceSlice) {
|
|
inputLen, outputLen := input.Len(), output.Len()
|
|
if inputLen != outputLen {
|
|
errorString := fmt.Sprintf(
|
|
"input and output capacities %d; %d are not equal",
|
|
inputLen,
|
|
outputLen,
|
|
)
|
|
panic(errorString)
|
|
}
|
|
cfg.areInputsOnDevice = input.IsOnDevice()
|
|
cfg.areOutputsOnDevice = output.IsOnDevice()
|
|
}
|