mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
only check it there
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from tinygrad.helpers import QUANTIZE, DEVECTORIZE, TRANSCENDENTAL, SPEC
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec, validate_pyrender
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# import all pattern matchers here
|
||||
@@ -20,6 +20,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
if ren is None: ren = Renderer()
|
||||
|
||||
if SPEC: type_verify(list(sink.toposort()), kernel_spec)
|
||||
if SPEC > 2: validate_pyrender(sink)
|
||||
|
||||
# first we optimize
|
||||
if optimize:
|
||||
@@ -105,4 +106,5 @@ def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
|
||||
assert len(full_sink.ranges) == 0, "all ranges must end by the sink"
|
||||
lst = linearize(full_sink)
|
||||
if SPEC: type_verify(lst, program_spec)
|
||||
if SPEC > 2: validate_pyrender(sink)
|
||||
return lst
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.helpers import suppress_finalizing
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, srender
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec, validate_pyrender
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
@@ -230,6 +230,7 @@ class Tensor(MathTrait):
|
||||
|
||||
# verify Tensors match the spec
|
||||
if SPEC: type_verify(list(big_sink.toposort()), tensor_spec)
|
||||
if SPEC > 2: validate_pyrender(big_sink)
|
||||
|
||||
if any(isinstance(x._device, tuple) for x in big_sink.toposort()):
|
||||
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
|
||||
|
||||
@@ -65,12 +65,6 @@ class UOpMetaClass(type):
|
||||
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
|
||||
buffers[created] = _buffer
|
||||
if SPEC > 1:
|
||||
if SPEC > 2:
|
||||
with Context(SPEC=0):
|
||||
code = '\n'.join(pyrender(created))
|
||||
lcls:dict[str, UOp] = {}
|
||||
exec(code, None, lcls)
|
||||
if lcls['ast'] is not created: raise RuntimeError(f"PYRENDER ISSUE:\nCODE:\n{code}\nUOP:\n{created}\nPRODUCED:\n{lcls['ast']}")
|
||||
from tinygrad.uop.spec import full_spec
|
||||
with Context(IGNORE_OOB=1): ret = full_spec.rewrite(created)
|
||||
if cast(bool|None, ret) is not True: raise RuntimeError(f"SPEC ISSUE {ret}: {created}")
|
||||
@@ -124,7 +118,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else repr(self.arg)
|
||||
def tagstr(self): return f", tag={self.tag}" if self.tag is not None else ""
|
||||
|
||||
def f(self, op, **kwargs): return UOp(op, dtype=kwargs.pop("dtype", self.dtype), src=(self,), **kwargs)
|
||||
def f(self, op, src=(), **kwargs): return UOp(op, dtype=kwargs.pop("dtype", self.dtype), src=(self,)+src, **kwargs)
|
||||
|
||||
@functools.cached_property
|
||||
def backward_slice(self:UOp) -> dict[UOp, None]:
|
||||
@@ -382,10 +376,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
|
||||
return ret
|
||||
@staticmethod
|
||||
def range(end:sint, *arg, dtype=dtypes.index):
|
||||
def range(end:sint, *arg, dtype=dtypes.index, **kwargs):
|
||||
if len(arg) == 0: raise RuntimeError("range needs an arg")
|
||||
if len(arg) == 1: arg = arg+(AxisType.LOOP,)
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=arg)
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=arg, **kwargs)
|
||||
@staticmethod
|
||||
def special(end:sint, name:str, dtype=dtypes.index): return UOp(Ops.SPECIAL, dtype=dtype, src=(sint_to_uop(end, dtype),), arg=name)
|
||||
def r(self, op:Ops, axis:tuple[int, ...]):
|
||||
@@ -510,6 +504,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
usrcs = []
|
||||
for arg in src_args:
|
||||
if len(arg) == 0: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(0)))
|
||||
elif len(arg) == 1 and isinstance(arg[0], UOp): usrcs.append(arg[0])
|
||||
elif all(isinstance(x, int) for x in arg): usrcs.append(UOp.const(dtypes.index.vec(len(arg)), arg))
|
||||
else: usrcs.append(UOp(Ops.VECTORIZE, dtypes.index.vec(len(arg)), tuple(UOp.const(dtypes.index, x) if isinstance(x, int) else x for x in arg)))
|
||||
ret = UOp(op, self.dtype, (self,)+tuple(usrcs), arg if len(usrcs) == 0 else None)
|
||||
@@ -533,12 +528,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# TODO: use this in Buffer
|
||||
unique_num = itertools.count(0)
|
||||
@staticmethod
|
||||
def unique(): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num))
|
||||
def unique(num:int|None=None): return UOp(Ops.UNIQUE, arg=next(UOp.unique_num) if num is None else num)
|
||||
|
||||
# *** uop Buffer stuff ***
|
||||
|
||||
@staticmethod
|
||||
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp.unique(), UOp(Ops.DEVICE, arg=device)), size)
|
||||
def new_buffer(device:str|tuple[str, ...], size:int, dtype:DType, num=None):
|
||||
return UOp(Ops.BUFFER, dtype, (UOp.unique(num), UOp(Ops.DEVICE, arg=device)), size)
|
||||
@property
|
||||
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
|
||||
@recursive_property
|
||||
@@ -1237,7 +1233,7 @@ renderer_infer = PatternMatcher([
|
||||
])
|
||||
|
||||
sugar = { Ops.SINK: "sink", Ops.STORE: "store", Ops.LOAD: "load", Ops.SQRT: "sqrt", Ops.INDEX: "index", Ops.REDUCE: "reduce",
|
||||
Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin"}
|
||||
Ops.WHERE: "where", Ops.RECIP: "reciprocal", Ops.EXP2: "exp2", Ops.LOG2: "log2", Ops.SIN: "sin", Ops.CONTIGUOUS: "contiguous"}
|
||||
pm_pyrender = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg}, src={x.src[0].arg})")),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=f"UOp.const({x.dtype}, {x.arg})")),
|
||||
@@ -1247,7 +1243,9 @@ pm_pyrender = PatternMatcher([
|
||||
(UPat({Ops.MAX, Ops.THREEFRY, Ops.CMPLT, Ops.CMPNE, Ops.POW}, src=UPat(Ops.NOOP), name="x"),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.alu({x.op}, {x.src[1].arg})")),
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=
|
||||
f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
|
||||
f"UOp.range({x.src[0].arg}, {str(x.arg[0])}, {str(x.arg[1])}"+\
|
||||
(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+\
|
||||
(', tag='+str(x.tag) if x.tag is not None else '')+")")),
|
||||
(UPat(Ops.SPECIAL, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg= f"UOp.special({x.src[0].arg}, \"{x.arg}\", dtype={x.dtype})")),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: UOp(Ops.NOOP, arg=
|
||||
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.index else ''})")),
|
||||
@@ -1255,6 +1253,12 @@ pm_pyrender = PatternMatcher([
|
||||
arg=f"{x.src[0].arg}.{sugar[x.op]}({', '.join([y.arg for y in x.src[1:]] + ([f'arg={str(x.arg)}'] if x.arg is not None else []))})")),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.NOOP),), name="x"),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.f({x.op}, arg=({', '.join([str(y) for y in x.arg])}))")),
|
||||
(UPat(GroupOp.Movement, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=
|
||||
f"{x.src[0].arg}.f({x.op}, src=({', '.join([y.arg for y in x.src[1:]])},))")),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d: UOp(Ops.NOOP, arg=
|
||||
f"UOp.new_buffer(\"{d.arg}\", {x.size}, {x.dtype}, {u.arg})")),
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.NOOP, name="x"), UPat(Ops.DEVICE, name="d"))), lambda x,d: UOp(Ops.NOOP, arg=
|
||||
f"{x.arg}.copy_to_device(\"{d.arg}\")")),
|
||||
])
|
||||
|
||||
@Context(SPEC=0)
|
||||
@@ -1263,7 +1267,7 @@ def pyrender(ast:UOp) -> list[str]:
|
||||
to_render = set({ast})
|
||||
for u in ast.toposort():
|
||||
if u.op is Ops.STORE: to_render.add(u.src[1])
|
||||
if len(cmap[u]) == 1 and u.op not in {Ops.DEFINE_GLOBAL, Ops.LOAD} or u.op in {Ops.CONST}: continue
|
||||
if len(cmap[u]) == 1 and u.op not in {Ops.DEFINE_GLOBAL, Ops.LOAD, Ops.BUFFER, Ops.COPY} or u.op in {Ops.CONST, Ops.DEVICE}: continue
|
||||
if u.op in {Ops.SINK}:
|
||||
for s in u.src: to_render.add(s)
|
||||
to_render.add(u)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import cast
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, AxisType, pyrender
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
|
||||
from tinygrad.helpers import DEBUG, Context, prod
|
||||
from tinygrad.uop.validate import validate_index
|
||||
@@ -239,3 +239,10 @@ def type_verify(uops:list[UOp], check_spec:PatternMatcher):
|
||||
if cast(bool|None, ret) is not True:
|
||||
if DEBUG >= 3: print_uops(uops)
|
||||
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[(x.op, x.dtype, x.arg) for x in u.src]} {u.arg}")
|
||||
|
||||
@Context(SPEC=0)
|
||||
def validate_pyrender(test_ast:UOp):
|
||||
code = '\n'.join(pyrender(test_ast))
|
||||
lcls:dict[str, UOp] = {}
|
||||
exec(code, None, lcls)
|
||||
if lcls['ast'] is not test_ast: raise RuntimeError(f"PYRENDER ISSUE:\nCODE:\n{code}\nUOP:\n{test_ast}\nPRODUCED:\n{lcls['ast']}")
|
||||
|
||||
Reference in New Issue
Block a user