From 95dda8dadf2970888fc8f494b83a0124eb614aa5 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 28 Jul 2024 15:08:54 +0800 Subject: [PATCH] more unmatching vectorize/gep asserts [run_process_replay] (#5760) * merge vectorize/gep rules [run_process_replay] * assert dtypes * src= * float2=(float4.x,float4.y) --- test/test_uop_graph.py | 14 +++++++++++--- tinygrad/codegen/uops.py | 3 ++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 2617e82bda..e9314c94e0 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -2,6 +2,7 @@ import unittest from test.helpers import TestUOps from tinygrad import dtypes, Variable from tinygrad.dtype import PtrDType +from tinygrad.helpers import DEBUG from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, ReduceOps from tinygrad.codegen.uops import UOps, UOp, NOp, PatternMatcher from tinygrad.codegen.uopgraph import UOpGraph, graph_rewrite @@ -144,10 +145,14 @@ class TestUOpGraph(TestUOps): d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, False)) d2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (2, False)) idx = UOp.const(dtypes.int, 0) - def _test_vec(geps): - vec = UOp(UOps.VECTORIZE, dtypes.float.vec(4), geps) + def _test_vec(geps, count=4): + vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps) out = UOp(UOps.STORE, None, (d0, idx, vec)) - return UOpGraph([out]).uops[-1].src[-1] + g = UOpGraph([out]) + if DEBUG >= 4: + from tinygrad import Device + print(Device[Device.DEFAULT].renderer.render("test", g)) + return g.uops[-1].src[-1] # possible val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx)) @@ -163,6 +168,9 @@ class TestUOpGraph(TestUOps): val = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx)) xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), i) for i in range(2)) self.assertIs(_test_vec(xy+xy).op, UOps.VECTORIZE) + val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx)) + xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), i) for i in range(2)) + self.assertIs(_test_vec(xy, count=2).op, UOps.VECTORIZE) # different vals val1 = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx)) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 17cbb8aadd..ccdba34a3b 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -200,8 +200,9 @@ def type_verify(uops): if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1 if uop is UOps.VECTORIZE: assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}" - assert dtype == src[0].dtype.vec(len(src)), f"{dtype=} must be {src[0].dtype.vec(len(src))}" + assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}" if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype + if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}" if uop is UOps.STORE: assert dtype is None, f"{uop} dtype must be None, got {dtype}" if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"