more unmatching vectorize/gep asserts [run_process_replay] (#5760)

* merge vectorize/gep rules [run_process_replay]

* assert dtypes

* src=

* float2=(float4.x,float4.y)
This commit is contained in:
qazal
2024-07-28 15:08:54 +08:00
committed by GitHub
parent bfbd7c5461
commit 95dda8dadf
2 changed files with 13 additions and 4 deletions

View File

@@ -2,6 +2,7 @@ import unittest
from test.helpers import TestUOps
from tinygrad import dtypes, Variable
from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps
from tinygrad.codegen.uops import UOps, UOp, NOp, PatternMatcher
from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite
@@ -144,10 +145,14 @@ class TestUOpGraph(TestUOps):
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, False))
d2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (2, False))
idx = UOp.const(dtypes.int, 0)
def _test_vec(geps):
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(4), geps)
def _test_vec(geps, count=4):
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
out = UOp(UOps.STORE, None, (d0, idx, vec))
return UOpGraph([out]).uops[-1].src[-1]
g = UOpGraph([out])
if DEBUG >= 4:
from tinygrad import Device
print(Device[Device.DEFAULT].renderer.render("test", g))
return g.uops[-1].src[-1]
# possible
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
@@ -163,6 +168,9 @@ class TestUOpGraph(TestUOps):
val = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))
xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), i) for i in range(2))
self.assertIs(_test_vec(xy+xy).op, UOps.VECTORIZE)
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), i) for i in range(2))
self.assertIs(_test_vec(xy, count=2).op, UOps.VECTORIZE)
# different vals
val1 = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))

View File

@@ -200,8 +200,9 @@ def type_verify(uops):
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
if uop is UOps.VECTORIZE:
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
if uop is UOps.STORE:
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"