Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2024-12-02 18:20:06 -05:00
39 changed files with 432 additions and 257 deletions

View File

@@ -37,6 +37,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
::: tinygrad.Tensor.hardsigmoid
::: tinygrad.Tensor.elu
::: tinygrad.Tensor.celu
::: tinygrad.Tensor.selu
::: tinygrad.Tensor.swish
::: tinygrad.Tensor.silu
::: tinygrad.Tensor.relu6

View File

@@ -189,7 +189,7 @@ class StableDiffusion:
# make image correct size and scale
x = (x + 1.0) / 2.0
x = x.reshape(3,512,512).permute(1,2,0).clip(0,1)*255
return x.cast(dtypes.uint8) if Device.DEFAULT != "WEBGPU" else x
return x.cast(dtypes.uint8)
def __call__(self, unconditional_context, context, latent, timestep, alphas, alphas_prev, guidance):
e_t = self.get_model_output(unconditional_context, context, latent, timestep, guidance)
@@ -280,7 +280,7 @@ if __name__ == "__main__":
print(x.shape)
# save image
im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
im = Image.fromarray(x.numpy())
print(f"saving {args.out}")
im.save(args.out)
# Open image.

View File

@@ -8,7 +8,7 @@ import numpy as np
tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan","Relu",
"Sigmoid", "MatMul", "Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign",
"Asinh", "Acosh", "Atanh", "Elu", "Celu", "Xor", "Round", "Erf"}
"Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf"}
# **************** Free Ops ****************
@@ -44,7 +44,6 @@ def Constant(value:Optional[Tensor]=None, value_float=None, value_floats=None, v
def HardSigmoid(x: Tensor, alpha=0.2, beta=0.5): return (alpha*x + beta).clip(0, 1)
def Gelu(x:Tensor, approximate=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
def Selu(X: Tensor, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): return gamma * (X.relu() - (-alpha*X.exp()+alpha).relu())
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
return (X > 0).where(X, X * slope)

View File

@@ -1,6 +1,6 @@
indent-width = 2
preview = true
target-version = "py38"
target-version = "py310"
lint.select = [
"F", # Pyflakes

View File

@@ -59,7 +59,7 @@ setup(name='tinygrad',
"bottle",
"ggml-python"
],
'webgpu': ["wgpu>=v0.19.0"],
'webgpu': ["wgpu==v0.18.1"],
'docs': [
"mkdocs",
"mkdocs-material",

View File

@@ -0,0 +1,46 @@
# ruff: noqa: E501
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.dtype import dtypes
from tinygrad.engine.realize import CompiledRunner
from tinygrad.engine.search import bufs_from_lin
from tinygrad.helpers import Timing
from tinygrad.ops import UOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 1, 1, 1), strides=(81920, 0, 64, 8, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.ADD, dtypes.half, arg=None, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.MUL, dtypes.half, arg=None, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 2560, 4, 10, 4, 10), strides=(0, 163840, 0, 64, 0, 8, 0, 1), offset=-9, mask=((0, 1), (0, 2), (0, 1), (0, 2560), (0, 4), (1, 9), (0, 4), (1, 9)), contiguous=False), View(shape=(2, 1, 1280, 8, 8, 2560, 3, 3), strides=(4096000, 0, 0, 40, 1, 1600, 440, 11), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 2560, 3, 3), strides=(0, 0, 23040, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()),
x17:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=4, src=()),
x17,)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=2)]
k = Kernel(ast)
for opt in opts: k.apply_opt(opt)
bufs = bufs_from_lin(k)
prg = CompiledRunner(k.to_program())
with Timing("run "):
prg(bufs, var_vals={}, wait=True)
# on M1 Max
# 11ms before block 9b0859d71780fef5cf3831e317f74e53f2483229
# 15ms after block cbcc1c20eb09a1342f6581cfbb99632bade982a8

View File

@@ -88,13 +88,13 @@ class TestKernelSpeed(unittest.TestCase):
# def test_gemm_1024(self): self._test_matmul(1024, nv_tflops=8, amd_tflops=7)
# def test_gemm_2048(self): self._test_matmul(2048, nv_tflops=50, amd_tflops=30)
def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=95, amd_tflops=70)
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=130, amd_tflops=70)
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=70)
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=400)
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=430, amd_gbs=400)
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=430, amd_gbs=380) # AMD was flaky at 400
# TODO: tiny7 is slower than tiny12
def test_conv_3x3_256_32_32_256_256(self): self._test_conv_3x3(256, 32, 32, 256, 256, nv_tflops=27, amd_tflops=18)
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -17,7 +17,7 @@ class TestConvShapetracker(unittest.TestCase):
# run it again to get the kernels
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is Ops.SINK]
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
for st in [x.st_arg for x in sched[0].ast.parents if x.op is Ops.LOAD]:
for st in [x.st_arg for x in sched[0].ast.toposort if x.op is Ops.LOAD]:
assert len(st.views) == 1
def test_conv_2x2_backward_one_view(self):
@@ -26,7 +26,7 @@ class TestConvShapetracker(unittest.TestCase):
conv(X).mean().backward()
si = X.grad.schedule()[-1]
print(si)
ldb = [x for x in si.ast.parents if x.op is Ops.LOAD][0]
ldb = [x for x in si.ast.toposort if x.op is Ops.LOAD][0]
st: ShapeTracker = ldb.st_arg.simplify()
# NOTE: st.real_size() is broken
print(si.inputs[0].size)

View File

@@ -120,7 +120,9 @@ class TestDType(unittest.TestCase):
data = [1., 2., 0., 0.5, -1.5, 5.25]
for dt in dtypes:
arr = np.asarray(data).astype(dt)
tin = Tensor(arr).numpy()
tensor = Tensor(arr)
if not is_dtype_supported(tensor.dtype): continue
tin = tensor.numpy()
tor = torch.as_tensor(arr).detach().numpy()
assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)

View File

@@ -79,7 +79,7 @@ def universal_test_unary(a, dtype, op):
np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
else: np.testing.assert_equal(tensor_value, numpy_value)
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
op = [x for x in ast.parents if x.op in GroupOp.Unary][0]
op = [x for x in ast.toposort if x.op in GroupOp.Unary][0]
assert op.dtype == dtype
def universal_test_cast(a, in_dtype, dtype):

View File

