unbind Variable pre LazyOp (#4873)

* early unbind

* assert ConstType is correct
This commit is contained in:
qazal
2024-06-08 20:16:38 +08:00
committed by GitHub
parent 9c30889ce9
commit d19f39d4dd
3 changed files with 13 additions and 8 deletions

View File

@@ -48,8 +48,8 @@ class Linearizer(Kernel):
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op): return UOp.alu(op, a, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
# NOTE: the consts have to be cached for deduping of downstream uops to work
def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp:
return self.uops.add(UOps.DEFINE_VAR, dtype, (), b.unbind()[0]) if isinstance(b, Variable) else UOp.const(dtype, b)
def const(self, b:ConstType|Variable, dtype:DType=dtypes.int32) -> UOp:
return self.uops.add(UOps.DEFINE_VAR, dtype, (), b) if isinstance(b, Variable) else UOp.const(dtype, b)
def get_reduce_acc(self, reduceop:LazyOp):
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0

View File

@@ -1,12 +1,12 @@
import sys, pickle, atexit
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args
from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes, DType
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
@@ -56,8 +56,13 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
if buf.op is LoadOps.CONST:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if isinstance(buf.arg, Variable): var_vals.__setitem__(*buf.arg.unbind())
return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st))
if isinstance(buf.arg, Variable):
val, var_val = buf.arg.unbind()
var_vals.__setitem__(val, var_val)
else:
assert isinstance(buf.arg, get_args(ConstType)), f"cannot create ConstBuffer with value {buf.arg}"
val = buf.arg
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
# if we aren't fusing it, it's a load and we add it to the inputs
if buf.realized is not None or (buf in realizes and buf not in outputs):

View File

@@ -41,7 +41,7 @@ class MemBuffer:
@dataclass(frozen=True)
class ConstBuffer:
val: ConstType
val: ConstType | Variable
dtype: DType
st: ShapeTracker
@@ -74,7 +74,7 @@ class LazyOp:
def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
def vars(self) -> List[Variable]:
extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
const_vars = [x.arg.val.unbind()[0] for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda x: str(x.expr))
# **************** independent FlopCounter ****************