mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
unbind Variable pre LazyOp (#4873)
* early unbind * assert ConstType is correct
This commit is contained in:
@@ -48,8 +48,8 @@ class Linearizer(Kernel):
|
||||
def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op): return UOp.alu(op, a, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
|
||||
|
||||
# NOTE: the consts have to be cached for deduping of downstream uops to work
|
||||
def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp:
|
||||
return self.uops.add(UOps.DEFINE_VAR, dtype, (), b.unbind()[0]) if isinstance(b, Variable) else UOp.const(dtype, b)
|
||||
def const(self, b:ConstType|Variable, dtype:DType=dtypes.int32) -> UOp:
|
||||
return self.uops.add(UOps.DEFINE_VAR, dtype, (), b) if isinstance(b, Variable) else UOp.const(dtype, b)
|
||||
|
||||
def get_reduce_acc(self, reduceop:LazyOp):
|
||||
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import sys, pickle, atexit
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args
|
||||
from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
@@ -56,8 +56,13 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
|
||||
if buf.op is LoadOps.CONST:
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
if isinstance(buf.arg, Variable): var_vals.__setitem__(*buf.arg.unbind())
|
||||
return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st))
|
||||
if isinstance(buf.arg, Variable):
|
||||
val, var_val = buf.arg.unbind()
|
||||
var_vals.__setitem__(val, var_val)
|
||||
else:
|
||||
assert isinstance(buf.arg, get_args(ConstType)), f"cannot create ConstBuffer with value {buf.arg}"
|
||||
val = buf.arg
|
||||
return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st))
|
||||
|
||||
# if we aren't fusing it, it's a load and we add it to the inputs
|
||||
if buf.realized is not None or (buf in realizes and buf not in outputs):
|
||||
|
||||
@@ -41,7 +41,7 @@ class MemBuffer:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstBuffer:
|
||||
val: ConstType
|
||||
val: ConstType | Variable
|
||||
dtype: DType
|
||||
st: ShapeTracker
|
||||
|
||||
@@ -74,7 +74,7 @@ class LazyOp:
|
||||
def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops])
|
||||
def vars(self) -> List[Variable]:
|
||||
extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
|
||||
const_vars = [x.arg.val.unbind()[0] for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
|
||||
const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
|
||||
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda x: str(x.expr))
|
||||
|
||||
# **************** independent FlopCounter ****************
|
||||
|
||||
Reference in New Issue
Block a user