add UOp.define_global [run_process_replay] (#6787)

* add UOp.define_global [run_process_replay]

* no src
This commit is contained in:
qazal
2024-09-27 19:24:03 +08:00
committed by GitHub
parent b95f47784a
commit 568c97f7a2
3 changed files with 5 additions and 5 deletions

View File

@@ -5,8 +5,7 @@ from collections import defaultdict
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
from tinygrad.dtype import DType, ImageDType, PtrDType
from tinygrad.ops import UOp, UOps
from tinygrad.dtype import DType, ImageDType
from tinygrad.renderer import Renderer
# **************** Device ****************
@@ -132,7 +131,6 @@ class Buffer:
assert offset < self.nbytes, "offset must be less than nbytes"
if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
return Buffer(self.device, size, dtype, base=self, offset=offset)
def to_uop(self) -> UOp: return UOp(UOps.DEFINE_GLOBAL, self.dtype if isinstance(self.dtype, ImageDType) else PtrDType(self.dtype), (), self)
# TODO: size, dest, src are the same type. can we enforce this?
class Allocator:

View File

@@ -154,7 +154,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
return UOp(UOps.LOAD, dtype, (buf.buffer.to_uop(), unbound_st.to_uop()))
return UOp(UOps.LOAD, dtype, (UOp.define_global(buf.dtype, buf.buffer), unbound_st.to_uop()))
# reduce ops change ShapeTracker
if buf.op in ReduceOps:
@@ -188,7 +188,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
output_st = out.arg[0]
output_st, vv = output_st.simplify().unbind()
var_vals.update(vv)
ast.append(UOp(UOps.STORE, dtypes.void, (out.buffer.to_uop(), output_st.to_uop(), src)))
ast.append(UOp(UOps.STORE, dtypes.void, (UOp.define_global(out.dtype, out.buffer), output_st.to_uop(), src)))
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(x.buffer for x in outs+inputs)))
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals

View File

@@ -217,6 +217,8 @@ class UOp(MathTrait):
@staticmethod
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@staticmethod
def define_global(dtype:DType, arg): return UOp(UOps.DEFINE_GLOBAL, dtype if isinstance(dtype, ImageDType) else PtrDType(dtype), (), arg)
@staticmethod
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)