@@ -96,10 +96,18 @@ class TestLinearizer(unittest.TestCase):
lin = helper_linearizer_ast(sink, [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0]
stores = [u for u in lin.uops if u.op is Ops.STORE]
mutable_bufs = dedup(flatten([[x for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL] for u in stores]))
mutable_bufs = dedup(flatten([[x for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL] for u in stores]))
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg for u in mutable_bufs] == [0, 1]
def _test_no_nested_ranges(self, lins, skip=None):
for l in lins:
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_ACC])
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u in range_in_acc) or (u.op is Ops.ENDRANGE and u.src[0] in range_in_acc)]
for i,u in enumerate(ranges):
if skip and i in skip: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
@@ -130,11 +138,7 @@ class TestLinearizer(unittest.TestCase):
]
wanna_output = (x.numpy()-x.numpy().sum(-1, keepdims=True)).sum(-1).reshape(1,1)
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -194,11 +198,7 @@ class TestLinearizer(unittest.TestCase):
]
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0])
def test_triple_multireduce(self):
Tensor.manual_seed(0)
@@ -218,11 +218,7 @@ class TestLinearizer(unittest.TestCase):
sink = UOp(Ops.SINK, src=(store,))
wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5)
lins = helper_linearizer_ast(sink, [x0,x1,x2], wanna_output=[wanna_output])
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -270,11 +266,7 @@ class TestLinearizer(unittest.TestCase):
Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UNROLL, 1, 2), Opt(OptOps.UNROLL, 2, 2), Opt(OptOps.UNROLL, 3, 2)],
]
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i < 2: continue
assert ranges[i-2] != u or ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-2], ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0, 1])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -301,11 +293,7 @@ class TestLinearizer(unittest.TestCase):
]
wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5)
lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -339,11 +327,7 @@ class TestLinearizer(unittest.TestCase):
]
wanna_output = (x.numpy()-(x.numpy().sum(-1, keepdims=True)+np.exp2(x_p.numpy()).sum(-1, keepdims=True))).sum(-1).reshape(4, 1,1)
lins = helper_linearizer_ast(sink, [x,x_p], wanna_output=[wanna_output], opts=opts)
for l in lins:
ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])]
for i,u in enumerate(ranges):
if i == 0: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
self._test_no_nested_ranges(lins, [0])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_multiout_multireduce(self):
@@ -988,10 +972,10 @@ class TestLinearizer(unittest.TestCase):
# the first store is to lds and can be upcasted
assert stores[0].src[-1].dtype == dtypes.float.vec(4)
assert any(x.op is Ops.DEFINE_LOCAL for x in stores[0].sparents)
assert any(x.op is Ops.DEFINE_LOCAL for x in stores[0].toposort)
# the second store is to gds with no upcasts
assert stores[1].src[-1].dtype == dtypes.float
assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].sparents)
assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].toposort)
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
@@ -1155,7 +1139,7 @@ class TestLinearizer(unittest.TestCase):
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
loop_idxs = dedup(flatten([[y for y in x.sparents if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = dedup(flatten([[y for y in x.toposort if y.op is Ops.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0])
sizes = [x.arg[1] for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"

View File

@@ -95,7 +95,7 @@ class TestLinearizerDumb(unittest.TestCase):
print(prg.src)
if_uops = [u for u in k.uops if u.op is Ops.IF]
self.assertIn(len(if_uops), {1,2,3})
conditions = if_uops[0].src[0].sparents
conditions = if_uops[0].src[0].toposort
self.assertLessEqual(len(conditions), 9)
# this was a bug in embedding, someday we should fold this anyway

View File

@@ -1045,7 +1045,7 @@ class TestLinearizerFailures(unittest.TestCase):
ifs = [u for u in k.uops if u.op is Ops.IF]
self.assertEqual(len(ifs), 3)
#for st in k.uops.sink.src: self.assertEqual(len(st.src), 4)
self.assertLessEqual(len(ifs[0].src[0].sparents), 17)
self.assertLessEqual(len(ifs[0].src[0].toposort), 17)
def test_failure_45(self):
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(

View File

@@ -697,6 +697,9 @@ class TestOps(unittest.TestCase):
for val in range(1, 5):
helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val))
def test_selu(self):
helper_test_op([(45,65)], torch.nn.functional.selu, Tensor.selu)
helper_test_op([()], torch.nn.functional.selu, Tensor.selu)
def test_abs(self):
helper_test_op([(45,65)], torch.abs, Tensor.abs)

View File

@@ -199,7 +199,7 @@ class TestSchedule(unittest.TestCase):
r1 = (x - r0).sum(axis=0).div(2)
out = r0 + r1
schedule = check_schedule(out, 2)
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
reduceops = [x for si in schedule for x in si.ast.toposort if x.op is Ops.REDUCE_AXIS]
assert len(reduceops) == 2
def test_cache_reduce_multiple_children(self):
@@ -210,7 +210,7 @@ class TestSchedule(unittest.TestCase):
out0 = r0 + y
out1 = r1 + y
schedule = check_schedule([out0, out1], 4)
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
reduceops = [x for si in schedule for x in si.ast.toposort if x.op is Ops.REDUCE_AXIS]
assert len(reduceops) == 2
def test_fold_double_unary(self):
@@ -1673,7 +1673,7 @@ class TestIndexing(unittest.TestCase):
@track_rewrites(named=True)
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.sparents if x.op is Ops.VIEW and len(x.src) != 0])
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0])
class TestSwizzle(unittest.TestCase):
def test_swizzle_simple(self):
@@ -1755,7 +1755,7 @@ class TestSwizzle(unittest.TestCase):
# EXPAND is rewritten
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
# and pushed to the LOAD
new_load_st = unwrap([x for x in ret.parents if x.op is Ops.VIEW][0].st)
new_load_st = unwrap([x for x in ret.toposort if x.op is Ops.VIEW][0].st)
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))

View File

@@ -18,7 +18,7 @@ class TestTimeLinearizer(unittest.TestCase):
def test_reasonable_time(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0]
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.parents if x.op is Ops.LOAD}
memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.toposort if x.op is Ops.LOAD}
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
assert tm > 0 and tm != float('inf')

View File

@@ -1,5 +1,5 @@
# basic self-contained tests of the external functionality of tinygrad
import unittest
import unittest, random
from tinygrad import Tensor, Context, Variable, TinyJit, dtypes, Device
from tinygrad.helpers import IMAGE
@@ -41,13 +41,22 @@ class TestTiny(unittest.TestCase):
def test_jit(self):
cnt = 0
random.seed(0)
def new_rand_list(ln=10): return [random.randint(0, 100000) for _ in range(ln)]
@TinyJit
def fxn(a,b):
def fxn(a,b) -> Tensor:
nonlocal cnt
cnt += 1
return a+b
fa,fb = Tensor([1.,2,3]), Tensor([4.,5,6])
for _ in range(3): fxn(fa, fb)
for _ in range(3):
la,lb = new_rand_list(), new_rand_list()
fa,fb = Tensor(la), Tensor(lb)
ret = fxn(fa, fb)
# math is correct
self.assertListEqual(ret.tolist(), [a+b for a,b in zip(la, lb)])
# function is only called twice
self.assertEqual(cnt, 2)

View File

