mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
@@ -37,6 +37,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
|
||||
::: tinygrad.Tensor.hardsigmoid
|
||||
::: tinygrad.Tensor.elu
|
||||
::: tinygrad.Tensor.celu
|
||||
::: tinygrad.Tensor.selu
|
||||
::: tinygrad.Tensor.swish
|
||||
::: tinygrad.Tensor.silu
|
||||
::: tinygrad.Tensor.relu6
|
||||
|
||||
@@ -189,7 +189,7 @@ class StableDiffusion:
|
||||
# make image correct size and scale
|
||||
x = (x + 1.0) / 2.0
|
||||
x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255
|
||||
return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
|
||||
return x.cast(dtypes.uint8)
|
||||
|
||||
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
|
||||
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
|
||||
@@ -280,7 +280,7 @@ if __name__ == "__main__":
|
||||
print(x.shape)
|
||||
|
||||
# save image
|
||||
im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
|
||||
im = Image.fromarray(x.numpy())
|
||||
print(f"saving {args.out}")
|
||||
im.save(args.out)
|
||||
# Open image.
|
||||
|
||||
@@ -8,7 +8,7 @@ import numpy as np
|
||||
|
||||
tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan","Relu",
|
||||
"Sigmoid", "MatMul", "Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign",
|
||||
"Asinh", "Acosh", "Atanh", "Elu", "Celu", "Xor", "Round", "Erf"}
|
||||
"Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf"}
|
||||
|
||||
# **************** Free Ops ****************
|
||||
|
||||
@@ -44,7 +44,6 @@ def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, v
|
||||
|
||||
def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1)
|
||||
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
|
||||
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
|
||||
def PRelu(X:Tensor, slope:Tensor):
|
||||
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
|
||||
return (X > 0).where(X, X * slope)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
indent-width = 2
|
||||
preview = true
|
||||
target-version = "py38"
|
||||
target-version = "py310"
|
||||
|
||||
lint.select = [
|
||||
"F", # Pyflakes
|
||||
|
||||
2
setup.py
2
setup.py
@@ -59,7 +59,7 @@ setup(name='tinygrad',
|
||||
"bottle",
|
||||
"ggml-python"
|
||||
],
|
||||
'webgpu': ["wgpu>=v0.19.0"],
|
||||
'webgpu': ["wgpu==v0.18.1"],
|
||||
'docs': [
|
||||
"mkdocs",
|
||||
"mkdocs-material",
|
||||
|
||||
46
test/external/external_debug_metal_sd_conv.py
vendored
Normal file
46
test/external/external_debug_metal_sd_conv.py
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
# ruff: noqa: E501
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.engine.search import bufs_from_lin
|
||||
from tinygrad.helpers import Timing
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 1, 1, 1), strides=(81920, 0, 64, 8, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 2560, 4, 10, 4, 10), strides=(0, 163840, 0, 64, 0, 8, 0, 1), offset=-9, mask=((0, 1), (0, 2), (0, 1), (0, 2560), (0, 4), (1, 9), (0, 4), (1, 9)), contiguous=False), View(shape=(2, 1, 1280, 8, 8, 2560, 3, 3), strides=(4096000, 0, 0, 40, 1, 1600, 440, 11), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 2560, 3, 3), strides=(0, 0, 23040, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()),
|
||||
x17:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=4, src=()),
|
||||
x17,)),)),)),))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=2)]
|
||||
|
||||
k = Kernel(ast)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
bufs = bufs_from_lin(k)
|
||||
|
||||
prg = CompiledRunner(k.to_program())
|
||||
|
||||
with Timing("run "):
|
||||
prg(bufs, var_vals={}, wait=True)
|
||||
|
||||
# on M1 Max
|
||||
# 11ms before block 9b0859d71780fef5cf3831e317f74e53f2483229
|
||||
# 15ms after block cbcc1c20eb09a1342f6581cfbb99632bade982a8
|
||||
8
test/external/speed_v_theoretical.py
vendored
8
test/external/speed_v_theoretical.py
vendored
@@ -88,13 +88,13 @@ class TestKernelSpeed(unittest.TestCase):
|
||||
# def test_gemm_1024(self): self._test_matmul(1024, nv_tflops=8, amd_tflops=7)
|
||||
# def test_gemm_2048(self): self._test_matmul(2048, nv_tflops=50, amd_tflops=30)
|
||||
def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=95, amd_tflops=70)
|
||||
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=130, amd_tflops=70)
|
||||
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=70)
|
||||
|
||||
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=400)
|
||||
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=430, amd_gbs=400)
|
||||
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400
|
||||
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400
|
||||
|
||||
# TODO: tiny7 is slower than tiny12
|
||||
def test_conv_3x3_256_32_32_256_256(self): self._test_conv_3x3(256, 32, 32, 256, 256, nv_tflops=27, amd_tflops=18)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -17,7 +17,7 @@ class TestConvShapetracker(unittest.TestCase):
|
||||
# run it again to get the kernels
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is Ops.SINK]
|
||||
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
|
||||
for st in [x.st_arg for x in sched[0].ast.parents if x.op is Ops.LOAD]:
|
||||
for st in [x.st_arg for x in sched[0].ast.toposort if x.op is Ops.LOAD]:
|
||||
assert len(st.views) == 1
|
||||
|
||||
def test_conv_2x2_backward_one_view(self):
|
||||
@@ -26,7 +26,7 @@ class TestConvShapetracker(unittest.TestCase):
|
||||
conv(X).mean().backward()
|
||||
si = X.grad.schedule()[-1]
|
||||
print(si)
|
||||
ldb = [x for x in si.ast.parents if x.op is Ops.LOAD][0]
|
||||
ldb = [x for x in si.ast.toposort if x.op is Ops.LOAD][0]
|
||||
st: ShapeTracker = ldb.st_arg.simplify()
|
||||
# NOTE: st.real_size() is broken
|
||||
print(si.inputs[0].size)
|
||||
|
||||
@@ -120,7 +120,9 @@ class TestDType(unittest.TestCase):
|
||||
data = [1., 2., 0., 0.5, -1.5, 5.25]
|
||||
for dt in dtypes:
|
||||
arr = np.asarray(data).astype(dt)
|
||||
tin = Tensor(arr).numpy()
|
||||
tensor = Tensor(arr)
|
||||
if not is_dtype_supported(tensor.dtype): continue
|
||||
tin = tensor.numpy()
|
||||
tor = torch.as_tensor(arr).detach().numpy()
|
||||
assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
|
||||
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
|
||||
|
||||
@@ -79,7 +79,7 @@ def universal_test_unary(a, dtype, op):
|
||||
np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
|
||||
else: np.testing.assert_equal(tensor_value, numpy_value)
|
||||
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
|
||||
op = [x for x in ast.parents if x.op in GroupOp.Unary][0]
|
||||
op = [x for x in ast.toposort if x.op in GroupOp.Unary][0]
|
||||
assert op.dtype == dtype
|
||||
|
||||
def universal_test_cast(a, in_dtype, dtype):
|
||||
|
||||
@@ -96,10 +96,18 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin = helper_linearizer_ast(sink, [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0]
|
||||
|
||||
stores = [u for u in lin.uops if u.op is Ops.STORE]
|
||||
mutable_bufs = dedup(flatten([[x for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL] for u in stores]))
|
||||
mutable_bufs = dedup(flatten([[x for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL] for u in stores]))
|
||||
assert len(mutable_bufs) == len(stores) == 2
|
||||
assert [u.arg for u in mutable_bufs] == [0, 1]
|
||||
|
||||
def _test_no_nested_ranges(self, lins, skip=None):
|
||||
for l in lins:
|
||||
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_ACC])
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)]
|
||||
for i,u in enumerate(ranges):
|
||||
if skip and i in skip: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
@@ -130,11 +138,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(-1, keepdims=True)).sum(-1).reshape(1,1)
|
||||
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i == 0: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -194,11 +198,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i == 0: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0])
|
||||
|
||||
def test_triple_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
@@ -218,11 +218,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
sink = UOp(Ops.SINK, src=(store,))
|
||||
wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5)
|
||||
lins = helper_linearizer_ast(sink, [x0,x1,x2], wanna_output=[wanna_output])
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i == 0: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -270,11 +266,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)],
|
||||
]
|
||||
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i < 2: continue
|
||||
assert ranges[i-2] != u or ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-2], ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0, 1])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -301,11 +293,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
|
||||
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i == 0: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -339,11 +327,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
]
|
||||
wanna_output = (x.numpy()-(x.numpy().sum(-1, keepdims=True)+np.exp2(x_p.numpy()).sum(-1, keepdims=True))).sum(-1).reshape(4, 1,1)
|
||||
lins = helper_linearizer_ast(sink, [x,x_p], wanna_output=[wanna_output], opts=opts)
|
||||
for l in lins:
|
||||
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
|
||||
for i,u in enumerate(ranges):
|
||||
if i == 0: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
self._test_no_nested_ranges(lins, [0])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
def test_multiout_multireduce(self):
|
||||
@@ -988,10 +972,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
# the first store is to lds and can be upcasted
|
||||
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
|
||||
assert any(x.op is Ops.DEFINE_LOCAL for x in stores[0].sparents)
|
||||
assert any(x.op is Ops.DEFINE_LOCAL for x in stores[0].toposort)
|
||||
# the second store is to gds with no upcasts
|
||||
assert stores[1].src[-1].dtype == dtypes.float
|
||||
assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].sparents)
|
||||
assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort)
|
||||
|
||||
def test_zero_fold(self):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
@@ -1155,7 +1139,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_grouped_dims(self):
|
||||
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
|
||||
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
|
||||
loop_idxs = dedup(flatten([[y for y in x.sparents if y.op is Ops.SPECIAL] for x in idxs]))
|
||||
loop_idxs = dedup(flatten([[y for y in x.toposort if y.op is Ops.SPECIAL] for x in idxs]))
|
||||
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0])
|
||||
sizes = [x.arg[1] for x in loop_idxs]
|
||||
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
|
||||
|
||||
@@ -95,7 +95,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
print(prg.src)
|
||||
if_uops = [u for u in k.uops if u.op is Ops.IF]
|
||||
self.assertIn(len(if_uops), {1,2,3})
|
||||
conditions = if_uops[0].src[0].sparents
|
||||
conditions = if_uops[0].src[0].toposort
|
||||
self.assertLessEqual(len(conditions), 9)
|
||||
|
||||
# this was a bug in embedding, someday we should fold this anyway
|
||||
|
||||
@@ -1045,7 +1045,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
ifs = [u for u in k.uops if u.op is Ops.IF]
|
||||
self.assertEqual(len(ifs), 3)
|
||||
#for st in k.uops.sink.src: self.assertEqual(len(st.src), 4)
|
||||
self.assertLessEqual(len(ifs[0].src[0].sparents), 17)
|
||||
self.assertLessEqual(len(ifs[0].src[0].toposort), 17)
|
||||
|
||||
def test_failure_45(self):
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
|
||||
@@ -697,6 +697,9 @@ class TestOps(unittest.TestCase):
|
||||
for val in range(1, 5):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
|
||||
helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
|
||||
def test_selu(self):
|
||||
helper_test_op([(45,65)], torch.nn.functional.selu, Tensor.selu)
|
||||
helper_test_op([()], torch.nn.functional.selu, Tensor.selu)
|
||||
|
||||
def test_abs(self):
|
||||
helper_test_op([(45,65)], torch.abs, Tensor.abs)
|
||||
|
||||
@@ -199,7 +199,7 @@ class TestSchedule(unittest.TestCase):
|
||||
r1 = (x - r0).sum(axis=0).div(2)
|
||||
out = r0 + r1
|
||||
schedule = check_schedule(out, 2)
|
||||
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
|
||||
reduceops = [x for si in schedule for x in si.ast.toposort if x.op is Ops.REDUCE_AXIS]
|
||||
assert len(reduceops) == 2
|
||||
|
||||
def test_cache_reduce_multiple_children(self):
|
||||
@@ -210,7 +210,7 @@ class TestSchedule(unittest.TestCase):
|
||||
out0 = r0 + y
|
||||
out1 = r1 + y
|
||||
schedule = check_schedule([out0, out1], 4)
|
||||
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
|
||||
reduceops = [x for si in schedule for x in si.ast.toposort if x.op is Ops.REDUCE_AXIS]
|
||||
assert len(reduceops) == 2
|
||||
|
||||
def test_fold_double_unary(self):
|
||||
@@ -1673,7 +1673,7 @@ class TestIndexing(unittest.TestCase):
|
||||
@track_rewrites(named=True)
|
||||
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
|
||||
|
||||
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.sparents if x.op is Ops.VIEW and len(x.src) != 0])
|
||||
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0])
|
||||
|
||||
class TestSwizzle(unittest.TestCase):
|
||||
def test_swizzle_simple(self):
|
||||
@@ -1755,7 +1755,7 @@ class TestSwizzle(unittest.TestCase):
|
||||
# EXPAND is rewritten
|
||||
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
|
||||
# and pushed to the LOAD
|
||||
new_load_st = unwrap([x for x in ret.parents if x.op is Ops.VIEW][0].st)
|
||||
new_load_st = unwrap([x for x in ret.toposort if x.op is Ops.VIEW][0].st)
|
||||
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
|
||||
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
def test_reasonable_time(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0]
|
||||
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
|
||||
memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.parents if x.op is Ops.LOAD}
|
||||
memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.toposort if x.op is Ops.LOAD}
|
||||
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
|
||||
tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# basic self-contained tests of the external functionality of tinygrad
|
||||
import unittest
|
||||
import unittest, random
|
||||
from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device
|
||||
from tinygrad.helpers import IMAGE
|
||||
|
||||
@@ -41,13 +41,22 @@ class TestTiny(unittest.TestCase):
|
||||
|
||||
def test_jit(self):
|
||||
cnt = 0
|
||||
random.seed(0)
|
||||
def new_rand_list(ln=10): return [random.randint(0, 100000) for _ in range(ln)]
|
||||
|
||||
@TinyJit
|
||||
def fxn(a,b):
|
||||
def fxn(a,b) -> Tensor:
|
||||
nonlocal cnt
|
||||
cnt += 1
|
||||
return a+b
|
||||
fa,fb = Tensor([1.,2,3]), Tensor([4.,5,6])
|
||||
for _ in range(3): fxn(fa, fb)
|
||||
|
||||
for _ in range(3):
|
||||
la,lb = new_rand_list(), new_rand_list()
|
||||
fa,fb = Tensor(la), Tensor(lb)
|
||||
ret = fxn(fa, fb)
|
||||
# math is correct
|
||||
self.assertListEqual(ret.tolist(), [a+b for a,b in zip(la, lb)])
|
||||
|
||||
# function is only called twice
|
||||
self.assertEqual(cnt, 2)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
|
||||
new_sink = full_graph_rewrite(lower_sink)
|
||||
et = time.perf_counter() - st
|
||||
UOp.__init__ = old_init
|
||||
print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.sparents)} -> {len(new_sink.sparents)}, creating {cnt[0]} uops")
|
||||
print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.toposort)} -> {len(new_sink.toposort)}, creating {cnt[0]} uops")
|
||||
|
||||
class TestGraphRewriteConst(unittest.TestCase):
|
||||
def test_gep_const(self):
|
||||
@@ -106,7 +106,7 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
a1 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
|
||||
a2 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
|
||||
sink = a1.sink(a2)
|
||||
define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).sparents if x.op is Ops.DEFINE_VAR]
|
||||
define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).toposort if x.op is Ops.DEFINE_VAR]
|
||||
self.assertEqual(len(define_vars), 1)
|
||||
|
||||
def test_simple(self):
|
||||
@@ -187,7 +187,7 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
print(sink.render())
|
||||
self.assertEqual(sink.op, Ops.ADD)
|
||||
self.assertEqual(sink.src[1].op, Ops.CONST)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is Ops.CONST]), 1)
|
||||
self.assertEqual(len([x for x in sink.toposort if x.op is Ops.CONST]), 1)
|
||||
|
||||
class TestUOpGraph(unittest.TestCase):
|
||||
def test_add_constant_fold(self):
|
||||
@@ -433,8 +433,8 @@ class TestUOpGraph(unittest.TestCase):
|
||||
c0 = UOp.const(dtypes.int, 0)
|
||||
c2 = UOp.const(dtypes.int, 2)
|
||||
cf = UOp.const(dtypes.float, 0.0)
|
||||
r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 0, False))
|
||||
r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 1, False))
|
||||
r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), 0)
|
||||
r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), 1)
|
||||
alu = UOp(Ops.MUL, dtypes.int, (r2, r1))
|
||||
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
|
||||
uops = to_uops_list([store])
|
||||
@@ -600,14 +600,14 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1
|
||||
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1
|
||||
|
||||
def test_two_load_fold(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(8)]
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 2
|
||||
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 2
|
||||
|
||||
def test_simple_load_fold_gated(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
@@ -615,8 +615,8 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate),)) for i in range(4)]
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1
|
||||
single_load = [x for x in sink.sparents if x.op is Ops.LOAD][0]
|
||||
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1
|
||||
single_load = [x for x in sink.toposort if x.op is Ops.LOAD][0]
|
||||
self.assertEqual(single_load.src[1].op, Ops.VECTORIZE)
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
@@ -627,14 +627,14 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
UOp.const(dtypes.float, 0))) for i in range(4)]
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 3
|
||||
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 3
|
||||
|
||||
def test_simple_store_fold(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0))) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1
|
||||
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1
|
||||
|
||||
def test_simple_store_fold_gate(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
@@ -642,10 +642,11 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1
|
||||
one_store = [x for x in sink.sparents if x.op is Ops.STORE][0]
|
||||
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1
|
||||
one_store = [x for x in sink.toposort if x.op is Ops.STORE][0]
|
||||
assert len(one_store.src) == 3
|
||||
assert str(one_store.src[2]) == str(gate) # huh, why do i need str here?
|
||||
_if_node = one_store.src[2]
|
||||
assert _if_node.op == Ops.IF and _if_node.src[0] == gate
|
||||
|
||||
def test_simple_store_dont_fold(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
@@ -655,7 +656,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
UOp.const(dtypes.float, i))) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 3
|
||||
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 3
|
||||
|
||||
class TestIFUOps(unittest.TestCase):
|
||||
def test_create_ifs(self):
|
||||
@@ -671,7 +672,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
store = UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, 0), gate), lbuf))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is Ops.IF]
|
||||
if_uops = [u for u in sink.toposort if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
@@ -689,7 +690,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is Ops.IF]
|
||||
if_uops = [u for u in sink.toposort if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
@@ -705,7 +706,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is Ops.IF]
|
||||
if_uops = [u for u in sink.toposort if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
|
||||
@@ -247,58 +247,59 @@ class TestConstantFolding(unittest.TestCase):
|
||||
assert any(uop.op is Ops.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast"
|
||||
|
||||
class TestGatedStoreRewrite(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_tiny_gate_store(self):
|
||||
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2)))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
store = UOp(Ops.STORE, dtypes.void, (gmem, idx, val, gate))
|
||||
store = UOp(Ops.STORE, dtypes.void, (idx, val, gate))
|
||||
uops = to_uops_list([store])
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.op is Ops.IF)
|
||||
endif = next(u for u in uops if u.op is Ops.ENDIF)
|
||||
assert endif.src[0] is if_uop
|
||||
gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
|
||||
gated_uops = tuple(uops[uops.index(if_uop)+1:uops.index(endif)])
|
||||
self.assertEqual(len(gated_uops), 1)
|
||||
self.assertIs(gated_uops[-1].op, Ops.STORE)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_gate_some_stores(self):
|
||||
gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val)]
|
||||
uops = linearize_uop(stores)
|
||||
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val)]
|
||||
uops = to_uops_list(stores)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.op is Ops.IF)
|
||||
endif = next(u for u in uops if u.op is Ops.ENDIF)
|
||||
assert endif.src[0] is if_uop
|
||||
gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
|
||||
gated_uops = tuple(uops[uops.index(if_uop)+1:uops.index(endif)])
|
||||
self.assertEqual(len(gated_uops), 1)
|
||||
self.assertIs(gated_uops[-1].op, Ops.STORE)
|
||||
|
||||
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
|
||||
@unittest.expectedFailure
|
||||
def test_merge_ifs_alt(self):
|
||||
gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
|
||||
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val, gate)]
|
||||
uops = linearize_uop(stores)
|
||||
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val, gate)]
|
||||
uops = to_uops_list(stores)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
ifs = [u for u in uops if u.op is Ops.IF]
|
||||
endifs = [u for u in uops if u.op is Ops.ENDIF]
|
||||
self.assertEqual(len(ifs), 1)
|
||||
self.assertEqual(len(endifs), 1)
|
||||
gated_uops = tuple(uops.uops[uops.uops.index(ifs[0])+1:uops.uops.index(endifs[0])])
|
||||
gated_uops = tuple(uops[uops.index(ifs[0])+1:uops.index(endifs[0])])
|
||||
self.assertEqual(len(gated_uops), 2)
|
||||
for x in gated_uops: self.assertIs(x.op, Ops.STORE)
|
||||
|
||||
@@ -314,6 +315,16 @@ class TestLocalAccess(unittest.TestCase):
|
||||
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
|
||||
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
|
||||
|
||||
# NOTE: webgpu specific, since only webgpu performs bitpacking for uchar
|
||||
@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Test local access with packed data type")
|
||||
def test_local_packed(self):
|
||||
uops = []
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(local=True), (), ('smem', 16))
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
|
||||
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
def test_local_indirect(self):
|
||||
uops = []
|
||||
|
||||
@@ -24,7 +24,7 @@ class TestWinograd(unittest.TestCase):
|
||||
|
||||
for i,s in enumerate(sched):
|
||||
if s.ast.op is not Ops.SINK: continue
|
||||
ops = s.ast.parents
|
||||
ops = s.ast.toposort
|
||||
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
|
||||
l = Kernel(s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
|
||||
@@ -492,6 +492,30 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
|
||||
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
|
||||
|
||||
def test_where_combine(self):
|
||||
cond = Variable("x", 0, 3).lt(2)
|
||||
a = Variable("a", 0, 3)
|
||||
b = Variable("b", 0, 3)
|
||||
aa = cond.where(a, a.ufix(0))
|
||||
bb = cond.where(b, b.ufix(1))
|
||||
self.helper_test_variable(aa, 0, 3, "(a if (x<2) else 0)")
|
||||
self.helper_test_variable(bb, 0, 3, "(b if (x<2) else 1)")
|
||||
self.helper_test_variable(aa+bb, 0, 6, "((a+b) if (x<2) else 1)")
|
||||
self.helper_test_variable(aa.maximum(bb), 0, 3, "(max(a, b) if (x<2) else 1)")
|
||||
|
||||
# not combining because it increased total ALU
|
||||
c = Variable("c", 0, 3)
|
||||
cc = cond.where(c, c+1)
|
||||
self.helper_test_variable(bb+cc, 0, 7, "((b if (x<2) else 1)+(c if (x<2) else (c+1)))")
|
||||
|
||||
# not combining # TODO: can combine if it can further simplify?
|
||||
ab = cond.where(a, b)
|
||||
ba = cond.where(b, a)
|
||||
self.helper_test_variable(ab+ba, 0, 6, "((a if (x<2) else b)+(b if (x<2) else a))")
|
||||
|
||||
# not combining # TODO: can combine if one is identity element const
|
||||
self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
MIN, MAX = 0, 10
|
||||
|
||||
@@ -68,10 +68,11 @@ class Kernel:
|
||||
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
|
||||
|
||||
self.vars: List[Variable] = self.ast.variables()
|
||||
self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in GroupOp.Buffer]
|
||||
# NOTE: this requires a specific order with the [::-1], this is likely a bug
|
||||
self.bufs: List[UOp] = [x for x in self.ast.toposort if x.op in GroupOp.Buffer][::-1]
|
||||
|
||||
# get earlybufs, before any reduceops
|
||||
earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in GroupOp.Buffer]
|
||||
earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer]
|
||||
self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
|
||||
# NOTE: full_shape can be wrong if there's a tree of reduces
|
||||
|
||||
@@ -597,7 +598,7 @@ class Kernel:
|
||||
@functools.cached_property
|
||||
def name(self) -> str:
|
||||
# kernel name (before late upcast)
|
||||
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in GroupOp.Buffer for x in self.ast.parents) else "E")
|
||||
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort) else "E")
|
||||
suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
|
||||
|
||||
@@ -712,7 +713,7 @@ class Kernel:
|
||||
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
||||
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
||||
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
|
||||
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
|
||||
for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
|
||||
key=lambda x: (x.op, x.src[0].arg)))
|
||||
return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
|
||||
@@ -1,92 +1,168 @@
|
||||
from typing import List, Set, Dict, Tuple
|
||||
import functools, heapq
|
||||
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops, GroupOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import dedup, flatten, partition
|
||||
|
||||
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]):
|
||||
if u in children: return srcs[u]
|
||||
srcs[u] = {}
|
||||
children[u] = []
|
||||
for x in u.src:
|
||||
srcs[u].update(get_children_dfs(x, children, srcs, in_degree))
|
||||
if x.op is Ops.RANGE and x.arg[1]: srcs[u][x] = None
|
||||
children[x].append(u)
|
||||
in_degree[u] = len(u.src)
|
||||
return srcs[u]
|
||||
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST,
|
||||
Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
|
||||
|
||||
def disp(y:UOp) -> str:
|
||||
if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
|
||||
if y.op is Ops.IF: return f'IF{id(y)}'
|
||||
if y.op is Ops.RANGE: return str(y.arg)
|
||||
return "<NONE>"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BasicBlock:
|
||||
ctx: Tuple[UOp, ...]
|
||||
lst: Tuple[UOp, ...]
|
||||
end: Optional[UOp] = None
|
||||
def __repr__(self):
|
||||
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\
|
||||
f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
|
||||
|
||||
def append_to_block(ctx, x:UOp):
|
||||
block_ctxs, children = ctx
|
||||
new_srcs: List[UOp] = []
|
||||
to_append: List[UOp] = []
|
||||
new_blocks: Dict[Tuple[UOp, ...], List[UOp]] = {}
|
||||
in_this_block = set(x.arg.lst)
|
||||
for u in x.src:
|
||||
if u.op in DONT_PLACE_IN_BLOCK or len([y for y in children[u] if y not in in_this_block]) > 0:
|
||||
# if it's a fork or not placed, we don't place it
|
||||
new_srcs.append(u)
|
||||
elif (block_ctx:=block_ctxs[u]) == x.arg.ctx:
|
||||
# if it's the same context, we place the UOp in this block and append the parents to it's srcs
|
||||
new_srcs += list(u.src)
|
||||
to_append.append(u)
|
||||
else:
|
||||
# otherwise, we create a new block with this UOp
|
||||
new_blocks.setdefault(block_ctx, []).append(u)
|
||||
if len(to_append) == 0 and len(new_blocks) == 0: return None
|
||||
|
||||
for rng,lst in new_blocks.items():
|
||||
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(flatten(y.src for y in lst))), BasicBlock(rng, tuple(lst)))
|
||||
lrng = list(rng)
|
||||
for r in rng[::-1]:
|
||||
if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
|
||||
lrng.remove(r)
|
||||
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
|
||||
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
|
||||
new_srcs.append(new_block)
|
||||
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
|
||||
|
||||
make_basic_blocks = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))),
|
||||
(UPat(Ops.BLOCK, name="x"), append_to_block),
|
||||
])
|
||||
|
||||
def block_merge(ctx, x:UOp):
|
||||
# ctx is children here
|
||||
if x.op is Ops.BLOCKEND:
|
||||
# if it's a BLOCKEND, see if we are done with placement. if all the children of the range are in here
|
||||
in_this_block = set(x.arg.lst)
|
||||
if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0:
|
||||
# find the parent block that has the BLOCKSTART in the ctx
|
||||
parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx]
|
||||
if len(parent_blocks) == 1:
|
||||
parent_block = parent_blocks[0]
|
||||
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
|
||||
early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
|
||||
return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src,
|
||||
BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops)))
|
||||
assert not len(parent_blocks)
|
||||
|
||||
new_srcs: List[UOp] = []
|
||||
to_append: List[UOp] = []
|
||||
new_ctx = x.arg.ctx
|
||||
placed = set()
|
||||
for u in x.src:
|
||||
if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)):
|
||||
# NOTE: this can't appear in srcs twice or it would be a BLOCKFORK
|
||||
new_ctx += u.arg.ctx
|
||||
new_srcs += list(u.src)
|
||||
to_append += u.arg.lst
|
||||
elif u.op is Ops.BLOCKFORK and len([y for y in x.src if y is u]) == u.arg: # block fork appears # of times in srcs
|
||||
if u not in placed:
|
||||
new_srcs += list(u.src)
|
||||
placed.add(u)
|
||||
else:
|
||||
# keep it in srcs
|
||||
new_srcs.append(u)
|
||||
if len(to_append) == 0 and len(placed) == 0: return None
|
||||
return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(dedup(new_ctx)), tuple(to_append)+x.arg.lst, x.arg.end))
|
||||
|
||||
pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
|
||||
|
||||
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
# filter nodes that don't link to a sink
|
||||
# BFS toposort
|
||||
|
||||
# get children and all block contexts
|
||||
temp_block_ctxs: Dict[UOp, List[UOp]] = {}
|
||||
children: Dict[UOp, List[UOp]] = {}
|
||||
range_srcs: Dict[UOp, Dict[UOp, None]] = {}
|
||||
in_degree: Dict[UOp, int] = {}
|
||||
get_children_dfs(sink, children, range_srcs, in_degree)
|
||||
for u in sink.toposort:
|
||||
this_block_ctx: List[UOp] = []
|
||||
for s in u.src:
|
||||
# save children
|
||||
children.setdefault(s, []).append(u)
|
||||
# compute block ctx
|
||||
if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s)
|
||||
# don't flow (fully) through assign and store
|
||||
elif s.op is Ops.STORE:
|
||||
# ugh, deal with non-reduce locals. probably wrong
|
||||
if isinstance(s.src[0].dtype, PtrDType) and s.src[0].dtype.local:
|
||||
idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s]
|
||||
this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE]
|
||||
elif s.op is Ops.ASSIGN:
|
||||
# flow though assign, but remove the ranges used in the assign
|
||||
assert s.src[0].op is Ops.DEFINE_ACC
|
||||
this_block_ctx += [x for x in temp_block_ctxs[s.src[1]] if x not in s.src[0].src[1:]]
|
||||
else:
|
||||
# flow though everything else
|
||||
this_block_ctx += temp_block_ctxs[s]
|
||||
temp_block_ctxs[u] = dedup(sorted(this_block_ctx, key=lambda x: x.tuplize))
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_recursive_children(x:UOp, end:Ops, include_self=False) -> Set[UOp]:
|
||||
if x.op is Ops.SINK: return set()
|
||||
return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
|
||||
# make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE
|
||||
block_ctxs: Dict[UOp, Tuple[UOp, ...]] = {}
|
||||
for u in sink.toposort:
|
||||
block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + tuple(temp_block_ctxs[u])) if u.op in {Ops.IF, Ops.RANGE} else tuple(temp_block_ctxs[u])
|
||||
|
||||
# scope children impact the toposort and END* insertion
|
||||
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
|
||||
range_phi = {r:[p for p in scope_children[r] if p.op is Ops.ASSIGN] for r in scope_children if r.op is Ops.RANGE}
|
||||
# TODO: there's probably a clever way to remove this while loop
|
||||
while 1:
|
||||
sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children))
|
||||
|
||||
# assign priorities
|
||||
def get_priority(u:UOp):
|
||||
priority = 0
|
||||
# prefer ranges that depend on the least number of independent ranges
|
||||
if u.op is Ops.RANGE and u.arg[1]:
|
||||
priority += u.arg[0]
|
||||
for p in range_phi[u]:
|
||||
priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
|
||||
elif u.op is Ops.CONST:
|
||||
# place consts first here, they don't do anything and it can cause issues with DEFINE_ACC
|
||||
priority -= 100000000000
|
||||
else:
|
||||
# prefer uops that are loop children
|
||||
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is Ops.RANGE and u in ss])
|
||||
if u.op is Ops.IF and len(u.src) == 1: priority += 10000000 # if penalty
|
||||
return priority
|
||||
priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}
|
||||
# add BLOCKFORK (slow!)
|
||||
block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
|
||||
non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]))
|
||||
forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count)
|
||||
for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents}
|
||||
|
||||
# prevent priority inversion
|
||||
@functools.lru_cache(None)
|
||||
def fix_priority(u:UOp, lowest_priority):
|
||||
if u.op in {Ops.CAST, Ops.BITCAST, *GroupOp.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
|
||||
priorities[u] = min(priorities[u], lowest_priority)
|
||||
if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here)
|
||||
for x in u.src: fix_priority(x, priorities[u])
|
||||
fix_priority(sink, 0)
|
||||
if not len(forks): break
|
||||
sink = sink.substitute(forks)
|
||||
|
||||
# NOTE: the compare should never make it all the way to u
|
||||
queue:List[Tuple[int, Tuple, UOp]] = []
|
||||
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
|
||||
# combine matching BLOCKENDS
|
||||
blockends_to_arg: Dict[UOp, List[UOp]] = {}
|
||||
for be in sink.toposort:
|
||||
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
|
||||
new_forks = {}
|
||||
for k,v in blockends_to_arg.items():
|
||||
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
|
||||
if len(v) > 1:
|
||||
out = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)),
|
||||
arg=BasicBlock(tuple(dedup(flatten([y.arg.ctx for y in v]))), v[0].arg.lst, k)),), arg=len(v))
|
||||
for u in v: new_forks[u] = out
|
||||
sink = sink.substitute(new_forks)
|
||||
|
||||
for u in children:
|
||||
if in_degree[u] == 0: push(u)
|
||||
# final rewrite to merge all blocks into one
|
||||
sink = graph_rewrite(sink, pm_block_merge, ctx=children)
|
||||
|
||||
scope_end: Dict[UOp, UOp] = {}
|
||||
_uops: List[UOp] = []
|
||||
while queue:
|
||||
p,_,x = heapq.heappop(queue)
|
||||
if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg)
|
||||
if x in scope_children: scope_end[x] = x
|
||||
if x.op is Ops.DEFINE_ACC:
|
||||
idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE])
|
||||
_uops.insert(idx, x)
|
||||
else: _uops.append(x)
|
||||
for u, ss in scope_children.items():
|
||||
if x in ss:
|
||||
ss.remove(x)
|
||||
if len(ss) == 0: scope_end[u] = x
|
||||
for u in children[x]:
|
||||
in_degree[u] -= 1
|
||||
if in_degree[u] == 0: push(u)
|
||||
|
||||
# end scopes in toposort order
|
||||
for u, x in scope_end.items(): _uops.insert(_uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], dtypes.void, (u,)))
|
||||
# there should just be one block left, with a few parents with 0 srcs
|
||||
assert sink.op is Ops.BLOCK
|
||||
_uops = sorted(dedup(sink.src), key=lambda x: x.tuplize)
|
||||
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
|
||||
_uops += sink.arg.lst
|
||||
|
||||
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
||||
if not skip_check: type_verify(_uops)
|
||||
|
||||
@@ -55,8 +55,8 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
full_shape = ast.full_shape
|
||||
first_upcasted = len(full_shape)-ki.upcasted
|
||||
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
|
||||
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS))
|
||||
local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
|
||||
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort if x.op is Ops.REDUCE_AXIS))
|
||||
local_loads = [x for x in ast.toposort if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
|
||||
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
|
||||
group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
|
||||
global_dims = first_reduce-ki.local_dims
|
||||
@@ -71,10 +71,10 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
else:
|
||||
# all loops are RANGES
|
||||
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, False)) for i,g in enumerate(full_shape[:first_reduce])]
|
||||
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[:first_reduce])]
|
||||
|
||||
# reduce loops
|
||||
idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, True))
|
||||
idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i)
|
||||
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
||||
|
||||
# upcast loops
|
||||
@@ -85,7 +85,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
# late indexes (group for reduce)
|
||||
ridxs = idxs[:]
|
||||
for a in range(first_reduce, first_reduce+group_for_reduces):
|
||||
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), (1000+a, True))
|
||||
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), 1000+a)
|
||||
|
||||
return IndexContext(idxs, ridxs)
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ def no_vectorized_wmma(wmma:UOp):
|
||||
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
||||
|
||||
def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
||||
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents)
|
||||
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.toposort)
|
||||
if len(reduce_unparented) == 0: return None
|
||||
new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
|
||||
ret = new_acc.assign(new_acc.alu(alu.op, ret))
|
||||
@@ -447,7 +447,7 @@ devectorize = PatternMatcher([
|
||||
])
|
||||
|
||||
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optional[UOp]=None) -> Optional[UOp]:
|
||||
if store_gate not in [gate.src[0] for gate in val.sparents if gate.op is Ops.IF]: return None
|
||||
if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
|
||||
# remove the gate from the index
|
||||
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
|
||||
|
||||
@@ -483,6 +483,9 @@ pm_render = PatternMatcher([
|
||||
# move masks of loads/stores
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
|
||||
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
||||
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
@@ -20,7 +20,7 @@ class _Device:
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
||||
cpn = multiprocessing.current_process().name
|
||||
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}"
|
||||
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"
|
||||
x = ix.split(":")[0].upper()
|
||||
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \
|
||||
if (cname.lower() == x.lower() + "device")][0](ix)
|
||||
|
||||
@@ -54,28 +54,28 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = to_uop(buf.base, ctx, buffers, cache).view(buf.st)
|
||||
return ret
|
||||
# make things that can't be images not images
|
||||
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
||||
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
||||
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}")
|
||||
# hack the underlying buffer too
|
||||
buf.dtype = buf.buffer.dtype = buf.dtype.base
|
||||
assert not buf.is_realized, "can't fixup allocated buffer"
|
||||
buf.buffer.options = None
|
||||
assert buf.op is not None, f"base must be base itself {buf}"
|
||||
dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base
|
||||
# make things that can't be images not images
|
||||
dtype = buf.dtype
|
||||
if isinstance(dtype, ImageDType) and (prod(buf.shape) != prod(dtype.shape) or not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
||||
assert buf.realized is None, "can't fixup allocated buffer"
|
||||
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
|
||||
dtype = buf.dtype.base
|
||||
# hack the underlying buffer too
|
||||
buf.buffer.dtype = buf.dtype = dtype
|
||||
buf.buffer.options = None
|
||||
if buf.is_realized:
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers))
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
buffers[ubuf] = buf.buffer
|
||||
op = None
|
||||
elif buf.op is Ops.ASSIGN:
|
||||
target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs]
|
||||
ctx.assigns.add(ubuf:=target.buf_uop)
|
||||
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
|
||||
op = UOp(Ops.ASSIGN, dtype.base, (ubuf, new_val), buf.arg)
|
||||
else:
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers))
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
buffers[ubuf] = buf.buffer
|
||||
op = UOp(buf.op, dtype, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
|
||||
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
|
||||
if op is not None:
|
||||
ctx.lazybufs[ubuf] = buf
|
||||
@@ -126,7 +126,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
||||
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
|
||||
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
|
||||
|
||||
# push VIEW to stores
|
||||
|
||||
@@ -41,9 +41,7 @@ def fully_flatten(l):
|
||||
return [l]
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
||||
def ceildiv(num, amt):
|
||||
ret = -(num//-amt)
|
||||
return ret if not isinstance(ret, float) else int(ret)
|
||||
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
|
||||
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
|
||||
def data64(data:Any) -> Tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint
|
||||
def data64_le(data:Any) -> Tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint
|
||||
@@ -52,10 +50,9 @@ def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
|
||||
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||
return {k:v for d in ds for k,v in d.items()}
|
||||
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
|
||||
a:List[T] = []
|
||||
b:List[T] = []
|
||||
for s in itr: (a if fxn(s) else b).append(s)
|
||||
return a,b
|
||||
ret:Tuple[List[T], List[T]] = ([], [])
|
||||
for s in itr: (ret[0] if fxn(s) else ret[1]).append(s)
|
||||
return ret
|
||||
def unwrap(x:Optional[T]) -> T:
|
||||
assert x is not None
|
||||
return x
|
||||
@@ -268,7 +265,8 @@ def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
||||
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
|
||||
def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
||||
def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
|
||||
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
|
||||
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char):
|
||||
return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
||||
class CStruct(ctypes.Structure):
|
||||
|
||||
@@ -248,10 +248,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg
|
||||
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
|
||||
@functools.cached_property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
def toposort(self) -> Dict[UOp, None]:
|
||||
nodes: Dict[UOp, None] = {}
|
||||
# NOTE: this is a lot faster than the comprehension in parents
|
||||
for parent in self.src: nodes.update(parent.toposort)
|
||||
nodes[self] = None
|
||||
return nodes
|
||||
|
||||
@functools.cached_property
|
||||
def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
|
||||
@@ -273,7 +277,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape)
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
@property
|
||||
def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape
|
||||
@property
|
||||
@@ -334,8 +338,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def alu(self, arg, *src:UOp):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None:
|
||||
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return UOp(arg, out_dtype, (self,)+src)
|
||||
@staticmethod
|
||||
def const(dtype:DType, b:ConstLike):
|
||||
@@ -378,14 +381,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# *** uop Buffer stuff ***
|
||||
|
||||
buffer_num = itertools.count(0)
|
||||
@staticmethod
|
||||
def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype)))
|
||||
def new_buffer(device:str, size:int, dtype:DType) -> UOp: return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype)))
|
||||
@functools.cached_property
|
||||
def device(self) -> str:
|
||||
match self.op:
|
||||
case Ops.COPY: return self.arg
|
||||
case Ops.BUFFER: return self.arg[1][0]
|
||||
case _: return self.src[0].device
|
||||
def device(self) -> str: return self.arg[1][0] if self.op is Ops.BUFFER else self.src[0].device
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
if self.op is Ops.BUFFER: return self
|
||||
@@ -412,12 +412,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def val(self) -> int: return self.unbind()[1]
|
||||
def vars(self) -> Set[UOp]:
|
||||
bound_vars = set([x for x in self.sparents if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
|
||||
bound_vars = set([x for x in self.toposort if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
|
||||
bound_var_base = set(x.src[0] for x in bound_vars)
|
||||
all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR])
|
||||
all_vars = set([x for x in self.toposort if x.op is Ops.DEFINE_VAR])
|
||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in GroupOp.Buffer]
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.toposort if x.op in GroupOp.Buffer]
|
||||
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
@@ -474,7 +474,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@functools.cached_property
|
||||
def _sym_fxn(self):
|
||||
sself = self.simplify()
|
||||
varnames = tuple(x.arg[0] for x in sself.sparents if x.op is Ops.DEFINE_VAR)
|
||||
varnames = tuple(x.arg[0] for x in sself.toposort if x.op is Ops.DEFINE_VAR)
|
||||
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
||||
return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
|
||||
|
||||
@@ -532,10 +532,10 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
if ignore_indexing:
|
||||
for u in uops:
|
||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||
dont_count = dont_count.union(u.src[0].sparents)
|
||||
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents)
|
||||
dont_count = dont_count.union(u.src[0].toposort)
|
||||
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort)
|
||||
elif u.op is Ops.IF:
|
||||
dont_count = dont_count.union(u.src[0].sparents)
|
||||
dont_count = dont_count.union(u.src[0].toposort)
|
||||
for u in uops:
|
||||
if u.op is Ops.RANGE:
|
||||
mult_stack.append(mults)
|
||||
@@ -622,7 +622,7 @@ class UPat(MathTrait):
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
||||
def alu(self, op:Ops, *src:UPat):
|
||||
asrc = (self,)+src
|
||||
return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
||||
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
||||
|
||||
def printable(self:UPat) -> str:
|
||||
try: return lines(self.location[0])[self.location[1]-1].strip()
|
||||
@@ -802,7 +802,7 @@ spec = PatternMatcher([
|
||||
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
||||
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
|
||||
(UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
||||
(UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype and isinstance(rng.arg, int)),
|
||||
(UPat(Ops.SPECIAL, src=()), lambda: True),
|
||||
|
||||
# TODO: confirm the args of both of these are shapetrackers
|
||||
@@ -1046,7 +1046,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
|
||||
# try checking the whole clause
|
||||
if expr in uop.sparents:
|
||||
if expr in uop.toposort:
|
||||
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
|
||||
|
||||
for candidate in candidates:
|
||||
@@ -1061,7 +1061,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
|
||||
def _valid_priority(v: UOp, valids:List[UOp]):
|
||||
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
|
||||
try: return sum(-1 if parse_valid(v)[0] in other.parents else 0 for other in valids)
|
||||
try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids)
|
||||
except ValueError: return 0
|
||||
|
||||
def simplify_valid(valid:UOp) -> Optional[UOp]:
|
||||
@@ -1132,6 +1132,9 @@ symbolic = symbolic_simple+PatternMatcher([
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# alu of two where with same conds can combine, only do if true branch or false branch is const
|
||||
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
|
||||
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
|
||||
# ALU min==max -> CONST (slow!)
|
||||
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# max folding
|
||||
|
||||
@@ -44,7 +44,7 @@ class ProgramSpec:
|
||||
for u in self.uops:
|
||||
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
|
||||
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
|
||||
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL])
|
||||
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
|
||||
if u.op is Ops.SPECIAL:
|
||||
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
|
||||
if u.arg[0][0] == 'i': self.local_size = None
|
||||
|
||||
@@ -57,9 +57,6 @@ extra_pm = PatternMatcher([
|
||||
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
||||
(UPat(Ops.BITCAST, name="x"),
|
||||
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
||||
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
||||
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
])
|
||||
@@ -130,7 +127,7 @@ class CStyleLanguage(Renderer):
|
||||
|
||||
# mark buffers that we store to writable
|
||||
if u.op is Ops.STORE:
|
||||
for up in u.src[0].sparents:
|
||||
for up in u.src[0].toposort:
|
||||
if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
|
||||
|
||||
# naming
|
||||
|
||||
@@ -60,13 +60,13 @@ llvm_rewrite = PatternMatcher([
|
||||
|
||||
# range
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
|
||||
f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
|
||||
f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"),
|
||||
f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n"
|
||||
f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"),
|
||||
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
|
||||
f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
|
||||
f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n"
|
||||
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
|
||||
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
|
||||
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg}, label %loop_exit_{x.src[0].arg}\nloop_exit_{x.src[0].arg}:"),
|
||||
|
||||
# if
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
|
||||
@@ -85,10 +85,6 @@ class LLVMRenderer(Renderer):
|
||||
(UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
|
||||
# rewrite cast to bool to CMPNE 0
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
|
||||
# *** also in cstyle ***
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
||||
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
||||
# rewrite MAX to CMPLT + WHERE
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
])
|
||||
@@ -135,7 +131,7 @@ class LLVMRenderer(Renderer):
|
||||
for x in acc_to_assign:
|
||||
if u in x.src: # if this range is relevent for this acc
|
||||
vc += 1
|
||||
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]")
|
||||
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]")
|
||||
r[x] = f"%acc{vc}"
|
||||
|
||||
# output the function
|
||||
|
||||
@@ -54,7 +54,7 @@ ptx_matcher = PatternMatcher([
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
|
||||
])
|
||||
|
||||
def mem_type(x: UOp): return 'shared' if x.src[0].op is Ops.DEFINE_LOCAL or any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].parents) else 'global'
|
||||
def mem_type(x: UOp): return 'shared' if any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].toposort) else 'global'
|
||||
|
||||
def render_wmma(ctx: "PTXRenderer", x: UOp):
|
||||
assert ctx.wmma_r, "registry values for wmma must be populated"
|
||||
@@ -76,8 +76,7 @@ def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.i
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat.cvar("x", dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
|
||||
(UPat.cvar("x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
|
||||
(UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var"), UPat.var("pred")), allow_any_len=True), lambda ctx, x, bidx, var, pred=None:
|
||||
f"{f'@{ctx.r[pred]} ' if pred is not None and pred.op is not Ops.IF else ''}st.{mem_type(bidx)}" + \
|
||||
(UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx, x, bidx, var: f"st.{mem_type(bidx)}" + \
|
||||
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
|
||||
f"[{ctx.r[bidx]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"),
|
||||
|
||||
@@ -68,7 +68,7 @@ class WGSLRenderer(CStyleLanguage):
|
||||
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda ctx,x: f"bitcast<u32>({x.arg})" \
|
||||
if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"bitcast<i32>({x.arg}u)" if x.arg >= 0x80000000 else f"{x.arg}"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{type_map[x.dtype.base]}, {x.arg[1]}>;"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.render_buf_dt(x.dtype.base)}, {x.arg[1]}>;"),
|
||||
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
|
||||
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
||||
@@ -78,7 +78,7 @@ class WGSLRenderer(CStyleLanguage):
|
||||
lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\
|
||||
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\natomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
|
||||
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
|
||||
else f"{ctx[b]} = {ctx[v]};"),
|
||||
# fix nan check: 'a != a -> is_nan()'
|
||||
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"is_nan({ctx[a]})"),
|
||||
@@ -86,7 +86,7 @@ class WGSLRenderer(CStyleLanguage):
|
||||
|
||||
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
|
||||
def render_buf(self, dt:DType, rw:bool) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and (dt.itemsize < 4) else buffer_map[dt.base]}"
|
||||
def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and dt.itemsize < 4 else buffer_map[dt.base]}"
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
||||
local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
|
||||
if not local_size: local_size = [1]
|
||||
@@ -99,6 +99,6 @@ class WGSLRenderer(CStyleLanguage):
|
||||
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
|
||||
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
|
||||
f"{'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'}" +
|
||||
f"{name}:{f'array<{self.render_buf(dtype.base,rw)}>' if isinstance(dtype, PtrDType) else buffer_map[dtype]};" for name,(dtype,rw) in bufs])
|
||||
f"{name}:{f'array<{self.render_buf_dt(dtype.base,rw)}>' if isinstance(dtype, PtrDType) else buffer_map[dtype]};" for name,(dtype,rw) in bufs])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
|
||||
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"
|
||||
|
||||
@@ -82,7 +82,12 @@ class AMDComputeQueue(HWQueue):
|
||||
def exec(self, prg:AMDProgram, args_state:AMDArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
|
||||
self.acquire_mem(gli=0, gl2=0)
|
||||
|
||||
user_regs = [*data64_le(prg.dev.scratch.va_addr), 0xffffffff, 0xc00000] if prg.enable_private_segment_sgpr else []
|
||||
if prg.enable_private_segment_sgpr:
|
||||
scratch_hilo = data64_le(prg.dev.scratch.va_addr)
|
||||
# sgpr word1 bit31 enables swizzle
|
||||
# sgpr word3 = 0x14 << 12 | 2 << 28 | 2 << 21 | 1 << 23
|
||||
user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000] if prg.enable_private_segment_sgpr else []
|
||||
else: user_regs = []
|
||||
if prg.enable_dispatch_ptr:
|
||||
dp = hsa.hsa_kernel_dispatch_packet_t.from_address(dp_addr:=args_state.ptr + prg.kernargs_segment_size)
|
||||
|
||||
@@ -370,12 +375,16 @@ class AMDDevice(HCQCompiled):
|
||||
max_cu_id = self.properties['simd_count'] // self.properties['simd_per_cu'] - 1
|
||||
max_wave_id = self.properties['max_waves_per_simd'] * self.properties['simd_per_cu'] - 1
|
||||
self.max_private_segment_size = 4096
|
||||
wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256) # gfx11 requires alignment of 256
|
||||
# <gfx103 requires alignment of 1024, >=gfx11 requires 256
|
||||
wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256 if self.target >= 110000 else 1024)
|
||||
self.scratch_len = (max_cu_id + 1) * self.properties['max_slots_scratch_cu'] * wave_scratch_len
|
||||
self.scratch = self._gpu_alloc(self.scratch_len, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
|
||||
self.has_scratch_base_registers = self.target >= 110000
|
||||
engines = self.properties['array_count'] // self.properties['simd_arrays_per_engine']
|
||||
self.tmpring_size = (wave_scratch_len // 256) << 12 | (self.scratch_len // (wave_scratch_len * engines))
|
||||
waves = wave_scratch_len // (256 if self.target >= 110000 else 1024)
|
||||
# >=gfx11 wavesize is per SE
|
||||
wavesize = self.scratch_len // ((wave_scratch_len * engines) if self.target >= 110000 else wave_scratch_len)
|
||||
self.tmpring_size = waves << 12 | wavesize
|
||||
|
||||
# https://gitlab.freedesktop.org/agd5f/linux/-/blob/a1fc9f584c4aaf8bc1ebfa459fc57a3f26a290d8/drivers/gpu/drm/amd/amdkfd/kfd_queue.c#L391
|
||||
sgrp_size_per_cu, lds_size_per_cu, hwreg_size_per_cu = 0x4000, 0x10000, 0x1000
|
||||
|
||||
@@ -58,8 +58,8 @@ class WebGpuAllocator(Allocator):
|
||||
|
||||
class WebGpuDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance")
|
||||
adapter = wgpu.gpu.request_adapter(power_preference="high-performance")
|
||||
timestamp_supported = wgpu.FeatureName.timestamp_query in adapter.features
|
||||
wgpu_device = adapter.request_device_sync(required_features=[wgpu.FeatureName.timestamp_query] if timestamp_supported else [])
|
||||
wgpu_device = adapter.request_device(required_features=[wgpu.FeatureName.timestamp_query] if timestamp_supported else [])
|
||||
super().__init__(device, WebGpuAllocator(wgpu_device), WGSLRenderer(), Compiler(),
|
||||
functools.partial(WebGPUProgram, (wgpu_device, timestamp_supported)))
|
||||
|
||||
@@ -81,17 +81,17 @@ class ShapeTracker:
|
||||
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
|
||||
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
|
||||
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
|
||||
used_ranges = [x.arg[0] for x in idx.sparents if x.op is Ops.RANGE]
|
||||
used_ranges = [x.arg[0] for x in idx.toposort if x.op is Ops.RANGE]
|
||||
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
||||
if not ignore_valid:
|
||||
for masked_axis in [x.arg[0] for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None
|
||||
for masked_axis in [x.arg[0] for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
|
||||
return tuple(ret)
|
||||
|
||||
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
def axis_is_masked(self, axis:int) -> bool:
|
||||
_, valid = self.to_indexed_uops()
|
||||
return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is Ops.RANGE]
|
||||
return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
|
||||
|
||||
def simplify(self) -> ShapeTracker:
|
||||
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
||||
|
||||
@@ -2688,6 +2688,19 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
||||
|
||||
def selu(self, alpha=1.67326, gamma=1.0507):
|
||||
"""
|
||||
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
|
||||
|
||||
- Described: https://paperswithcode.com/method/selu
|
||||
- Paper: https://arxiv.org/abs/1706.02515v5
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy())
|
||||
```
|
||||
"""
|
||||
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
|
||||
|
||||
def swish(self):
|
||||
"""
|
||||
See `.silu()`
|
||||
|
||||
@@ -60,7 +60,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List
|
||||
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
assert isinstance(x, UOp)
|
||||
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
|
||||
for u in x.sparents:
|
||||
for u in x.toposort:
|
||||
if u.op is Ops.CONST: continue
|
||||
label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
|
||||
for idx,x in enumerate(u.src):
|
||||
@@ -92,7 +92,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
|
||||
# sanity check
|
||||
if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
|
||||
# update ret data
|
||||
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not Ops.CONST])
|
||||
g.changed_nodes.append([id(x) for x in u1.toposort if x.op is not Ops.CONST])
|
||||
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
|
||||
g.graphs.append(sink:=new_sink)
|
||||
return g
|
||||
|
||||
Reference in New Issue
Block a user