mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
devectorize prereqs [pr] (#9404)
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from typing import List
|
||||
import unittest, time, pytest
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.helpers import DEBUG, AMX
|
||||
from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher, track_rewrites
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen.devectorizer import full_graph_rewrite, graph_rewrite, sym
|
||||
@@ -502,7 +502,9 @@ class TestUOpGraph(unittest.TestCase):
|
||||
# ranges are closed in the right order
|
||||
self.assertEqual(endranges[-1].src[0], ranges[0])
|
||||
|
||||
@track_rewrites()
|
||||
def expander_rewrite(sink): return graph_rewrite(sink, sym + expander)
|
||||
@track_rewrites()
|
||||
def float4_rewrite(sink): return full_graph_rewrite(sink, Renderer())
|
||||
|
||||
class TestExpander(unittest.TestCase):
|
||||
@@ -652,72 +654,6 @@ class TestExpander(unittest.TestCase):
|
||||
sink = expander_rewrite(sink)
|
||||
print(sink)
|
||||
|
||||
class TestLoadStoreFolder(unittest.TestCase):
|
||||
def test_simple_load_fold(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(4)]
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_two_load_fold(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(8)]
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 2
|
||||
|
||||
def test_simple_load_fold_gated(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
gate = UOp(Ops.DEFINE_VAR, dtypes.bool)
|
||||
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate),)) for i in range(4)]
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1
|
||||
single_load = [x for x in sink.toposort if x.op is Ops.LOAD][0]
|
||||
self.assertEqual(single_load.src[1].op, Ops.VECTORIZE)
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
||||
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate if i == 0 else gate2),
|
||||
UOp.const(dtypes.float, 0))) for i in range(4)]
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 3
|
||||
|
||||
def test_simple_store_fold(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0))) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1
|
||||
|
||||
def test_simple_store_fold_gate(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i)), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 1
|
||||
one_store = [x for x in sink.toposort if x.op is Ops.STORE][0]
|
||||
assert len(one_store.src) == 3
|
||||
_if_node = one_store.src[2]
|
||||
assert _if_node.op == Ops.IF and _if_node.src[0] == gate
|
||||
|
||||
def test_simple_store_dont_fold(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
||||
load = [UOp(Ops.STORE, dtypes.float, (buf.index(UOp.const(dtypes.int, i), gate if i == 0 else gate2),
|
||||
UOp.const(dtypes.float, i))) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.toposort if x.op is Ops.STORE]) == 3
|
||||
|
||||
class TestIFUOps(unittest.TestCase):
|
||||
def test_create_ifs(self):
|
||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
|
||||
@@ -54,6 +54,8 @@ class PtrDType(DType):
|
||||
def vec(self, sz:int) -> DType:
|
||||
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
||||
if sz == 1: return self # sz=1 is a scalar
|
||||
if isinstance(self, ImageDType):
|
||||
return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size, self.shape)
|
||||
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size)
|
||||
def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
|
||||
@property
|
||||
|
||||
@@ -688,7 +688,7 @@ def get_location() -> tuple[str, int]:
|
||||
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
||||
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
|
||||
"symbolic.py", "expander.py", "lowerer.py", "cstyle.py",
|
||||
"linearize.py"}:
|
||||
"linearize.py", "devectorizer.py"}:
|
||||
frm = frm.f_back
|
||||
return frm.f_code.co_filename, frm.f_lineno
|
||||
@functools.lru_cache(None)
|
||||
|
||||
Reference in New Issue
Block a user