@@ -60,7 +60,7 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
new_sink = full_graph_rewrite(lower_sink)
et = time.perf_counter() - st
UOp.__init__ = old_init
print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.sparents)} -> {len(new_sink.sparents)}, creating {cnt[0]} uops")
print(f"rewrote in {et*1000:.2f} ms, from {len(lower_sink.toposort)} -> {len(new_sink.toposort)}, creating {cnt[0]} uops")
class TestGraphRewriteConst(unittest.TestCase):
def test_gep_const(self):
@@ -106,7 +106,7 @@ class TestGraphRewrite(unittest.TestCase):
a1 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
a2 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
sink = a1.sink(a2)
define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).sparents if x.op is Ops.DEFINE_VAR]
define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).toposort if x.op is Ops.DEFINE_VAR]
self.assertEqual(len(define_vars), 1)
def test_simple(self):
@@ -187,7 +187,7 @@ class TestGraphRewrite(unittest.TestCase):
print(sink.render())
self.assertEqual(sink.op, Ops.ADD)
self.assertEqual(sink.src[1].op, Ops.CONST)
self.assertEqual(len([x for x in sink.sparents if x.op is Ops.CONST]), 1)
self.assertEqual(len([x for x in sink.toposort if x.op is Ops.CONST]), 1)
class TestUOpGraph(unittest.TestCase):
def test_add_constant_fold(self):
@@ -433,8 +433,8 @@ class TestUOpGraph(unittest.TestCase):
c0 = UOp.const(dtypes.int, 0)
c2 = UOp.const(dtypes.int, 2)
cf = UOp.const(dtypes.float, 0.0)
r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 0, False))
r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 1, False))
r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), 0)
r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), 1)
alu = UOp(Ops.MUL, dtypes.int, (r2, r1))
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
uops = to_uops_list([store])
@@ -600,14 +600,14 @@ class TestLoadStoreFolder(unittest.TestCase):
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1
def test_two_load_fold(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(8)]
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 2
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 2
def test_simple_load_fold_gated(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
@@ -615,8 +615,8 @@ class TestLoadStoreFolder(unittest.TestCase):
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate),)) for i in range(4)]
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is Ops.LOAD][0]
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1
single_load = [x for x in sink.toposort if x.op is Ops.LOAD][0]
self.assertEqual(single_load.src[1].op, Ops.VECTORIZE)
def test_simple_load_dont_fold_different_gated(self):
@@ -627,14 +627,14 @@ class TestLoadStoreFolder(unittest.TestCase):
UOp.const(dtypes.float, 0))) for i in range(4)]
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 3
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 3
def test_simple_store_fold(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0))) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1
def test_simple_store_fold_gate(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
@@ -642,10 +642,11 @@ class TestLoadStoreFolder(unittest.TestCase):
load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1
one_store = [x for x in sink.sparents if x.op is Ops.STORE][0]
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1
one_store = [x for x in sink.toposort if x.op is Ops.STORE][0]
assert len(one_store.src) == 3
assert str(one_store.src[2]) == str(gate) # huh, why do i need str here?
_if_node = one_store.src[2]
assert _if_node.op == Ops.IF and _if_node.src[0] == gate
def test_simple_store_dont_fold(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
@@ -655,7 +656,7 @@ class TestLoadStoreFolder(unittest.TestCase):
UOp.const(dtypes.float, i))) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 3
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 3
class TestIFUOps(unittest.TestCase):
def test_create_ifs(self):
@@ -671,7 +672,7 @@ class TestIFUOps(unittest.TestCase):
store = UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, 0), gate), lbuf))
sink = UOp(Ops.SINK, dtypes.void, (store,))
sink = full_graph_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is Ops.IF]
if_uops = [u for u in sink.toposort if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
@@ -689,7 +690,7 @@ class TestIFUOps(unittest.TestCase):
stores = [UOp(Ops.STORE, dtypes.void, (gbuf.index(UOp.const(dtypes.int, i), gate), lbufs[i])) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
sink = full_graph_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is Ops.IF]
if_uops = [u for u in sink.toposort if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
@@ -705,7 +706,7 @@ class TestIFUOps(unittest.TestCase):
stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
sink = full_graph_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is Ops.IF]
if_uops = [u for u in sink.toposort if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:

View File

@@ -247,58 +247,59 @@ class TestConstantFolding(unittest.TestCase):
assert any(uop.op is Ops.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast"
class TestGatedStoreRewrite(unittest.TestCase):
@unittest.expectedFailure
def test_tiny_gate_store(self):
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0 * UOp.const(dtypes.int, 2)
idx = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem, gidx0 * UOp.const(dtypes.int, 2)))
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
store = UOp(Ops.STORE, dtypes.void, (gmem, idx, val, gate))
store = UOp(Ops.STORE, dtypes.void, (idx, val, gate))
uops = to_uops_list([store])
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.op is Ops.IF)
endif = next(u for u in uops if u.op is Ops.ENDIF)
assert endif.src[0] is if_uop
gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
gated_uops = tuple(uops[uops.index(if_uop)+1:uops.index(endif)])
self.assertEqual(len(gated_uops), 1)
self.assertIs(gated_uops[-1].op, Ops.STORE)
@unittest.expectedFailure
def test_gate_some_stores(self):
gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0*UOp.const(dtypes.int, 2)
idx = gidx0 * UOp.const(dtypes.int, 2)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val)]
uops = linearize_uop(stores)
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val)]
uops = to_uops_list(stores)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.op is Ops.IF)
endif = next(u for u in uops if u.op is Ops.ENDIF)
assert endif.src[0] is if_uop
gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
gated_uops = tuple(uops[uops.index(if_uop)+1:uops.index(endif)])
self.assertEqual(len(gated_uops), 1)
self.assertIs(gated_uops[-1].op, Ops.STORE)
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
@unittest.expectedFailure
def test_merge_ifs_alt(self):
gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0*UOp.const(dtypes.int, 2)
idx0 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem0, idx))
idx1 = UOp(Ops.INDEX, dtypes.float.ptr(), (gmem1, idx))
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val, gate)]
uops = linearize_uop(stores)
stores = [UOp.store(idx0, val, gate), UOp.store(idx1, val, gate)]
uops = to_uops_list(stores)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
ifs = [u for u in uops if u.op is Ops.IF]
endifs = [u for u in uops if u.op is Ops.ENDIF]
self.assertEqual(len(ifs), 1)
self.assertEqual(len(endifs), 1)
gated_uops = tuple(uops.uops[uops.uops.index(ifs[0])+1:uops.uops.index(endifs[0])])
gated_uops = tuple(uops[uops.index(ifs[0])+1:uops.index(endifs[0])])
self.assertEqual(len(gated_uops), 2)
for x in gated_uops: self.assertIs(x.op, Ops.STORE)
@@ -314,6 +315,16 @@ class TestLocalAccess(unittest.TestCase):
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
# NOTE: webgpu specific, since only webgpu performs bitpacking for uchar
@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Test local access with packed data type")
def test_local_packed(self):
uops = []
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(local=True), (), ('smem', 16))
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
self.assertEqual(_test_uops_result(dtypes.uint8, uops, sres), 42)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_indirect(self):
uops = []

