corealize + remove realize from lazybuffer (#1968)

* corealize + remove realize from lazybuffer

* fix multigpu

* fix graph
This commit is contained in:
George Hotz
2023-10-04 10:59:31 -07:00
committed by GitHub
parent 88b6ed6945
commit de5d603ec1
9 changed files with 28 additions and 32 deletions

View File

@@ -139,7 +139,7 @@ assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
# now we realize the LazyBuffer
result.lazydata.realize()
result.realize()
assert result.lazydata.realized is not None, "the LazyBuffer is realized!"
# this brings us nicely to DeviceBuffer, of which the realized ClangBuffer is a subclass
assert 'RawMallocBuffer' in str(type(result.lazydata.realized))

11
extra/dist/world.py vendored
View File

@@ -56,11 +56,14 @@ def _recv_rb(x:RawBufferCopyIn, target_rank:int):
CacheCollector.add(__recv_rb, [x, rb, target_rank], {})
# sends a lazybuffer from our rank to the target rank
def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None: _send_rb(x.contiguous().realize().realized, target_rank, cache_id=cache_id)
def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None:
assert x.st.contiguous and x.realized, "sending buffer must be contiguous and realized"
_send_rb(x.realized, target_rank, cache_id=cache_id)
# receive a lazybuffer from the target rank
def _recv_lb(x:LazyBuffer, target_rank:int) -> LazyBuffer:
_recv_rb(x.contiguous().realize().realized, target_rank)
assert x.st.contiguous and x.realized, "receiving buffer must be contiguous and realized"
_recv_rb(x.realized, target_rank)
return x
class Send(Function):
@@ -74,5 +77,5 @@ class Recv(Function):
self.target_rank, self.cache_id = target_rank, cache_id
return _recv_lb(x, target_rank)
def send(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Send.apply(x, target_rank=target_rank, cache_id=cache_id)
def recv(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Recv.apply(x, target_rank=target_rank, cache_id=cache_id)
def send(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Send.apply(x.contiguous().realize(), target_rank=target_rank, cache_id=cache_id)
def recv(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Recv.apply(x.contiguous().realize(), target_rank=target_rank, cache_id=cache_id)

View File

@@ -2,7 +2,6 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.lazy import LAZY
from tinygrad.ops import GlobalCounters, Device
from tinygrad.graph import nm
from tinygrad.helpers import dtypes
@@ -20,7 +19,7 @@ class TestAssign(unittest.TestCase):
a += b
a.realize()
ba2 = a.lazydata.realized
if LAZY: assert ba1 == ba2 and ba1 != bb1
assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
@unittest.skipIf(Device.DEFAULT == "CPU" or Device.DEFAULT == "TORCH", "questionable tests")

View File

@@ -15,7 +15,7 @@ class TestLazyBuffer(unittest.TestCase):
def test_fromcpu_shape_tracker(self):
def helper(a: np.ndarray):
print(a.shape, a.strides, a.flags.c_contiguous)
b = LazyBuffer.fromCPU(a).realize()
b = LazyBuffer.fromCPU(a)
#assert b.st.contiguous == a.flags.c_contiguous
assert b.st.shape == a.shape
np.testing.assert_equal(a, Tensor(b).numpy())

View File

@@ -4,7 +4,7 @@ try:
except ImportError:
nx = None # graph won't work
from collections import defaultdict
from typing import Dict, List, TYPE_CHECKING, Tuple, cast
from typing import Dict, List, TYPE_CHECKING, Tuple
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters
@@ -50,7 +50,7 @@ def str_dtype(dtyp):
def log_schedule_item(iop: LazyOp, ret: 'LazyBuffer', inp: Tuple['LazyBuffer', ...]):
show_graph = bool(GRAPH)
if not DEBUG and not show_graph: return
if iop.op == LoadOps.CONTIGUOUS: setattr(ret, 'node_id', nm(cast('LazyBuffer', iop.src[0]).base))
if iop.op == LoadOps.CONTIGUOUS: setattr(ret, 'node_id', nm(inp[0].base))
if iop.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
op: List[Op] = [x.op for x in iop.get_lazyops()]

View File

@@ -16,7 +16,6 @@ from tinygrad.runtime.ops_cpu import RawNumpyBuffer
sys.setrecursionlimit(10000)
OPT = getenv("OPT", 2)
LAZY = getenv("LAZY", 1)
LAZYCACHE = getenv("LAZYCACHE", 1)
# TODO: movement ops that only change shape are really nops. treat them as such
@@ -115,7 +114,6 @@ class LazyBuffer:
self._base = base
if base: base.views.add(self)
else: assert st.contiguous, "unbased LazyBuffers must be contiguous"
if not LAZY: self.realize()
@property
def var_vals_key(self): return tuple(sorted(self.var_vals.keys()))
@@ -187,23 +185,12 @@ class LazyBuffer:
if self.op.op == LoadOps.CONTIGUOUS:
src = cast(LazyBuffer, self.op.src[0])
if src.st.contiguous and src.st.size() == src.base.st.size() and not src.is_unrealized_const():
return ret + [(self.op, self, (src,))]
op = self.op
# run the ast and log the op
op, base_bufs = _replace_bufferops(op)
# confirm the LoadOps are contiguous and in order
if op.op in LoadOps:
for i,s in enumerate(op.src):
assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
return ret + [(op, self, tuple(base_bufs))]
def realize(self:LazyBuffer) -> LazyBuffer:
from tinygrad.realize import run_schedule
if not self.realized: run_schedule(self.schedule())
return self
# *** creation/special ops ***
@staticmethod

View File

@@ -17,10 +17,8 @@ class Optimizer:
for param in self.params: param.grad = None
def realize(self, extra=None):
# TODO: corealize
# NOTE: in extra is too late for most of the params due to issues with assign
for p in extra + self.params + self.buffers if extra is not None else self.params + self.buffers:
p.realize()
Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers)
class SGD(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):

View File

@@ -1,6 +1,6 @@
from typing import List, Tuple, cast, Dict, Callable
import numpy as np
from tinygrad.ops import LazyOp, LoadOps, Device
from tinygrad.ops import LazyOp, LoadOps, BufferOps, Device
from tinygrad.graph import log_schedule_item
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import DEBUG, prod, all_int, getenv
@@ -19,7 +19,8 @@ def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]
from extra.utils import print_tree # type: ignore
print_tree(op)
if op.op in LoadOps:
# NOTE: load op buffers are promised to be in order by the scheduler
# confirm the LoadOps are contiguous and in order
for i,s in enumerate(op.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out, *buffers)
else:
out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **out._device_extra_args())

View File

@@ -5,12 +5,13 @@ from collections import defaultdict
from functools import partialmethod, reduce
from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import Device, LoadOps
from tinygrad.shape.symbolic import sint
from tinygrad.realize import run_schedule
# An instantiation of the Function is the Context
class Function:
@@ -89,8 +90,15 @@ class Tensor:
# ***** data handlers ****
@staticmethod
def corealize(lst:Iterable[Tensor]):
seen:Set[LazyBuffer] = set()
sched = []
for t in lst: sched += t.lazydata.schedule(seen)
run_schedule(sched)
def realize(self) -> Tensor:
self.lazydata.realize()
run_schedule(self.lazydata.schedule())
return self
def assign(self, x) -> Tensor:
@@ -111,7 +119,7 @@ class Tensor:
def numpy(self) -> np.ndarray:
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized._buf.reshape(self.shape)
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)
# TODO: if things are realized this won't work
def to_(self, device:str):