diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 5bf1f133a1..4d829082a3 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -48,8 +48,8 @@ class Linearizer(Kernel): def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op): return UOp.alu(op, a, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx)) # NOTE: the consts have to be cached for deduping of downstream uops to work - def const(self, b:ConstType, dtype:DType=dtypes.int32) -> UOp: - return self.uops.add(UOps.DEFINE_VAR, dtype, (), b.unbind()[0]) if isinstance(b, Variable) else UOp.const(dtype, b) + def const(self, b:ConstType|Variable, dtype:DType=dtypes.int32) -> UOp: + return self.uops.add(UOps.DEFINE_VAR, dtype, (), b) if isinstance(b, Variable) else UOp.const(dtype, b) def get_reduce_acc(self, reduceop:LazyOp): if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0 diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 3a20f2fb09..8aae1261bd 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,12 +1,12 @@ import sys, pickle, atexit from collections import defaultdict, deque from dataclasses import dataclass -from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union +from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args from tinygrad.ops import LoadOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv from tinygrad.shape.symbolic import Variable -from tinygrad.dtype import ImageDType, dtypes, DType +from tinygrad.dtype import ConstType, ImageDType, dtypes, DType from tinygrad.lazy import LazyBuffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer @@ -56,8 +56,13 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz if buf.op is LoadOps.CONST: unbound_st, st_var_vals = st.simplify().unbind() var_vals.update(st_var_vals) - if isinstance(buf.arg, Variable): var_vals.__setitem__(*buf.arg.unbind()) - return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st)) + if isinstance(buf.arg, Variable): + val, var_val = buf.arg.unbind() + var_vals.__setitem__(val, var_val) + else: + assert isinstance(buf.arg, get_args(ConstType)), f"cannot create ConstBuffer with value {buf.arg}" + val = buf.arg + return LazyOp(BufferOps.CONST, (), ConstBuffer(val, buf.dtype, unbound_st)) # if we aren't fusing it, it's a load and we add it to the inputs if buf.realized is not None or (buf in realizes and buf not in outputs): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index fc818cc939..dc9c53f987 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -41,7 +41,7 @@ class MemBuffer: @dataclass(frozen=True) class ConstBuffer: - val: ConstType + val: ConstType | Variable dtype: DType st: ShapeTracker @@ -74,7 +74,7 @@ class LazyOp: def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src for item in x.lazyops]) def vars(self) -> List[Variable]: extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps] - const_vars = [x.arg.val.unbind()[0] for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)] + const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)] return sorted(set.union(*extract_vars, set(const_vars)), key=lambda x: str(x.expr)) # **************** independent FlopCounter ****************