diff --git a/tinygrad/device.py b/tinygrad/device.py index 52e31015ac..92f5732886 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -5,7 +5,8 @@ from collections import defaultdict from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE -from tinygrad.dtype import DType, ImageDType +from tinygrad.dtype import DType, ImageDType, PtrDType +from tinygrad.ops import UOp, UOps from tinygrad.renderer import Renderer # **************** Device **************** @@ -131,6 +132,7 @@ class Buffer: assert offset < self.nbytes, "offset must be less than nbytes" if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset) return Buffer(self.device, size, dtype, base=self, offset=offset) + def to_uop(self) -> UOp: return UOp(UOps.DEFINE_GLOBAL, self.dtype if isinstance(self.dtype, ImageDType) else PtrDType(self.dtype), (), self) # TODO: size, dest, src are the same type. can we enforce this? class Allocator: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 13c3e977aa..5675736f59 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -6,7 +6,7 @@ from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, UnaryOp from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \ GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap from tinygrad.shape.symbolic import Variable, sint -from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes +from tinygrad.dtype import ConstType, ImageDType, dtypes from tinygrad.lazy import LazyBuffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer, Device @@ -154,8 +154,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) if buf not in assign_targets and buf not in inputs: inputs.append(buf) - ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), buf.buffer) - return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop())) + return UOp(UOps.LOAD, dtype, (buf.buffer.to_uop(), unbound_st.to_uop())) # reduce ops change ShapeTracker if buf.op in ReduceOps: @@ -189,8 +188,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> output_st = out.arg[0] output_st, vv = output_st.simplify().unbind() var_vals.update(vv) - ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), out.buffer) - ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src))) + ast.append(UOp(UOps.STORE, dtypes.void, (out.buffer.to_uop(), output_st.to_uop(), src))) sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(x.buffer for x in outs+inputs))) return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals