MSM - supporting all window sizes (#534)

This PR enables using MSM with any value of c.

Note: default c isn't necessarily optimal, the user is expected to
choose c and the precomputation factor that give the best results for
the relevant case.

---------

Co-authored-by: Jeremy Felder <jeremy.felder1@gmail.com>
This commit is contained in:
HadarIngonyama
2024-06-17 15:57:24 +03:00
committed by GitHub
parent af9ec76506
commit 8936d9c800
46 changed files with 688 additions and 282 deletions

View File

@@ -116,13 +116,13 @@ func MsmCheck(scalars HostOrDeviceSlice, points HostOrDeviceSlice, cfg *MSMConfi
return scalars.AsUnsafePointer(), points.AsUnsafePointer(), results.AsUnsafePointer(), size, unsafe.Pointer(cfg)
}
func PrecomputeBasesCheck(points HostOrDeviceSlice, precomputeFactor int32, outputBases DeviceSlice) (unsafe.Pointer, unsafe.Pointer) {
func PrecomputePointsCheck(points HostOrDeviceSlice, cfg *MSMConfig, outputBases DeviceSlice) (unsafe.Pointer, unsafe.Pointer) {
outputBasesLength, pointsLength := outputBases.Len(), points.Len()
if outputBasesLength != pointsLength*int(precomputeFactor) {
if outputBasesLength != pointsLength*int(cfg.PrecomputeFactor) {
errorString := fmt.Sprintf(
"Precompute factor is probably incorrect: expected %d but got %d",
outputBasesLength/pointsLength,
precomputeFactor,
cfg.PrecomputeFactor,
)
panic(errorString)
}
@@ -131,5 +131,8 @@ func PrecomputeBasesCheck(points HostOrDeviceSlice, precomputeFactor int32, outp
points.(DeviceSlice).CheckDevice()
}
cfg.pointsSize = int32(pointsLength)
cfg.arePointsOnDevice = points.IsOnDevice()
return points.AsUnsafePointer(), outputBases.AsUnsafePointer()
}