View File

@@ -24,7 +24,7 @@ class TestWinograd(unittest.TestCase):
for i,s in enumerate(sched):
if s.ast.op is not Ops.SINK: continue
ops = s.ast.parents
ops = s.ast.toposort
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
l = Kernel(s.ast)
l.hand_coded_optimizations()

View File

@@ -492,6 +492,30 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)")
self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)")
def test_where_combine(self):
cond = Variable("x", 0, 3).lt(2)
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
aa = cond.where(a, a.ufix(0))
bb = cond.where(b, b.ufix(1))
self.helper_test_variable(aa, 0, 3, "(a if (x<2) else 0)")
self.helper_test_variable(bb, 0, 3, "(b if (x<2) else 1)")
self.helper_test_variable(aa+bb, 0, 6, "((a+b) if (x<2) else 1)")
self.helper_test_variable(aa.maximum(bb), 0, 3, "(max(a, b) if (x<2) else 1)")
# not combining because it increased total ALU
c = Variable("c", 0, 3)
cc = cond.where(c, c+1)
self.helper_test_variable(bb+cc, 0, 7, "((b if (x<2) else 1)+(c if (x<2) else (c+1)))")
# not combining # TODO: can combine if it can further simplify?
ab = cond.where(a, b)
ba = cond.where(b, a)
self.helper_test_variable(ab+ba, 0, 6, "((a if (x<2) else b)+(b if (x<2) else a))")
# not combining # TODO: can combine if one is identity element const
self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
MIN, MAX = 0, 10

View File

@@ -68,10 +68,11 @@ class Kernel:
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
self.vars: List[Variable] = self.ast.variables()
self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in GroupOp.Buffer]
# NOTE: this requires a specific order with the [::-1], this is likely a bug
self.bufs: List[UOp] = [x for x in self.ast.toposort if x.op in GroupOp.Buffer][::-1]
# get earlybufs, before any reduceops
earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.parents if x.op in GroupOp.Buffer]
earlybufs: List[UOp] = [x for reduceop in self.reduceops for x in reduceop.src[0].toposort if x.op in GroupOp.Buffer]
self.full_buf_index: int = self.bufs.index(earlybufs[0]) if earlybufs else 0
# NOTE: full_shape can be wrong if there's a tree of reduces
@@ -597,7 +598,7 @@ class Kernel:
@functools.cached_property
def name(self) -> str:
# kernel name (before late upcast)
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op in GroupOp.Buffer for x in self.ast.parents) else "E")
kernel_type = "r" if self.reduceop is not None else ("C" if all(x.op is Ops.SINK or x.op in GroupOp.Buffer for x in self.ast.toposort) else "E")
suffix = colored('_', 'BLACK').join([colored(x.render() if isinstance(x, UOp) else str(x), c) for x,c in zip(self.full_shape, self.colors())])
name = kernel_type + (f"{len(self.ast.src)}" if len(self.ast.src) > 1 else "") + "_" + suffix
@@ -712,7 +713,7 @@ class Kernel:
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong.
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
for _, group in itertools.groupby([x for x in self.ast.toposort if x.op in GroupOp.Buffer and x.src[0].op is Ops.DEFINE_GLOBAL],
key=lambda x: (x.op, x.src[0].arg)))
return ProgramSpec(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)

View File

