mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
add UOP_IS_SYMBOLIC [run_process_replay] [no_assert] (#5386)
* cleanup a few things in uops [run_process_replay] [no_assert] * add optional UOP_IS_SYMBOLIC
This commit is contained in:
@@ -329,7 +329,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
shard_output_np = shard_output.numpy()
|
||||
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV"), "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow, and flaky on LLVM")
|
||||
def test_data_parallel_resnet_train_step(self):
|
||||
import sys, pathlib
|
||||
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import List, Tuple, cast, Optional, Any, Dict
|
||||
import functools
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, DType
|
||||
from tinygrad.ops import BufferOps, LazyOp, TernaryOps, ReduceOps, UnaryOps, get_lazyop_info
|
||||
from tinygrad.codegen.uops import UOp, flops_mem, UOps
|
||||
@@ -10,12 +10,9 @@ from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.helpers import to_function_name, DEBUG, getenv, prod, diskcache_put, ContextVar
|
||||
|
||||
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker
|
||||
def variable_to_uop(x, ctx=None) -> UOp:
|
||||
if isinstance(x, int): return UOp.const(dtypes.int32, x)
|
||||
return x.render(render_ops, ctx)
|
||||
|
||||
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int32, x) if isinstance(x, int) else x.render(render_ops, ctx)
|
||||
render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
|
||||
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
|
||||
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
|
||||
@@ -25,12 +22,38 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
# TODO: change this once UOps is ready to replace symbolic
|
||||
def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
|
||||
fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
|
||||
idx, valid = st.expr_idxs(fake_idxs)
|
||||
ctx = dict(zip(fake_idxs, idxs))
|
||||
return idx.render(render_ops, ctx), valid.render(render_ops, ctx).cast(dtypes.bool)
|
||||
if getenv("UOP_IS_SYMBOLIC"):
|
||||
# TODO: change this once UOps is ready to replace symbolic. note: this doesn't work for variable shapetrackers now
|
||||
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
|
||||
# TODO: dtypes.realint
|
||||
iexpr = variable_to_uop(view.offset)
|
||||
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
||||
if sh != 1 and st != 0: iexpr = iexpr + idx*variable_to_uop(st)
|
||||
if m is not None:
|
||||
if m[0] != 0: vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
|
||||
if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
|
||||
return iexpr, vexpr
|
||||
|
||||
def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
|
||||
idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
|
||||
for view in reversed(st.views[0:-1]):
|
||||
view = view.minify()
|
||||
acc, idxs = 1, []
|
||||
for _d in reversed(view.shape):
|
||||
d = variable_to_uop(_d)
|
||||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = _uop_view(view, idxs[::-1], valid)
|
||||
return idx, valid
|
||||
else:
|
||||
def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
|
||||
fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
|
||||
idx, valid = st.expr_idxs(fake_idxs)
|
||||
ctx = dict(zip(fake_idxs, idxs))
|
||||
uidx, uvalid = idx.render(render_ops, ctx), valid.render(render_ops, ctx)
|
||||
if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg)
|
||||
assert uvalid.dtype == dtypes.bool
|
||||
return uidx, uvalid
|
||||
|
||||
def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0) -> Tuple[List[UOp], List[UOp]]:
|
||||
local_idxs = loop_local_idxs = [UOp(UOps.SPECIAL, dtypes.int32, (), (i, f"{prefix}{start_dim+i}", s)) for i,s in enumerate((prod(local_dims[:-(maxdim-1)]),) + local_dims[-(maxdim-1):] if len(local_dims) > maxdim else local_dims)] # noqa: E501
|
||||
@@ -88,7 +111,7 @@ class Lowerer(Kernel):
|
||||
|
||||
def linearize(self) -> Lowerer:
|
||||
modified_ast, ki = self.get_optimized_ast()
|
||||
if DEBUG >= 4:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.engine.graph import print_tree
|
||||
for mast in modified_ast: print_tree(mast)
|
||||
|
||||
|
||||
@@ -383,7 +383,6 @@ constant_folder = PatternMatcher([
|
||||
# ** self folding **
|
||||
(-(-UOp.var('x')), lambda x: x), # -(-x) -> x
|
||||
(UOp.var('x') + 0, lambda x: x), # x+0 -> x
|
||||
(UOp.var('x') - 0, lambda x: x), # x-0 -> x
|
||||
(UOp.var('x') * 1, lambda x: x), # x*1 -> x
|
||||
(UOp.var('x') * -1, lambda x: -x), # x*-1 -> -x
|
||||
(UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
|
||||
@@ -406,6 +405,8 @@ constant_folder = PatternMatcher([
|
||||
# *** rules from symbolic ***
|
||||
# two stage mul, (x*c1)*c2 = x*(c1*c2)
|
||||
((UOp.var("x") * UOp.cvar("c1")) * UOp.cvar("c2"), lambda x,c1,c2: x*UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c1.arg, c2.arg]))),
|
||||
# -(x+y) -> -x + -y
|
||||
#(-(UOp.var("x") + UOp.var("y")), lambda x,y: (-x)+(-y)),
|
||||
# x%1 -> 0
|
||||
(UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
|
||||
# (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
|
||||
@@ -13,9 +13,6 @@ from tinygrad.engine.schedule import ScheduleItem
|
||||
|
||||
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
|
||||
def get_linearizer(renderer:Renderer, ast:Tuple[LazyOp, ...]) -> Linearizer:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.engine.graph import print_tree
|
||||
for op in ast: print_tree(op)
|
||||
k = Linearizer(*ast, opts=renderer)
|
||||
k.required_optimizations()
|
||||
if not NOOPT:
|
||||
|
||||
Reference in New Issue
Block a user