mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
add UOp.define_global [run_process_replay] (#6787)
* add UOp.define_global [run_process_replay] * no src
This commit is contained in:
@@ -5,8 +5,7 @@ from collections import defaultdict
|
||||
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
|
||||
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
|
||||
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# **************** Device ****************
|
||||
@@ -132,7 +131,6 @@ class Buffer:
|
||||
assert offset < self.nbytes, "offset must be less than nbytes"
|
||||
if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
|
||||
return Buffer(self.device, size, dtype, base=self, offset=offset)
|
||||
def to_uop(self) -> UOp: return UOp(UOps.DEFINE_GLOBAL, self.dtype if isinstance(self.dtype, ImageDType) else PtrDType(self.dtype), (), self)
|
||||
|
||||
# TODO: size, dest, src are the same type. can we enforce this?
|
||||
class Allocator:
|
||||
|
||||
@@ -154,7 +154,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
|
||||
return UOp(UOps.LOAD, dtype, (buf.buffer.to_uop(), unbound_st.to_uop()))
|
||||
return UOp(UOps.LOAD, dtype, (UOp.define_global(buf.dtype, buf.buffer), unbound_st.to_uop()))
|
||||
|
||||
# reduce ops change ShapeTracker
|
||||
if buf.op in ReduceOps:
|
||||
@@ -188,7 +188,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
||||
output_st = out.arg[0]
|
||||
output_st, vv = output_st.simplify().unbind()
|
||||
var_vals.update(vv)
|
||||
ast.append(UOp(UOps.STORE, dtypes.void, (out.buffer.to_uop(), output_st.to_uop(), src)))
|
||||
ast.append(UOp(UOps.STORE, dtypes.void, (UOp.define_global(out.dtype, out.buffer), output_st.to_uop(), src)))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(x.buffer for x in outs+inputs)))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals
|
||||
|
||||
|
||||
@@ -217,6 +217,8 @@ class UOp(MathTrait):
|
||||
@staticmethod
|
||||
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@staticmethod
|
||||
def define_global(dtype:DType, arg): return UOp(UOps.DEFINE_GLOBAL, dtype if isinstance(dtype, ImageDType) else PtrDType(dtype), (), arg)
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
|
||||
Reference in New Issue
Block a user