@@ -1,92 +1,168 @@
from typing import List, Set, Dict, Tuple
import functools, heapq
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops, GroupOp
from tinygrad.dtype import dtypes
from tinygrad.helpers import DEBUG
from typing import List, Dict, Tuple, Optional
import collections
from dataclasses import dataclass
from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import dedup, flatten, partition
def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[UOp, None]], in_degree:Dict[UOp, int]):
if u in children: return srcs[u]
srcs[u] = {}
children[u] = []
for x in u.src:
srcs[u].update(get_children_dfs(x, children, srcs, in_degree))
if x.op is Ops.RANGE and x.arg[1]: srcs[u][x] = None
children[x].append(u)
in_degree[u] = len(u.src)
return srcs[u]
DONT_PLACE_IN_BLOCK = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.CONST,
Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKFORK, Ops.BLOCKSTART}
def disp(y:UOp) -> str:
if y.op is Ops.BLOCKSTART: return "w"+disp(y.src[0])
if y.op is Ops.IF: return f'IF{id(y)}'
if y.op is Ops.RANGE: return str(y.arg)
return "<NONE>"
@dataclass(frozen=True)
class BasicBlock:
ctx: Tuple[UOp, ...]
lst: Tuple[UOp, ...]
end: Optional[UOp] = None
def __repr__(self):
return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\
f"{[disp(y) for y in self.ctx]} {len(self.lst)}" + "\n" + '\n'.join([str(x.op) for x in self.lst])
def append_to_block(ctx, x:UOp):
block_ctxs, children = ctx
new_srcs: List[UOp] = []
to_append: List[UOp] = []
new_blocks: Dict[Tuple[UOp, ...], List[UOp]] = {}
in_this_block = set(x.arg.lst)
for u in x.src:
if u.op in DONT_PLACE_IN_BLOCK or len([y for y in children[u] if y not in in_this_block]) > 0:
# if it's a fork or not placed, we don't place it
new_srcs.append(u)
elif (block_ctx:=block_ctxs[u]) == x.arg.ctx:
# if it's the same context, we place the UOp in this block and append the parents to it's srcs
new_srcs += list(u.src)
to_append.append(u)
else:
# otherwise, we create a new block with this UOp
new_blocks.setdefault(block_ctx, []).append(u)
if len(to_append) == 0 and len(new_blocks) == 0: return None
for rng,lst in new_blocks.items():
new_block = UOp(Ops.BLOCK, dtypes.void, tuple(dedup(flatten(y.src for y in lst))), BasicBlock(rng, tuple(lst)))
lrng = list(rng)
for r in rng[::-1]:
if r not in x.arg.ctx and r.op is not Ops.BLOCKSTART:
lrng.remove(r)
new_block = UOp(Ops.BLOCKEND, src=(new_block,),
arg=BasicBlock(tuple(lrng), (UOp(Ops.ENDIF if r.op is Ops.IF else Ops.ENDRANGE, src=(r,)),), r))
new_srcs.append(new_block)
return UOp(Ops.BLOCK, dtypes.void, tuple(dedup(new_srcs)), BasicBlock(x.arg.ctx, tuple(to_append)+x.arg.lst))
make_basic_blocks = PatternMatcher([
(UPat(Ops.SINK, name="x"), lambda x: UOp(Ops.BLOCK, src=x.src, arg=BasicBlock((), (x,)))),
(UPat(Ops.BLOCK, name="x"), append_to_block),
])
def block_merge(ctx, x:UOp):
# ctx is children here
if x.op is Ops.BLOCKEND:
# if it's a BLOCKEND, see if we are done with placement. if all the children of the range are in here
in_this_block = set(x.arg.lst)
if len([y for y in ctx[x.arg.end] if y not in in_this_block]) == 0:
# find the parent block that has the BLOCKSTART in the ctx
parent_blocks = [y for y in x.src if y.op is Ops.BLOCK and UOp(Ops.BLOCKSTART, src=(x.arg.end,)) in y.arg.ctx]
if len(parent_blocks) == 1:
parent_block = parent_blocks[0]
# range needs DEFINE_ACC to be before the range (never in DEFINE_ACC for if)
early_ops, late_ops = partition(x.arg.lst, lambda y: y.op is Ops.DEFINE_ACC and x.arg.end in y.src)
return UOp(Ops.BLOCK, dtypes.void, tuple(y for y in x.src if y is not parent_block)+parent_block.src,
BasicBlock(tuple(y for y in x.arg.ctx if y is not x.arg.end), tuple(early_ops)+parent_block.arg.lst+tuple(late_ops)))
assert not len(parent_blocks)
new_srcs: List[UOp] = []
to_append: List[UOp] = []
new_ctx = x.arg.ctx
placed = set()
for u in x.src:
if u.op is Ops.BLOCK and (tuple(u.arg.ctx) == tuple(x.arg.ctx) or (x.arg.end is not None and x.arg.end in u.arg.ctx)):
# NOTE: this can't appear in srcs twice or it would be a BLOCKFORK
new_ctx += u.arg.ctx
new_srcs += list(u.src)
to_append += u.arg.lst
elif u.op is Ops.BLOCKFORK and len([y for y in x.src if y is u]) == u.arg: # block fork appears # of times in srcs
if u not in placed:
new_srcs += list(u.src)
placed.add(u)
else:
# keep it in srcs
new_srcs.append(u)
if len(to_append) == 0 and len(placed) == 0: return None
return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(dedup(new_ctx)), tuple(to_append)+x.arg.lst, x.arg.end))
pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
# filter nodes that don't link to a sink
# BFS toposort
# get children and all block contexts
temp_block_ctxs: Dict[UOp, List[UOp]] = {}
children: Dict[UOp, List[UOp]] = {}
range_srcs: Dict[UOp, Dict[UOp, None]] = {}
in_degree: Dict[UOp, int] = {}
get_children_dfs(sink, children, range_srcs, in_degree)
for u in sink.toposort:
this_block_ctx: List[UOp] = []
for s in u.src:
# save children
children.setdefault(s, []).append(u)
# compute block ctx
if s.op in {Ops.RANGE, Ops.IF}: this_block_ctx.append(s)
# don't flow (fully) through assign and store
elif s.op is Ops.STORE:
# ugh, deal with non-reduce locals. probably wrong
if isinstance(s.src[0].dtype, PtrDType) and s.src[0].dtype.local:
idx_context, store_context = temp_block_ctxs[s.src[0]], temp_block_ctxs[s]
this_block_ctx += [x for x in store_context if x not in idx_context and x.op is Ops.RANGE]
elif s.op is Ops.ASSIGN:
# flow though assign, but remove the ranges used in the assign
assert s.src[0].op is Ops.DEFINE_ACC
this_block_ctx += [x for x in temp_block_ctxs[s.src[1]] if x not in s.src[0].src[1:]]
else:
# flow though everything else
this_block_ctx += temp_block_ctxs[s]
temp_block_ctxs[u] = dedup(sorted(this_block_ctx, key=lambda x: x.tuplize))
@functools.lru_cache(None)
def get_recursive_children(x:UOp, end:Ops, include_self=False) -> Set[UOp]:
if x.op is Ops.SINK: return set()
return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
# make final block_ctxs, add BLOCKSTART to block_ctxs for IF and RANGE
block_ctxs: Dict[UOp, Tuple[UOp, ...]] = {}
for u in sink.toposort:
block_ctxs[u] = ((UOp(Ops.BLOCKSTART, src=(u,)),) + tuple(temp_block_ctxs[u])) if u.op in {Ops.IF, Ops.RANGE} else tuple(temp_block_ctxs[u])
# scope children impact the toposort and END* insertion
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
range_phi = {r:[p for p in scope_children[r] if p.op is Ops.ASSIGN] for r in scope_children if r.op is Ops.RANGE}
# TODO: there's probably a clever way to remove this while loop
while 1:
sink = graph_rewrite(sink, make_basic_blocks, ctx=(block_ctxs, children))
# assign priorities
def get_priority(u:UOp):
priority = 0
# prefer ranges that depend on the least number of independent ranges
if u.op is Ops.RANGE and u.arg[1]:
priority += u.arg[0]
for p in range_phi[u]:
priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
elif u.op is Ops.CONST:
# place consts first here, they don't do anything and it can cause issues with DEFINE_ACC
priority -= 100000000000
else:
# prefer uops that are loop children
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is Ops.RANGE and u in ss])
if u.op is Ops.IF and len(u.src) == 1: priority += 10000000 # if penalty
return priority
priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}
# add BLOCKFORK (slow!)
block_parent_count = collections.Counter(flatten([x.src for x in sink.toposort if x.op is Ops.BLOCK]))
non_block_parents = set(flatten([x.src for x in sink.toposort if x.op is not Ops.BLOCK]))
forks = {u:UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCK, src=u.src, arg=BasicBlock(block_ctxs[u], (u,))),), arg=child_count)
for u,child_count in block_parent_count.items() if u.op not in DONT_PLACE_IN_BLOCK and child_count > 1 and u not in non_block_parents}
# prevent priority inversion
@functools.lru_cache(None)
def fix_priority(u:UOp, lowest_priority):
if u.op in {Ops.CAST, Ops.BITCAST, *GroupOp.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
priorities[u] = min(priorities[u], lowest_priority)
if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here)
for x in u.src: fix_priority(x, priorities[u])
fix_priority(sink, 0)
if not len(forks): break
sink = sink.substitute(forks)
# NOTE: the compare should never make it all the way to u
queue:List[Tuple[int, Tuple, UOp]] = []
def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u))
# combine matching BLOCKENDS
blockends_to_arg: Dict[UOp, List[UOp]] = {}
for be in sink.toposort:
if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be)
new_forks = {}
for k,v in blockends_to_arg.items():
# NOTE: if any BLOCKEND is the parent of any other with the same arg, this algo fails
if len(v) > 1:
out = UOp(Ops.BLOCKFORK, src=(UOp(Ops.BLOCKEND, src=tuple(flatten(x.src for x in v)),
arg=BasicBlock(tuple(dedup(flatten([y.arg.ctx for y in v]))), v[0].arg.lst, k)),), arg=len(v))
for u in v: new_forks[u] = out
sink = sink.substitute(new_forks)
for u in children:
if in_degree[u] == 0: push(u)
# final rewrite to merge all blocks into one
sink = graph_rewrite(sink, pm_block_merge, ctx=children)
scope_end: Dict[UOp, UOp] = {}
_uops: List[UOp] = []
while queue:
p,_,x = heapq.heappop(queue)
if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg)
if x in scope_children: scope_end[x] = x
if x.op is Ops.DEFINE_ACC:
idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE])
_uops.insert(idx, x)
else: _uops.append(x)
for u, ss in scope_children.items():
if x in ss:
ss.remove(x)
if len(ss) == 0: scope_end[u] = x
for u in children[x]:
in_degree[u] -= 1
if in_degree[u] == 0: push(u)
# end scopes in toposort order
for u, x in scope_end.items(): _uops.insert(_uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], dtypes.void, (u,)))
# there should just be one block left, with a few parents with 0 srcs
assert sink.op is Ops.BLOCK
_uops = sorted(dedup(sink.src), key=lambda x: x.tuplize)
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
_uops += sink.arg.lst
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
if not skip_check: type_verify(_uops)

