mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
simplify render_ops ctx [run_process_replay] (#5116)
* new ctx * delete DEFINE_VAR * lt isnt static
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence
|
||||
from typing import List, Tuple, Optional, Type, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence, Callable
|
||||
import itertools, math, functools
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -93,9 +93,21 @@ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=N
|
||||
if idxs is None: idxs = (expand_idx(node),)
|
||||
return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
|
||||
|
||||
class Linearizer(Kernel):
|
||||
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op): return UOp.alu(op, a, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
||||
def variable_to_uop(x, ctx=None) -> UOp:
|
||||
if isinstance(x, int): return UOp.const(dtypes.int, x)
|
||||
return x.render(render_ops, ctx)
|
||||
|
||||
render_ops: Dict[Type, Callable[..., UOp]] = {
|
||||
NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
|
||||
Variable: lambda self, ops, ctx: ctx[self.expr] if self.expr in ctx else UOp(UOps.DEFINE_VAR, dtypes.int, (), self),
|
||||
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
|
||||
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
|
||||
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
|
||||
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
class Linearizer(Kernel):
|
||||
def get_reduce_acc(self, reduceop:LazyOp):
|
||||
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
|
||||
if reduceop.op is ReduceOps.MAX:
|
||||
@@ -105,16 +117,6 @@ class Linearizer(Kernel):
|
||||
# NOTE: once images are loaded, we uop them as their base float
|
||||
def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
||||
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.IDIV),
|
||||
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
|
||||
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT),
|
||||
SumNode: lambda self,ops,ctx:
|
||||
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx:
|
||||
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
|
||||
buf = self.bufs[i]
|
||||
localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
|
||||
@@ -149,19 +151,19 @@ class Linearizer(Kernel):
|
||||
elif this_const is not None:
|
||||
self.load_cache[key] = UOp.const(localtype, this_const)
|
||||
if valid.min == 0 and valid.max == 1:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
valid_rendered = valid.render(render_ops, self.loop_uops)
|
||||
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], UOp.const(localtype, invalid_value))
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
|
||||
valid_tuple = (valid.render(self.render_ops, self), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
|
||||
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
||||
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
|
||||
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
if localtype == localtype.scalar():
|
||||
idx_small = idx%4
|
||||
res = idx_small.render(self.render_ops, self)
|
||||
res = idx_small.render(render_ops, self.loop_uops)
|
||||
out = UOp(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
|
||||
for ix in range(idx_small.max, idx_small.min, -1):
|
||||
rvv = UOp(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
|
||||
@@ -171,8 +173,8 @@ class Linearizer(Kernel):
|
||||
else:
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
valid_tuple = (valid.render(self.render_ops, self), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
|
||||
rendered_idx = idx.render(render_ops, self.loop_uops)
|
||||
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
|
||||
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
|
||||
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
|
||||
return ret
|
||||
@@ -206,19 +208,19 @@ class Linearizer(Kernel):
|
||||
if isinstance(buf.dtype, ImageDType):
|
||||
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
|
||||
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \
|
||||
tuple(x.render(self.render_ops, self) for x in image_idx))
|
||||
tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
rendered_idx = idx.render(render_ops, self.loop_uops)
|
||||
# TODO: let UPat check this once it's fast
|
||||
if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
|
||||
else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
|
||||
else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
|
||||
return stores
|
||||
|
||||
# render loop
|
||||
def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]:
|
||||
new_loops = {x.expr:UOp(UOps.RANGE, dtypes.int32, (
|
||||
UOp.const(dtypes.int, x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
|
||||
UOp.const(dtypes.int, x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
||||
UOp.const(dtypes.int, x.min) if isinstance(x.min, int) else cast(Node, x.min).render(render_ops, self.loop_uops),
|
||||
UOp.const(dtypes.int, x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(render_ops, self.loop_uops)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
|
||||
self.loop_uops.update(new_loops)
|
||||
return tuple(new_loops.values())
|
||||
|
||||
@@ -307,7 +309,7 @@ class Linearizer(Kernel):
|
||||
global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
|
||||
|
||||
def gate_acc(r, idxs): return [
|
||||
UOp.alu(TernaryOps.WHERE, valid.render(self.render_ops, self), acc, UOp.const(r.dtype, 0)) if valid.min == 0 and valid.max == 1 else acc
|
||||
UOp.alu(TernaryOps.WHERE, valid.render(render_ops, self.loop_uops), acc, UOp.const(r.dtype, 0)) if valid.min == 0 and valid.max == 1 else acc
|
||||
for valid, acc in zip(expand_node(self.sts[self.full_buf_index].expr_idxs(idxs)[1], expand_idxs(idxs)), accs[r])]
|
||||
local_accs = {r: gate_acc(r,global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for r in accs}
|
||||
|
||||
@@ -325,7 +327,7 @@ class Linearizer(Kernel):
|
||||
if self.opts.has_local:
|
||||
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
|
||||
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
|
||||
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self)
|
||||
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(render_ops, self.loop_uops)
|
||||
barrier = UOp(UOps.IF, None, (if_cond, barrier))
|
||||
|
||||
# create new late reduce local loops and replace local_idxs that have been used
|
||||
@@ -395,10 +397,6 @@ class Linearizer(Kernel):
|
||||
self.buf_uops[i] = UOp(UOps.DEFINE_GLOBAL,
|
||||
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
||||
(buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
|
||||
# add var vals
|
||||
for i,var in enumerate(self.vars):
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = UOp(UOps.DEFINE_VAR, dtypes.int32, (), var)
|
||||
# define local buffers
|
||||
for aliases in self.local_alias.values():
|
||||
for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype),
|
||||
|
||||
Reference in New Issue
Block a user