mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use simple gidxs on GPU
This commit is contained in:
@@ -3,7 +3,6 @@ import unittest
|
||||
from dataclasses import replace
|
||||
|
||||
from tinygrad.codegen.opt.kernel import Opt, OptOps, KernelOptError, Kernel, AxisType
|
||||
from tinygrad.codegen.gpudims import get_grouped_dims
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, KernelInfo
|
||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -467,77 +466,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
end_range = [i for i, x in enumerate(uops) if x.op is Ops.ENDRANGE][0]
|
||||
assert end_range < uops.index(u)
|
||||
|
||||
def test_grouped_dims(self):
|
||||
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes, assert_same_length = True):
|
||||
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
|
||||
loop_idxs = dedup(flatten([[y for y in x.toposort() if y.op is Ops.SPECIAL] for x in idxs]))
|
||||
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0])
|
||||
sizes = [x.arg[1] for x in loop_idxs]
|
||||
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
|
||||
if assert_same_length:
|
||||
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
|
||||
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
|
||||
# TODO: add these back after uop symbolic
|
||||
# for i in range(len(dims)):
|
||||
# assert idxs[i].max+1 == dims[i], f"idxs[{i}] should have max {dims[i]-1}"
|
||||
# for i in range(len(loop_idxs)):
|
||||
# assert loop_idxs[i].expr.startswith(prefix), f"loop_idxs[{i}] must start with {prefix}"
|
||||
# assert loop_idxs[i].max+1 == sizes[i], f"loop_idxs[{i}] should have max {sizes[i]-1}"
|
||||
|
||||
# no-op
|
||||
_assert_grouped_dims("gidx", (2,), (16,16,16), False, [2])
|
||||
_assert_grouped_dims("gidx", (2,3), (16,16,16), False, [2,3])
|
||||
|
||||
# check reverse dims
|
||||
_assert_grouped_dims("gidx", (2,3), (16,16,16), True, [3,2])
|
||||
_assert_grouped_dims("gidx", (2,3,4), (16,16,16), False, [2,3,4])
|
||||
|
||||
# test splitting globals: len(dims) == len(max)
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), False, [16,12,4])
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,4,16), False, [16,3,16])
|
||||
_assert_grouped_dims("gidx", (64,3,4), (16,16,16), True, [16,3,16])
|
||||
_assert_grouped_dims("gidx", (128,3,4), (16,4,256), False, [16,3,32])
|
||||
_assert_grouped_dims("gidx", (4,4,512), (16,4,256), False, [8,4,256])
|
||||
|
||||
# prefer group_dim strategy when possible
|
||||
_assert_grouped_dims("gidx", (512,4,2), (8192,2,2), False, [2048,2])
|
||||
|
||||
# test splitting globals: len(dims) < len(max)
|
||||
# len(dim) -> len(limited)
|
||||
# 1 -> 2
|
||||
_assert_grouped_dims("gidx", (128,), (16,16,256), False, [16,8], False)
|
||||
# 1 -> 3
|
||||
_assert_grouped_dims("gidx", (65536,), (16,16,256), False, [16,16,256], False)
|
||||
# 2 -> 3
|
||||
_assert_grouped_dims("gidx", (128,128), (16,16,256), False, [16,16,64], False)
|
||||
# test when the only divisor is the square root of dim
|
||||
_assert_grouped_dims("gidx", (121,), (12,12,12), False, [11,11], False)
|
||||
|
||||
# collapse on onto the left most axis
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), False, [6,4,5])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (32,16,16), True, [20,3,2])
|
||||
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (32,16,16), True, [20,3,Variable("start_pos",1,2)])
|
||||
|
||||
# collapse on left-most available axis (the left most is too small)
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (4,16,16), False, [2,12,5])
|
||||
_assert_grouped_dims("gidx", (2,3,4,5), (16,16,16), True, [5,12,2])
|
||||
|
||||
# _assert_grouped_dims("gidx", (Variable("start_pos",1,2),3,4,5), (16,16,16), False, [Variable("start_pos",1,2)*3,4,5])
|
||||
|
||||
# dim too large and not factorable
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (23,), (16,16,16), False,)
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (128,3,4), (16,2,2), False,)
|
||||
|
||||
# too large for sizes
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_grouped_dims("gidx", (2,3,4,5,6), (16,16,16))
|
||||
|
||||
# # variable too large
|
||||
# with self.assertRaises(AssertionError):
|
||||
# get_grouped_dims("gidx", (Variable("start_pos",0,16),3,4), (16,16,16), False,)
|
||||
|
||||
@unittest.skip("only one global now")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
def test_default_global_reversed(self):
|
||||
# shrink so that the dims do not collapse
|
||||
|
||||
Reference in New Issue
Block a user