View File

@@ -55,8 +55,8 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
full_shape = ast.full_shape
first_upcasted = len(full_shape)-ki.upcasted
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS))
local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.toposort if x.op is Ops.REDUCE_AXIS))
local_loads = [x for x in ast.toposort if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
group_for_reduces = sum([any(l.st_arg.shape[i]!=ast.src[0].st_arg.shape[i] for l in local_loads) for i in range(first_reduce,first_upcasted)])
global_dims = first_reduce-ki.local_dims
@@ -71,10 +71,10 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
else:
# all loops are RANGES
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, False)) for i,g in enumerate(full_shape[:first_reduce])]
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[:first_reduce])]
# reduce loops
idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), (i, True))
idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i)
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
# upcast loops
@@ -85,7 +85,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
# late indexes (group for reduce)
ridxs = idxs[:]
for a in range(first_reduce, first_reduce+group_for_reduces):
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), (1000+a, True))
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), 1000+a)
return IndexContext(idxs, ridxs)

View File

@@ -221,7 +221,7 @@ def no_vectorized_wmma(wmma:UOp):
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents)
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.toposort)
if len(reduce_unparented) == 0: return None
new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented))
ret = new_acc.assign(new_acc.alu(alu.op, ret))
@@ -447,7 +447,7 @@ devectorize = PatternMatcher([
])
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optional[UOp]=None) -> Optional[UOp]:
if store_gate not in [gate.src[0] for gate in val.sparents if gate.op is Ops.IF]: return None
if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None
# remove the gate from the index
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
@@ -483,6 +483,9 @@ pm_render = PatternMatcher([
# move masks of loads/stores
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
# gate any stores that aren't gated with ifs
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
])
# *** uop graph ***

View File

@@ -20,7 +20,7 @@ class _Device:
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix:str) -> Compiled:
cpn = multiprocessing.current_process().name
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}"
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"
x = ix.split(":")[0].upper()
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \
if (cname.lower() == x.lower() + "device")][0](ix)

View File

