devectorize prereqs [pr] (#9404)

This commit is contained in:
George Hotz
2025-03-11 12:33:29 +08:00
committed by GitHub
parent beed00eabe
commit 2780e2027e
3 changed files with 7 additions and 69 deletions

View File

@@ -1,8 +1,8 @@
from typing import List
import unittest, time, pytest
from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG, AMX
from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher
from tinygrad.helpers import DEBUG
from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher, track_rewrites
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.devectorizer import full_graph_rewrite, graph_rewrite, sym
@@ -502,7 +502,9 @@ class TestUOpGraph(unittest.TestCase):
# ranges are closed in the right order
self.assertEqual(endranges[-1].src[0], ranges[0])
@track_rewrites()
def expander_rewrite(sink): return graph_rewrite(sink, sym + expander)
@track_rewrites()
def float4_rewrite(sink): return full_graph_rewrite(sink, Renderer())
class TestExpander(unittest.TestCase):
@@ -652,72 +654,6 @@ class TestExpander(unittest.TestCase):
sink = expander_rewrite(sink)
print(sink)
class TestLoadStoreFolder(unittest.TestCase):
def test_simple_load_fold(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(4)]
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1
@unittest.skipIf(Device.DEFAULT in {"CPU"} and AMX, "CPU with AMX upcasts float up to size 16")
def test_two_load_fold(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(8)]
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 2
def test_simple_load_fold_gated(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp(Ops.DEFINE_VAR, dtypes.bool)
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate),)) for i in range(4)]
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1
single_load = [x for x in sink.toposort if x.op is Ops.LOAD][0]
self.assertEqual(single_load.src[1].op, Ops.VECTORIZE)
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp.variable("g1", False, True, dtypes.bool)
gate2 = UOp.variable("g2", False, True, dtypes.bool)
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate if i == 0 else gate2),
UOp.const(dtypes.float, 0))) for i in range(4)]
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 3
def test_simple_store_fold(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0))) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1
def test_simple_store_fold_gate(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp.variable("g1", False, True, dtypes.bool)
load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1
one_store = [x for x in sink.toposort if x.op is Ops.STORE][0]
assert len(one_store.src) == 3
_if_node = one_store.src[2]
assert _if_node.op == Ops.IF and _if_node.src[0] == gate
def test_simple_store_dont_fold(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp.variable("g1", False, True, dtypes.bool)
gate2 = UOp.variable("g2", False, True, dtypes.bool)
load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate if i == 0 else gate2),
UOp.const(dtypes.float, i))) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 3
class TestIFUOps(unittest.TestCase):
def test_create_ifs(self):
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)

View File

@@ -54,6 +54,8 @@ class PtrDType(DType):
def vec(self, sz:int) -> DType:
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
if sz == 1: return self # sz=1 is a scalar
if isinstance(self, ImageDType):
return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size, self.shape)
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size)
def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
@property

View File

@@ -688,7 +688,7 @@ def get_location() -> tuple[str, int]:
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
"symbolic.py", "expander.py", "lowerer.py", "cstyle.py",
"linearize.py"}:
"linearize.py", "devectorizer.py"}:
frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno
@functools.lru_cache(None)