mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove RawConst and add test (#1939)
This commit is contained in:
@@ -5,13 +5,13 @@
|
||||
import unittest
|
||||
from typing import List, Optional
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.ops import LoadOps, Device, Compiled
|
||||
from tinygrad.helpers import DEBUG, dtypes
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.graph import log_schedule_item
|
||||
from tinygrad import nn
|
||||
|
||||
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None):
|
||||
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
|
||||
seen = set()
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
@@ -20,7 +20,7 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
|
||||
seen.add(s[1])
|
||||
sched = t.lazydata.schedule(seen)
|
||||
for s in sched: log_schedule_item(*s)
|
||||
sched = [s for s in sched if s[0].op not in LoadOps]
|
||||
if filter_loadops: sched = [s for s in sched if s[0].op not in LoadOps]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
from extra.utils import print_tree
|
||||
@@ -28,8 +28,9 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
|
||||
print("op", i)
|
||||
print_tree(s[0])
|
||||
assert len(sched) == allowed
|
||||
# test the ops linearize
|
||||
# test the (non loadops) ops linearize
|
||||
for s in sched:
|
||||
if s[0].op in LoadOps: continue
|
||||
l = Linearizer(s[0])
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
@@ -76,6 +77,11 @@ class TestSchedule(unittest.TestCase):
|
||||
d = (a+b).permute(1,0)+c
|
||||
check_schedule(d, 1)
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM", "only test for compiled backends")
|
||||
def test_constants_are_embedded(self):
|
||||
a = Tensor.empty(3,3) * 2
|
||||
check_schedule(a, 2, filter_loadops=False)
|
||||
|
||||
def test_binop_elu_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = a.elu()
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, Redu
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
|
||||
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer
|
||||
from tinygrad.runtime.lib import RawBuffer, RawBufferMapped, RawBufferTransfer
|
||||
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
|
||||
@@ -66,13 +66,11 @@ def _ast_binaryops(op:LazyOp, shape: Tuple[sint, ...]) -> LazyOp:
|
||||
|
||||
def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
|
||||
replacements:Dict[LazyBuffer, LazyOp] = {}
|
||||
base_bufs = dedup([x.base for x in op.buffers if (x.realized and not isinstance(x.realized, RawConst)) or not isinstance(Device[x.device], Compiled) or x.device == "LLVM" or (not x.realized and x.base.op.op != LoadOps.CONST)])
|
||||
base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()])
|
||||
for x in op.buffers:
|
||||
st = x.st.simplify()
|
||||
if x.base in base_bufs:
|
||||
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st))
|
||||
elif x.realized and isinstance(x.realized, RawConst):
|
||||
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(x.realized._buf, x.dtype, st))
|
||||
elif not x.realized and x.base.op.op == LoadOps.CONST:
|
||||
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st))
|
||||
else:
|
||||
@@ -128,6 +126,10 @@ class LazyBuffer:
|
||||
@property
|
||||
def base(self): return self._base if self._base is not None else self
|
||||
|
||||
def is_unrealized_const(self):
|
||||
# consts are broken in LLVM in NaN/inf
|
||||
return not self.realized and (self.base.op.op == LoadOps.CONST and isinstance(Device[self.device], Compiled) and self.device != "LLVM")
|
||||
|
||||
@property
|
||||
def realized(self): return self.base._realized
|
||||
@realized.setter
|
||||
@@ -164,7 +166,7 @@ class LazyBuffer:
|
||||
|
||||
def schedule(self, seen=None) -> List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]:
|
||||
if seen is None: seen = set()
|
||||
if self in seen or self.realized: return []
|
||||
if self in seen or self.realized or self.is_unrealized_const(): return []
|
||||
seen.add(self)
|
||||
if self.optype is MovementOps: return self.base.schedule(seen)
|
||||
|
||||
@@ -183,7 +185,7 @@ class LazyBuffer:
|
||||
# contiguous can be a copy. must do this after the image hack
|
||||
if self.op.op == LoadOps.CONTIGUOUS:
|
||||
src = cast(LazyBuffer, self.op.src[0])
|
||||
if src.st.contiguous and src.st.size() == src.base.st.size() and (src.realized or not src.base.op.op == LoadOps.CONST) and (not src.realized or not isinstance(src.realized, RawConst)):
|
||||
if src.st.contiguous and src.st.size() == src.base.st.size() and not src.is_unrealized_const():
|
||||
return src.schedule(seen) + [(self.op, self, ())]
|
||||
|
||||
# realize the past and exec the AST
|
||||
@@ -394,7 +396,7 @@ def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]
|
||||
out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **out._device_extra_args())
|
||||
del out.op
|
||||
for v in out.views: del v.op
|
||||
assert out.realized and isinstance(out.realized, (RawConst, Device[out.device].buffer)), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
|
||||
assert out.realized and isinstance(out.realized, Device[out.device].buffer), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
|
||||
assert out.realized.dtype == out.dtype, "realized dtype is incorrect"
|
||||
|
||||
def _realize_contiguous(buffer: LazyBuffer) -> None:
|
||||
@@ -430,10 +432,7 @@ def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=buffer.shape, dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args()) # type: ignore
|
||||
|
||||
def _realize_const(buffer: LazyBuffer) -> None:
|
||||
if isinstance(Device[buffer.device], Compiled) and buffer.device not in ["LLVM"]: # consts are broken in LLVM in NaN/inf
|
||||
buffer.realized = RawConst(1, buffer.dtype, float(buffer.op.arg))
|
||||
else:
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args())
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args())
|
||||
|
||||
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
|
||||
LoadOps.CONTIGUOUS: _realize_contiguous,
|
||||
|
||||
@@ -157,7 +157,7 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
|
||||
|
||||
# **************** for Compiled Buffers ****************
|
||||
|
||||
from tinygrad.runtime.lib import RawBuffer, RawConst
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
|
||||
class BasicBatchExecutor:
|
||||
@@ -221,7 +221,6 @@ class Compiled:
|
||||
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
|
||||
output.realized = output.output_buffer
|
||||
if output.realized:
|
||||
if output.realized.__class__ is RawConst: output.realized = None # can't assign to RawConst
|
||||
for i,a in enumerate(inputs):
|
||||
# TODO: if this is contiguous it's fine
|
||||
if a == output.realized:
|
||||
|
||||
Reference in New Issue
Block a user