mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
more unmatching vectorize/gep asserts [run_process_replay] (#5760)
* merge vectorize/gep rules [run_process_replay] * assert dtypes * src= * float2=(float4.x,float4.y)
This commit is contained in:
@@ -2,6 +2,7 @@ import unittest
|
||||
from test.helpers import TestUOps
|
||||
from tinygrad import dtypes, Variable
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps
|
||||
from tinygrad.codegen.uops import UOps, UOp, NOp, PatternMatcher
|
||||
from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite
|
||||
@@ -144,10 +145,14 @@ class TestUOpGraph(TestUOps):
|
||||
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, False))
|
||||
d2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (2, False))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
def _test_vec(geps):
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(4), geps)
|
||||
def _test_vec(geps, count=4):
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
|
||||
out = UOp(UOps.STORE, None, (d0, idx, vec))
|
||||
return UOpGraph([out]).uops[-1].src[-1]
|
||||
g = UOpGraph([out])
|
||||
if DEBUG >= 4:
|
||||
from tinygrad import Device
|
||||
print(Device[Device.DEFAULT].renderer.render("test", g))
|
||||
return g.uops[-1].src[-1]
|
||||
|
||||
# possible
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
@@ -163,6 +168,9 @@ class TestUOpGraph(TestUOps):
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))
|
||||
xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), i) for i in range(2))
|
||||
self.assertIs(_test_vec(xy+xy).op, UOps.VECTORIZE)
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), i) for i in range(2))
|
||||
self.assertIs(_test_vec(xy, count=2).op, UOps.VECTORIZE)
|
||||
|
||||
# different vals
|
||||
val1 = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))
|
||||
|
||||
@@ -200,8 +200,9 @@ def type_verify(uops):
|
||||
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
|
||||
if uop is UOps.VECTORIZE:
|
||||
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
|
||||
assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
||||
assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
|
||||
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
|
||||
if uop is UOps.STORE:
|
||||
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
||||
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
|
||||
|
||||
Reference in New Issue
Block a user