@@ -54,28 +54,28 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, ctx, buffers, cache).view(buf.st)
return ret
# make things that can't be images not images
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}")
# hack the underlying buffer too
buf.dtype = buf.buffer.dtype = buf.dtype.base
assert not buf.is_realized, "can't fixup allocated buffer"
buf.buffer.options = None
assert buf.op is not None, f"base must be base itself {buf}"
dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base
# make things that can't be images not images
dtype = buf.dtype
if isinstance(dtype, ImageDType) and (prod(buf.shape) != prod(dtype.shape) or not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
assert buf.realized is None, "can't fixup allocated buffer"
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
dtype = buf.dtype.base
# hack the underlying buffer too
buf.buffer.dtype = buf.dtype = dtype
buf.buffer.options = None
if buf.is_realized:
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers))
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
buffers[ubuf] = buf.buffer
op = None
elif buf.op is Ops.ASSIGN:
target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs]
ctx.assigns.add(ubuf:=target.buf_uop)
op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg)
op = UOp(Ops.ASSIGN, dtype.base, (ubuf, new_val), buf.arg)
else:
ubuf = UOp.new_buffer(buf.device, buf.size, buf.dtype, num=len(buffers))
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
buffers[ubuf] = buf.buffer
op = UOp(buf.op, dtype, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
if op is not None:
ctx.lazybufs[ubuf] = buf
@@ -126,7 +126,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
# push VIEW to stores

View File

@@ -41,9 +41,7 @@ def fully_flatten(l):
return [l]
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
def ceildiv(num, amt):
ret = -(num//-amt)
return ret if not isinstance(ret, float) else int(ret)
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
def data64(data:Any) -> Tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint
def data64_le(data:Any) -> Tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint
@@ -52,10 +50,9 @@ def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
return {k:v for d in ds for k,v in d.items()}
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
a:List[T] = []
b:List[T] = []
for s in itr: (a if fxn(s) else b).append(s)
return a,b
ret:Tuple[List[T], List[T]] = ([], [])
for s in itr: (ret[0] if fxn(s) else ret[1]).append(s)
return ret
def unwrap(x:Optional[T]) -> T:
assert x is not None
return x
@@ -268,7 +265,8 @@ def from_mv(mv:memoryview, to_type=ctypes.c_char):
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char):
return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
@functools.lru_cache(maxsize=None)
def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
class CStruct(ctypes.Structure):

View File

@@ -248,10 +248,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
@functools.cached_property # parents with self
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
def toposort(self) -> Dict[UOp, None]:
nodes: Dict[UOp, None] = {}
# NOTE: this is a lot faster than the comprehension in parents
for parent in self.src: nodes.update(parent.toposort)
nodes[self] = None
return nodes
@functools.cached_property
def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
@@ -273,7 +277,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape)
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
@property
def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape
@property
@@ -334,8 +338,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
def alu(self, arg, *src:UOp):
out_dtype = (self, *src)[-1].dtype
if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None:
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(arg, out_dtype, (self,)+src)
@staticmethod
def const(dtype:DType, b:ConstLike):
@@ -378,14 +381,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** uop Buffer stuff ***
buffer_num = itertools.count(0)
@staticmethod
def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype)))
def new_buffer(device:str, size:int, dtype:DType) -> UOp: return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype)))
@functools.cached_property
def device(self) -> str:
match self.op:
case Ops.COPY: return self.arg
case Ops.BUFFER: return self.arg[1][0]
case _: return self.src[0].device
def device(self) -> str: return self.arg[1][0] if self.op is Ops.BUFFER else self.src[0].device
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self
@@ -412,12 +412,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def val(self) -> int: return self.unbind()[1]
def vars(self) -> Set[UOp]:
bound_vars = set([x for x in self.sparents if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
bound_vars = set([x for x in self.toposort if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
bound_var_base = set(x.src[0] for x in bound_vars)
all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR])
all_vars = set([x for x in self.toposort if x.op is Ops.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in GroupOp.Buffer]
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.toposort if x.op in GroupOp.Buffer]
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
# *** uop symbolic stuff ***
@@ -474,7 +474,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def _sym_fxn(self):
sself = self.simplify()
varnames = tuple(x.arg[0] for x in sself.sparents if x.op is Ops.DEFINE_VAR)
varnames = tuple(x.arg[0] for x in sself.toposort if x.op is Ops.DEFINE_VAR)
# TODO: sanitize varnames, or don't use naked eval while staying fast
return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
@@ -532,10 +532,10 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
if ignore_indexing:
for u in uops:
if u.op in {Ops.LOAD, Ops.STORE}:
dont_count = dont_count.union(u.src[0].sparents)
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents)
dont_count = dont_count.union(u.src[0].toposort)
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort)
elif u.op is Ops.IF:
dont_count = dont_count.union(u.src[0].sparents)
dont_count = dont_count.union(u.src[0].toposort)
for u in uops:
if u.op is Ops.RANGE:
mult_stack.append(mults)
@@ -622,7 +622,7 @@ class UPat(MathTrait):
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
def alu(self, op:Ops, *src:UPat):
asrc = (self,)+src
return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
def printable(self:UPat) -> str:
try: return lines(self.location[0])[self.location[1]-1].strip()
@@ -802,7 +802,7 @@ spec = PatternMatcher([
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
(UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
(UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype and isinstance(rng.arg, int)),
(UPat(Ops.SPECIAL, src=()), lambda: True),
# TODO: confirm the args of both of these are shapetrackers
@@ -1046,7 +1046,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)])
# try checking the whole clause
if expr in uop.sparents:
if expr in uop.toposort:
candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))])
for candidate in candidates:
@@ -1061,7 +1061,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
def _valid_priority(v: UOp, valids:List[UOp]):
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
try: return sum(-1 if parse_valid(v)[0] in other.parents else 0 for other in valids)
try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids)
except ValueError: return 0
def simplify_valid(valid:UOp) -> Optional[UOp]:
@@ -1132,6 +1132,9 @@ symbolic = symbolic_simple+PatternMatcher([
# a conditional with the same results either way is a noop, also fold const conditionals
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# alu of two where with same conds can combine, only do if true branch or false branch is const
(UPat(GroupOp.Binary, name="alu", src=(UPat.var("c").where(UPat.var("t"), UPat.var("f")), UPat.var("c").where(UPat.var("tt"), UPat.var("ff")))), \
lambda alu,c,t,tt,f,ff: c.where(t.alu(alu.op, tt), f.alu(alu.op, ff)) if t.op == tt.op == Ops.CONST or f.op == ff.op == Ops.CONST else None),
# ALU min==max -> CONST (slow!)
(UPat(GroupOp.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding

View File

@@ -44,7 +44,7 @@ class ProgramSpec:
for u in self.uops:
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL])
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL])
if u.op is Ops.SPECIAL:
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
if u.arg[0][0] == 'i': self.local_size = None

View File

@@ -57,9 +57,6 @@ extra_pm = PatternMatcher([
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(UPat(Ops.BITCAST, name="x"),
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
# gate any stores that aren't gated with ifs
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
])
@@ -130,7 +127,7 @@ class CStyleLanguage(Renderer):
# mark buffers that we store to writable
if u.op is Ops.STORE:
for up in u.src[0].sparents:
for up in u.src[0].toposort:
if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
# naming

View File

@@ -60,13 +60,13 @@ llvm_rewrite = PatternMatcher([
# range
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"),
f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n"
f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"),
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n"
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg}, label %loop_exit_{x.src[0].arg}\nloop_exit_{x.src[0].arg}:"),
# if
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
@@ -85,10 +85,6 @@ class LLVMRenderer(Renderer):
(UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
# rewrite cast to bool to CMPNE 0
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
# *** also in cstyle ***
# gate any stores that aren't gated with ifs
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
# rewrite MAX to CMPLT + WHERE
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
])
@@ -135,7 +131,7 @@ class LLVMRenderer(Renderer):
for x in acc_to_assign:
if u in x.src: # if this range is relevent for this acc
vc += 1
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]")
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg}]")
r[x] = f"%acc{vc}"
# output the function

View File

@@ -54,7 +54,7 @@ ptx_matcher = PatternMatcher([
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None),
])
def mem_type(x: UOp): return 'shared' if x.src[0].op is Ops.DEFINE_LOCAL or any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].parents) else 'global'
def mem_type(x: UOp): return 'shared' if any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].toposort) else 'global'
def render_wmma(ctx: "PTXRenderer", x: UOp):
assert ctx.wmma_r, "registry values for wmma must be populated"
@@ -76,8 +76,7 @@ def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.i
string_rewrite = PatternMatcher([
(UPat.cvar("x", dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"),
(UPat.cvar("x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"),
(UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var"), UPat.var("pred")), allow_any_len=True), lambda ctx, x, bidx, var, pred=None:
f"{f'@{ctx.r[pred]} ' if pred is not None and pred.op is not Ops.IF else ''}st.{mem_type(bidx)}" + \
(UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx, x, bidx, var: f"st.{mem_type(bidx)}" + \
f"{f'.v{cnt}' if ((cnt:=var.dtype.count)>1) else ''}.{ctx.mem_types[var.dtype.scalar()]} " + \
f"[{ctx.r[bidx]}+0], {('{' + ', '.join(ctx.r[var]) + '}') if var.dtype.count > 1 else ctx.r[var]};"),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"),

View File

@@ -68,7 +68,7 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"), lambda ctx,x: f"bitcast<u32>({x.arg})" \
if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
(UPat(Ops.CONST, dtype=dtypes.int32, name="x"), lambda ctx,x: f"bitcast<i32>({x.arg}u)" if x.arg >= 0x80000000 else f"{x.arg}"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{type_map[x.dtype.base]}, {x.arg[1]}>;"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.render_buf_dt(x.dtype.base)}, {x.arg[1]}>;"),
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]})"),
@@ -78,7 +78,7 @@ class WGSLRenderer(CStyleLanguage):
lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\
# (load & mask) | var -> mask = v.src[0].src[1], var = v.src[1]
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\natomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
f"atomicAnd(&{ctx[b]},{ctx[v.src[0].src[1]]});\n atomicAdd(&{ctx[b]},{ctx[v.src[1]]});" if b.src[0].dtype.itemsize < 4 \
else f"{ctx[b]} = {ctx[v]};"),
# fix nan check: 'a != a -> is_nan()'
(UPat.var("a") != UPat.var("a"), lambda ctx,a: f"is_nan({ctx[a]})"),
@@ -86,7 +86,7 @@ class WGSLRenderer(CStyleLanguage):
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
def render_buf(self, dt:DType, rw:bool) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and (dt.itemsize < 4) else buffer_map[dt.base]}"
def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and dt.itemsize < 4 else buffer_map[dt.base]}"
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
if not local_size: local_size = [1]
@@ -99,6 +99,6 @@ class WGSLRenderer(CStyleLanguage):
prg += "@group(0) @binding(0)\nvar<uniform> INFINITY : f32;\n"
prg += "\n".join((external_local_bufs or [])+[f"@group(0) @binding({next(bind_it)+1})" +
f"{'var<storage,read_write>' if isinstance(dtype, PtrDType) else 'var<uniform>'}" +
f"{name}:{f'array<{self.render_buf(dtype.base,rw)}>' if isinstance(dtype, PtrDType) else buffer_map[dtype]};" for name,(dtype,rw) in bufs])
f"{name}:{f'array<{self.render_buf_dt(dtype.base,rw)}>' if isinstance(dtype, PtrDType) else buffer_map[dtype]};" for name,(dtype,rw) in bufs])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>,"
return prg + "@builtin(local_invocation_id) lindex: vec3<u32>) {\n" + "\n".join(kernel) + "\n}"

