diff --git a/tinygrad/device.py b/tinygrad/device.py index 514cbf0e72..a6bd2ea8c2 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -5,8 +5,7 @@ from collections import defaultdict from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE -from tinygrad.dtype import DType, ImageDType, PtrDType -from tinygrad.ops import UOp, UOps +from tinygrad.dtype import DType, ImageDType from tinygrad.renderer import Renderer # **************** Device **************** @@ -132,7 +131,6 @@ class Buffer: assert offset < self.nbytes, "offset must be less than nbytes" if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset) return Buffer(self.device, size, dtype, base=self, offset=offset) - def to_uop(self) -> UOp: return UOp(UOps.DEFINE_GLOBAL, self.dtype if isinstance(self.dtype, ImageDType) else PtrDType(self.dtype), (), self) # TODO: size, dest, src are the same type. can we enforce this? class Allocator: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 5675736f59..61a8da1791 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -154,7 +154,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) if buf not in assign_targets and buf not in inputs: inputs.append(buf) - return UOp(UOps.LOAD, dtype, (buf.buffer.to_uop(), unbound_st.to_uop())) + return UOp(UOps.LOAD, dtype, (UOp.define_global(buf.dtype, buf.buffer), unbound_st.to_uop())) # reduce ops change ShapeTracker if buf.op in ReduceOps: @@ -188,7 +188,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> output_st = out.arg[0] output_st, vv = output_st.simplify().unbind() var_vals.update(vv) - ast.append(UOp(UOps.STORE, dtypes.void, (out.buffer.to_uop(), output_st.to_uop(), src))) + ast.append(UOp(UOps.STORE, dtypes.void, (UOp.define_global(out.dtype, out.buffer), output_st.to_uop(), src))) sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(x.buffer for x in outs+inputs))) return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 068d3220e8..f2e6d4ffe6 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -217,6 +217,8 @@ class UOp(MathTrait): @staticmethod def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @staticmethod + def define_global(dtype:DType, arg): return UOp(UOps.DEFINE_GLOBAL, dtype if isinstance(dtype, ImageDType) else PtrDType(dtype), (), arg) + @staticmethod def range(dtype:DType, start:ConstType, end:ConstType, idx:int): return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,)) def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)