only check it there

This commit is contained in:
George Hotz
2025-10-27 11:50:33 +08:00
parent 1eb982e01f
commit 46914e2f40
4 changed files with 31 additions and 17 deletions

View File

@@ -1,6 +1,6 @@
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, SPEC
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec, validate_pyrender
from tinygrad.renderer import Renderer
# import all pattern matchers here
@@ -20,6 +20,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
if ren is None: ren = Renderer()
if SPEC: type_verify(list(sink.toposort()), kernel_spec)
if SPEC > 2: validate_pyrender(sink)
# first we optimize
if optimize:
@@ -105,4 +106,5 @@ def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
assert len(full_sink.ranges) == 0, "all ranges must end by the sink"
lst = linearize(full_sink)
if SPEC: type_verify(lst, program_spec)
if SPEC > 2: validate_pyrender(sink)
return lst

View File

@@ -11,7 +11,7 @@ from tinygrad.helpers import suppress_finalizing
from tinygrad.gradient import compute_gradient
from tinygrad.uop.mathtraits import MathTrait
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, srender
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.uop.spec import type_verify, tensor_spec, validate_pyrender
from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner
@@ -230,6 +230,7 @@ class Tensor(MathTrait):
# verify Tensors match the spec
if SPEC: type_verify(list(big_sink.toposort()), tensor_spec)
if SPEC > 2: validate_pyrender(big_sink)
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")

View File

@@ -65,12 +65,6 @@ class UOpMetaClass(type):
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
buffers[created] = _buffer
if SPEC > 1:
if SPEC > 2:
with Context(SPEC=0):
code = '\n'.join(pyrender(created))
lcls:dict[str, UOp] = {}
exec(code, None, lcls)
if lcls['ast'] is not created: raise RuntimeError(f"PYRENDER ISSUE:\nCODE:\n{code}\nUOP:\n{created}\nPRODUCED:\n{lcls['ast']}")
from tinygrad.uop.spec import full_spec
with Context(IGNORE_OOB=1): ret = full_spec.rewrite(created)
if cast(bool|None, ret) is not True: raise RuntimeError(f"SPEC ISSUE {ret}: {created}")
@@ -124,7 +118,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
def tagstr(self): return f", tag={self.tag}" if self.tag is not None else ""
def f(self, op, **kwargs): return UOp(op, dtype=kwargs.pop("dtype", self.dtype), src=(self,), **kwargs)
def f(self, op, src=(), **kwargs): return UOp(op, dtype=kwargs.pop("dtype", self.dtype), src=(self,)+src, **kwargs)
@functools.cached_property
def backward_slice(self:UOp) -> dict[UOp, None]:
@@ -382,10 +376,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
return ret
@staticmethod
def range(end:sint, *arg, dtype=dtypes.index):
def range(end:sint, *arg, dtype=dtypes.index, **kwargs):
if len(arg) == 0: raise RuntimeError("range needs an arg")
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=arg)
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=arg, **kwargs)
@staticmethod
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
def r(self, op:Ops, axis:tuple[int, ...]):
@@ -510,6 +504,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
usrcs = []
for arg in src_args:
if len(arg) == 0: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(0)))
elif len(arg) == 1 and isinstance(arg[0], UOp): usrcs.append(arg[0])
elif all(isinstance(x, int) for x in arg): usrcs.append(UOp.const(dtypes.index.vec(len(arg)), arg))
else: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)))
ret = UOp(op, self.dtype, (self,)+tuple(usrcs), arg if len(usrcs) == 0 else None)
@@ -533,12 +528,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# TODO: use this in Buffer
unique_num = itertools.count(0)
@staticmethod
def unique(): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num))
def unique(num:int|None=None): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num) if num is None else num)
# *** uop Buffer stuff ***
@staticmethod
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp.unique(), UOp(Ops.DEVICE, arg=device)), size)
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType, num=None):
return UOp(Ops.BUFFER, dtype, (UOp.unique(num), UOp(Ops.DEVICE, arg=device)), size)
@property
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
@recursive_property
@@ -1237,7 +1233,7 @@ renderer_infer = PatternMatcher([
])
sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqrt", Ops.INDEX: "index", Ops.REDUCE: "reduce",
Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"}
Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin", Ops.CONTIGUOUS: "contiguous"}
pm_pyrender = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")),
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")),
@@ -1247,7 +1243,9 @@ pm_pyrender = PatternMatcher([
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"),
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.alu({x.op}, {x.src[1].arg})")),
(UPat(Ops.RANGE, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=
f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])}"+\
(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+\
(', tag='+str(x.tag) if x.tag is not None else '')+")")),
(UPat(Ops.SPECIAL, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg= f"UOp.special({x.src[0].arg}, \"{x.arg}\", dtype={x.dtype})")),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: UOp(Ops.NOOP, arg=
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
@@ -1255,6 +1253,12 @@ pm_pyrender = PatternMatcher([
arg=f"{x.src[0].arg}.{sugar[x.op]}({', '.join([y.arg for y in x.src[1:]] + ([f'arg={str(x.arg)}'] if x.arg is not None else []))})")),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.NOOP),), name="x"),
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.f({x.op}, arg=({', '.join([str(y) for y in x.arg])}))")),
(UPat(GroupOp.Movement, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=
f"{x.src[0].arg}.f({x.op}, src=({', '.join([y.arg for y in x.src[1:]])},))")),
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d: UOp(Ops.NOOP, arg=
f"UOp.new_buffer(\"{d.arg}\", {x.size}, {x.dtype}, {u.arg})")),
(UPat(Ops.COPY, src=(UPat(Ops.NOOP, name="x"), UPat(Ops.DEVICE, name="d"))), lambda x,d: UOp(Ops.NOOP, arg=
f"{x.arg}.copy_to_device(\"{d.arg}\")")),
])
@Context(SPEC=0)
@@ -1263,7 +1267,7 @@ def pyrender(ast:UOp) -> list[str]:
to_render = set({ast})
for u in ast.toposort():
if u.op is Ops.STORE: to_render.add(u.src[1])
if len(cmap[u]) == 1 and u.op not in {Ops.DEFINE_GLOBAL, Ops.LOAD} or u.op in {Ops.CONST}: continue
if len(cmap[u]) == 1 and u.op not in {Ops.DEFINE_GLOBAL, Ops.LOAD, Ops.BUFFER, Ops.COPY} or u.op in {Ops.CONST, Ops.DEVICE}: continue
if u.op in {Ops.SINK}:
for s in u.src: to_render.add(s)
to_render.add(u)

View File

@@ -1,5 +1,5 @@
from typing import cast
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, pyrender
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
from tinygrad.helpers import DEBUG, Context, prod
from tinygrad.uop.validate import validate_index
@@ -239,3 +239,10 @@ def type_verify(uops:list[UOp], check_spec:PatternMatcher):
if cast(bool|None, ret) is not True:
if DEBUG >= 3: print_uops(uops)
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[(x.op, x.dtype, x.arg) for x in u.src]} {u.arg}")
@Context(SPEC=0)
def validate_pyrender(test_ast:UOp):
code = '\n'.join(pyrender(test_ast))
lcls:dict[str, UOp] = {}
exec(code, None, lcls)
if lcls['ast'] is not test_ast: raise RuntimeError(f"PYRENDER ISSUE:\nCODE:\n{code}\nUOP:\n{test_ast}\nPRODUCED:\n{lcls['ast']}")