View File

@@ -82,7 +82,12 @@ class AMDComputeQueue(HWQueue):
def exec(self, prg:AMDProgram, args_state:AMDArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
self.acquire_mem(gli=0, gl2=0)
user_regs = [*data64_le(prg.dev.scratch.va_addr), 0xffffffff, 0xc00000] if prg.enable_private_segment_sgpr else []
if prg.enable_private_segment_sgpr:
scratch_hilo = data64_le(prg.dev.scratch.va_addr)
# sgpr word1 bit31 enables swizzle
# sgpr word3 = 0x14 << 12 | 2 << 28 | 2 << 21 | 1 << 23
user_regs = [scratch_hilo[0], scratch_hilo[1] | 1 << 31, 0xffffffff, 0x20c14000] if prg.enable_private_segment_sgpr else []
else: user_regs = []
if prg.enable_dispatch_ptr:
dp = hsa.hsa_kernel_dispatch_packet_t.from_address(dp_addr:=args_state.ptr + prg.kernargs_segment_size)
@@ -370,12 +375,16 @@ class AMDDevice(HCQCompiled):
max_cu_id = self.properties['simd_count'] // self.properties['simd_per_cu'] - 1
max_wave_id = self.properties['max_waves_per_simd'] * self.properties['simd_per_cu'] - 1
self.max_private_segment_size = 4096
wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256) # gfx11 requires alignment of 256
# <gfx103 requires alignment of 1024, >=gfx11 requires 256
wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256 if self.target >= 110000 else 1024)
self.scratch_len = (max_cu_id + 1) * self.properties['max_slots_scratch_cu'] * wave_scratch_len
self.scratch = self._gpu_alloc(self.scratch_len, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
self.has_scratch_base_registers = self.target >= 110000
engines = self.properties['array_count'] // self.properties['simd_arrays_per_engine']
self.tmpring_size = (wave_scratch_len // 256) << 12 | (self.scratch_len // (wave_scratch_len * engines))
waves = wave_scratch_len // (256 if self.target >= 110000 else 1024)
# >=gfx11 wavesize is per SE
wavesize = self.scratch_len // ((wave_scratch_len * engines) if self.target >= 110000 else wave_scratch_len)
self.tmpring_size = waves << 12 | wavesize
# https://gitlab.freedesktop.org/agd5f/linux/-/blob/a1fc9f584c4aaf8bc1ebfa459fc57a3f26a290d8/drivers/gpu/drm/amd/amdkfd/kfd_queue.c#L391
sgrp_size_per_cu, lds_size_per_cu, hwreg_size_per_cu = 0x4000, 0x10000, 0x1000

View File

@@ -58,8 +58,8 @@ class WebGpuAllocator(Allocator):
class WebGpuDevice(Compiled):
def __init__(self, device:str):
adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance")
adapter = wgpu.gpu.request_adapter(power_preference="high-performance")
timestamp_supported = wgpu.FeatureName.timestamp_query in adapter.features
wgpu_device = adapter.request_device_sync(required_features=[wgpu.FeatureName.timestamp_query] if timestamp_supported else [])
wgpu_device = adapter.request_device(required_features=[wgpu.FeatureName.timestamp_query] if timestamp_supported else [])
super().__init__(device, WebGpuAllocator(wgpu_device), WGSLRenderer(), Compiler(),
functools.partial(WebGPUProgram, (wgpu_device, timestamp_supported)))

View File

@@ -81,17 +81,17 @@ class ShapeTracker:
if c.op is Ops.RANGE: ret[c.arg[0]] = 1
if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg
if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg
used_ranges = [x.arg[0] for x in idx.sparents if x.op is Ops.RANGE]
used_ranges = [x.arg[0] for x in idx.toposort if x.op is Ops.RANGE]
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
if not ignore_valid:
for masked_axis in [x.arg[0] for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None
for masked_axis in [x.arg[0] for x in valid.toposort if x.op is Ops.RANGE]: ret[masked_axis] = None
return tuple(ret)
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def axis_is_masked(self, axis:int) -> bool:
_, valid = self.to_indexed_uops()
return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is Ops.RANGE]
return axis in [x.arg[0] for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:

View File

@@ -2688,6 +2688,19 @@ class Tensor(SimpleMathTrait):
"""
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def selu(self, alpha=1.67326, gamma=1.0507):
"""
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
- Described: https://paperswithcode.com/method/selu
- Paper: https://arxiv.org/abs/1706.02515v5
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).selu().numpy())
```
"""
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
def swish(self):
"""
See `.silu()`

View File

@@ -60,7 +60,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
assert isinstance(x, UOp)
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
for u in x.sparents:
for u in x.toposort:
if u.op is Ops.CONST: continue
label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
for idx,x in enumerate(u.src):
@@ -92,7 +92,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
# sanity check
if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
# update ret data
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not Ops.CONST])
g.changed_nodes.append([id(x) for x in u1.toposort if x.op is not Ops.CONST])
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
g.graphs.append(sink:=new_sink)
return g