remove RawConst and add test (#1939)

This commit is contained in:
George Hotz
2023-09-29 01:21:51 -07:00
committed by GitHub
parent 22b8576887
commit d52df788d3
3 changed files with 21 additions and 17 deletions

View File

@@ -5,13 +5,13 @@
import unittest
from typing import List, Optional
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.ops import LoadOps, Device, Compiled
from tinygrad.helpers import DEBUG, dtypes
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.graph import log_schedule_item
from tinygrad import nn
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None):
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
seen = set()
if to_prerealize:
for pre in to_prerealize:
@@ -20,7 +20,7 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
seen.add(s[1])
sched = t.lazydata.schedule(seen)
for s in sched: log_schedule_item(*s)
sched = [s for s in sched if s[0].op not in LoadOps]
if filter_loadops: sched = [s for s in sched if s[0].op not in LoadOps]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
from extra.utils import print_tree
@@ -28,8 +28,9 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
print("op", i)
print_tree(s[0])
assert len(sched) == allowed
# test the ops linearize
# test the (non loadops) ops linearize
for s in sched:
if s[0].op in LoadOps: continue
l = Linearizer(s[0])
l.hand_coded_optimizations()
l.linearize()
@@ -76,6 +77,11 @@ class TestSchedule(unittest.TestCase):
d = (a+b).permute(1,0)+c
check_schedule(d, 1)
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM", "only test for compiled backends")
def test_constants_are_embedded(self):
a = Tensor.empty(3,3) * 2
check_schedule(a, 2, filter_loadops=False)
def test_binop_elu_fusion(self):
a = Tensor.empty(10)
b = a.elu()

View File

@@ -10,7 +10,7 @@ from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, Redu
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.lib import RawBuffer, RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
from tinygrad.runtime.ops_disk import RawDiskBuffer
@@ -66,13 +66,11 @@ def _ast_binaryops(op:LazyOp, shape: Tuple[sint, ...]) -> LazyOp:
def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
replacements:Dict[LazyBuffer, LazyOp] = {}
base_bufs = dedup([x.base for x in op.buffers if (x.realized and not isinstance(x.realized, RawConst)) or not isinstance(Device[x.device], Compiled) or x.device == "LLVM" or (not x.realized and x.base.op.op != LoadOps.CONST)])
base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()])
for x in op.buffers:
st = x.st.simplify()
if x.base in base_bufs:
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st))
elif x.realized and isinstance(x.realized, RawConst):
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(x.realized._buf, x.dtype, st))
elif not x.realized and x.base.op.op == LoadOps.CONST:
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st))
else:
@@ -128,6 +126,10 @@ class LazyBuffer:
@property
def base(self): return self._base if self._base is not None else self
def is_unrealized_const(self):
# consts are broken in LLVM in NaN/inf
return not self.realized and (self.base.op.op == LoadOps.CONST and isinstance(Device[self.device], Compiled) and self.device != "LLVM")
@property
def realized(self): return self.base._realized
@realized.setter
@@ -164,7 +166,7 @@ class LazyBuffer:
def schedule(self, seen=None) -> List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]:
if seen is None: seen = set()
if self in seen or self.realized: return []
if self in seen or self.realized or self.is_unrealized_const(): return []
seen.add(self)
if self.optype is MovementOps: return self.base.schedule(seen)
@@ -183,7 +185,7 @@ class LazyBuffer:
# contiguous can be a copy. must do this after the image hack
if self.op.op == LoadOps.CONTIGUOUS:
src = cast(LazyBuffer, self.op.src[0])
if src.st.contiguous and src.st.size() == src.base.st.size() and (src.realized or not src.base.op.op == LoadOps.CONST) and (not src.realized or not isinstance(src.realized, RawConst)):
if src.st.contiguous and src.st.size() == src.base.st.size() and not src.is_unrealized_const():
return src.schedule(seen) + [(self.op, self, ())]
# realize the past and exec the AST
@@ -394,7 +396,7 @@ def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]
out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **out._device_extra_args())
del out.op
for v in out.views: del v.op
assert out.realized and isinstance(out.realized, (RawConst, Device[out.device].buffer)), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
assert out.realized and isinstance(out.realized, Device[out.device].buffer), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
assert out.realized.dtype == out.dtype, "realized dtype is incorrect"
def _realize_contiguous(buffer: LazyBuffer) -> None:
@@ -430,10 +432,7 @@ def _realize_rand(buffer: LazyBuffer) -> None:
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=buffer.shape, dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args()) # type: ignore
def _realize_const(buffer: LazyBuffer) -> None:
if isinstance(Device[buffer.device], Compiled) and buffer.device not in ["LLVM"]: # consts are broken in LLVM in NaN/inf
buffer.realized = RawConst(1, buffer.dtype, float(buffer.op.arg))
else:
buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args())
buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args())
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
LoadOps.CONTIGUOUS: _realize_contiguous,

View File

@@ -157,7 +157,7 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
# **************** for Compiled Buffers ****************
from tinygrad.runtime.lib import RawBuffer, RawConst
from tinygrad.runtime.lib import RawBuffer
from tinygrad.shape.symbolic import Variable, sym_infer
class BasicBatchExecutor:
@@ -221,7 +221,6 @@ class Compiled:
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
output.realized = output.output_buffer
if output.realized:
if output.realized.__class__ is RawConst: output.realized = None # can't assign to RawConst
for i,a in enumerate(inputs):
# TODO: if this is contiguous it's fine
if a == output.realized: