add UOP_IS_SYMBOLIC [run_process_replay] [no_assert] (#5386)

* cleanup a few things in uops [run_process_replay] [no_assert]

* add optional UOP_IS_SYMBOLIC
This commit is contained in:
George Hotz
2024-07-11 10:48:45 -07:00
committed by GitHub
parent b3790b759b
commit 3e40211e45
4 changed files with 39 additions and 18 deletions

View File

@@ -329,7 +329,7 @@ class TestMultiTensor(unittest.TestCase):
shard_output_np = shard_output.numpy()
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV"), "slow")
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow, and flaky on LLVM")
def test_data_parallel_resnet_train_step(self):
import sys, pathlib
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import List, Tuple, cast, Optional, Any, Dict
import functools
from tinygrad.codegen.kernel import Kernel
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, get_lazyop_info
from tinygrad.codegen.uops import UOp, flops_mem, UOps
@@ -10,12 +10,9 @@ from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.renderer import Program
from tinygrad.helpers import to_function_name, DEBUG, getenv, prod, diskcache_put, ContextVar
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker
def variable_to_uop(x, ctx=None) -> UOp:
if isinstance(x, int): return UOp.const(dtypes.int32, x)
return x.render(render_ops, ctx)
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int32, x) if isinstance(x, int) else x.render(render_ops, ctx)
render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
@@ -25,12 +22,38 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
# TODO: change this once UOps is ready to replace symbolic
def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
idx, valid = st.expr_idxs(fake_idxs)
ctx = dict(zip(fake_idxs, idxs))
return idx.render(render_ops, ctx), valid.render(render_ops, ctx).cast(dtypes.bool)
if getenv("UOP_IS_SYMBOLIC"):
# TODO: change this once UOps is ready to replace symbolic. note: this doesn't work for variable shapetrackers now
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
# TODO: dtypes.realint
iexpr = variable_to_uop(view.offset)
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
if sh != 1 and st != 0: iexpr = iexpr + idx*variable_to_uop(st)
if m is not None:
if m[0] != 0: vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
return iexpr, vexpr
def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
for view in reversed(st.views[0:-1]):
view = view.minify()
acc, idxs = 1, []
for _d in reversed(view.shape):
d = variable_to_uop(_d)
idxs.append((idx//acc)%d)
acc *= d
idx, valid = _uop_view(view, idxs[::-1], valid)
return idx, valid
else:
def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
idx, valid = st.expr_idxs(fake_idxs)
ctx = dict(zip(fake_idxs, idxs))
uidx, uvalid = idx.render(render_ops, ctx), valid.render(render_ops, ctx)
if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg)
assert uvalid.dtype == dtypes.bool
return uidx, uvalid
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0) -> Tuple[List[UOp], List[UOp]]:
local_idxs = loop_local_idxs = [UOp(UOps.SPECIAL, dtypes.int32, (), (i, f"{prefix}{start_dim+i}", s)) for i,s in enumerate((prod(local_dims[:-(maxdim-1)]),) + local_dims[-(maxdim-1):] if len(local_dims) > maxdim else local_dims)] # noqa: E501
@@ -88,7 +111,7 @@ class Lowerer(Kernel):
def linearize(self) -> Lowerer:
modified_ast, ki = self.get_optimized_ast()
if DEBUG >= 4:
if DEBUG >= 3:
from tinygrad.engine.graph import print_tree
for mast in modified_ast: print_tree(mast)

View File

@@ -383,7 +383,6 @@ constant_folder = PatternMatcher([
# ** self folding **
(-(-UOp.var('x')), lambda x: x), # -(-x) -> x
(UOp.var('x') + 0, lambda x: x), # x+0 -> x
(UOp.var('x') - 0, lambda x: x), # x-0 -> x
(UOp.var('x') * 1, lambda x: x), # x*1 -> x
(UOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
(UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
@@ -406,6 +405,8 @@ constant_folder = PatternMatcher([
# *** rules from symbolic ***
# two stage mul, (x*c1)*c2 = x*(c1*c2)
((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
# -(x+y) -> -x + -y
#(-(UOp.var("x") + UOp.var("y")), lambda x,y: (-x)+(-y)),
# x%1 -> 0
(UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
# (x*c0)+(x*c1) -> x*(c0+c1)

View File

@@ -13,9 +13,6 @@ from tinygrad.engine.schedule import ScheduleItem
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
if DEBUG >= 3:
from tinygrad.engine.graph import print_tree
for op in ast: print_tree(op)
k = Linearizer(*ast, opts=renderer)
k.required_optimizations()
if not NOOPT: