simplify render_ops ctx [run_process_replay] (#5116)

* new ctx

* delete DEFINE_VAR

* lt isnt static
This commit is contained in:
qazal
2024-06-23 16:56:32 +03:00
committed by GitHub
parent 28bf8d86d8
commit 64a3b7931e

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import List, Tuple, Any, Optional, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence
from typing import List, Tuple, Optional, Type, cast, DefaultDict, Dict, Union, Final, Iterator, Sequence, Callable
import itertools, math, functools
from collections import defaultdict
@@ -93,9 +93,21 @@ def expand_node(node:Node, idxs:Optional[Tuple[Union[Variable, NumNode], ...]]=N
if idxs is None: idxs = (expand_idx(node),)
return [node.substitute({k:v for k,v in zip(idxs, (NumNode(x) for x in rep)) if isinstance(k, Variable)}) for rep in iter_idxs(idxs)]
class Linearizer(Kernel):
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op): return UOp.alu(op, a, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
def variable_to_uop(x, ctx=None) -> UOp:
if isinstance(x, int): return UOp.const(dtypes.int, x)
return x.render(render_ops, ctx)
render_ops: Dict[Type, Callable[..., UOp]] = {
NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
Variable: lambda self, ops, ctx: ctx[self.expr] if self.expr in ctx else UOp(UOps.DEFINE_VAR, dtypes.int, (), self),
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*variable_to_uop(b, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
class Linearizer(Kernel):
def get_reduce_acc(self, reduceop:LazyOp):
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
if reduceop.op is ReduceOps.MAX:
@@ -105,16 +117,6 @@ class Linearizer(Kernel):
# NOTE: once images are loaded, we uop them as their base float
def get_base_dtype(self, dt:DType) -> DType: return dt.base if isinstance(dt, ImageDType) else dt
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: UOp.const(dtypes.int, self.b),
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.IDIV),
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD),
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT),
SumNode: lambda self,ops,ctx:
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx:
functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
def global_load(self, i:int, idxs:List[Node], acc:Optional[LazyOp]=None, barrier:Optional[UOp]=None, loop_ctx:Tuple[UOp, ...]=()) -> List[UOp]:
buf = self.bufs[i]
localtype = self.get_base_dtype(buf.dtype if acc is None else acc.dtype)
@@ -149,19 +151,19 @@ class Linearizer(Kernel):
elif this_const is not None:
self.load_cache[key] = UOp.const(localtype, this_const)
if valid.min == 0 and valid.max == 1:
valid_rendered = valid.render(self.render_ops, self)
valid_rendered = valid.render(render_ops, self.loop_uops)
self.load_cache[key] = UOp.alu(TernaryOps.WHERE, valid_rendered, self.load_cache[key], UOp.const(localtype, invalid_value))
elif isinstance(buf.dtype, ImageDType):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx))
valid_tuple = (valid.render(self.render_ops, self), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(buf.dtype.base.vec(4), invalid_value)) if valid.min == 0 else tuple()
self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4),
(buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
if localtype == localtype.scalar():
idx_small = idx%4
res = idx_small.render(self.render_ops, self)
res = idx_small.render(render_ops, self.loop_uops)
out = UOp(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max)
for ix in range(idx_small.max, idx_small.min, -1):
rvv = UOp(UOps.GEP, localtype, (self.load_cache[key],), ix-1)
@@ -171,8 +173,8 @@ class Linearizer(Kernel):
else:
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
rendered_idx = idx.render(self.render_ops, self)
valid_tuple = (valid.render(self.render_ops, self), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
rendered_idx = idx.render(render_ops, self.loop_uops)
valid_tuple = (valid.render(render_ops, self.loop_uops), UOp.const(localtype, invalid_value)) if valid.min == 0 else tuple()
self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ()))
ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key])
return ret
@@ -206,19 +208,19 @@ class Linearizer(Kernel):
if isinstance(buf.dtype, ImageDType):
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \
tuple(x.render(self.render_ops, self) for x in image_idx))
tuple(x.render(render_ops, self.loop_uops) for x in image_idx))
else:
rendered_idx = idx.render(self.render_ops, self)
rendered_idx = idx.render(render_ops, self.loop_uops)
# TODO: let UPat check this once it's fast
if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var)))
else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self))))
else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(render_ops, self.loop_uops))))
return stores
# render loop
def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]:
new_loops = {x.expr:UOp(UOps.RANGE, dtypes.int32, (
UOp.const(dtypes.int, x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self),
UOp.const(dtypes.int, x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
UOp.const(dtypes.int, x.min) if isinstance(x.min, int) else cast(Node, x.min).render(render_ops, self.loop_uops),
UOp.const(dtypes.int, x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(render_ops, self.loop_uops)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501
self.loop_uops.update(new_loops)
return tuple(new_loops.values())
@@ -307,7 +309,7 @@ class Linearizer(Kernel):
global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs})
def gate_acc(r, idxs): return [
UOp.alu(TernaryOps.WHERE, valid.render(self.render_ops, self), acc, UOp.const(r.dtype, 0)) if valid.min == 0 and valid.max == 1 else acc
UOp.alu(TernaryOps.WHERE, valid.render(render_ops, self.loop_uops), acc, UOp.const(r.dtype, 0)) if valid.min == 0 and valid.max == 1 else acc
for valid, acc in zip(expand_node(self.sts[self.full_buf_index].expr_idxs(idxs)[1], expand_idxs(idxs)), accs[r])]
local_accs = {r: gate_acc(r,global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for r in accs}
@@ -325,7 +327,7 @@ class Linearizer(Kernel):
if self.opts.has_local:
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:]
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self)
if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(render_ops, self.loop_uops)
barrier = UOp(UOps.IF, None, (if_cond, barrier))
# create new late reduce local loops and replace local_idxs that have been used
@@ -395,10 +397,6 @@ class Linearizer(Kernel):
self.buf_uops[i] = UOp(UOps.DEFINE_GLOBAL,
buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
(buf.idx, any(buf.idx == x.idx for x in self.outbufs)))
# add var vals
for i,var in enumerate(self.vars):
assert var.expr is not None
self.loop_uops[var.expr] = UOp(UOps.DEFINE_VAR, dtypes.int32, (), var)
# define local buffers
for aliases in self.local_alias.values():
for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype),