mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
corealize + remove realize from lazybuffer (#1968)
* corealize + remove realize from lazybuffer * fix multigpu * fix graph
This commit is contained in:
@@ -139,7 +139,7 @@ assert lazyop.src[0].op.src[0].op.src[0].realized._buf[0] == 2, "the src of the
|
||||
assert result.lazydata.realized is None, "the LazyBuffer is not realized yet"
|
||||
|
||||
# now we realize the LazyBuffer
|
||||
result.lazydata.realize()
|
||||
result.realize()
|
||||
assert result.lazydata.realized is not None, "the LazyBuffer is realized!"
|
||||
# this brings us nicely to DeviceBuffer, of which the realized ClangBuffer is a subclass
|
||||
assert 'RawMallocBuffer' in str(type(result.lazydata.realized))
|
||||
|
||||
11
extra/dist/world.py
vendored
11
extra/dist/world.py
vendored
@@ -56,11 +56,14 @@ def _recv_rb(x:RawBufferCopyIn, target_rank:int):
|
||||
CacheCollector.add(__recv_rb, [x, rb, target_rank], {})
|
||||
|
||||
# sends a lazybuffer from our rank to the target rank
|
||||
def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None: _send_rb(x.contiguous().realize().realized, target_rank, cache_id=cache_id)
|
||||
def _send_lb(x:LazyBuffer, target_rank:int, cache_id:Optional[str]=None) -> None:
|
||||
assert x.st.contiguous and x.realized, "sending buffer must be contiguous and realized"
|
||||
_send_rb(x.realized, target_rank, cache_id=cache_id)
|
||||
|
||||
# receive a lazybuffer from the target rank
|
||||
def _recv_lb(x:LazyBuffer, target_rank:int) -> LazyBuffer:
|
||||
_recv_rb(x.contiguous().realize().realized, target_rank)
|
||||
assert x.st.contiguous and x.realized, "receiving buffer must be contiguous and realized"
|
||||
_recv_rb(x.realized, target_rank)
|
||||
return x
|
||||
|
||||
class Send(Function):
|
||||
@@ -74,5 +77,5 @@ class Recv(Function):
|
||||
self.target_rank, self.cache_id = target_rank, cache_id
|
||||
return _recv_lb(x, target_rank)
|
||||
|
||||
def send(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Send.apply(x, target_rank=target_rank, cache_id=cache_id)
|
||||
def recv(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Recv.apply(x, target_rank=target_rank, cache_id=cache_id)
|
||||
def send(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Send.apply(x.contiguous().realize(), target_rank=target_rank, cache_id=cache_id)
|
||||
def recv(x:Tensor, target_rank:int, cache_id:Optional[str]=None) -> Tensor: return Recv.apply(x.contiguous().realize(), target_rank=target_rank, cache_id=cache_id)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LAZY
|
||||
from tinygrad.ops import GlobalCounters, Device
|
||||
from tinygrad.graph import nm
|
||||
from tinygrad.helpers import dtypes
|
||||
@@ -20,7 +19,7 @@ class TestAssign(unittest.TestCase):
|
||||
a += b
|
||||
a.realize()
|
||||
ba2 = a.lazydata.realized
|
||||
if LAZY: assert ba1 == ba2 and ba1 != bb1
|
||||
assert ba1 == ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU" or Device.DEFAULT == "TORCH", "questionable tests")
|
||||
|
||||
@@ -15,7 +15,7 @@ class TestLazyBuffer(unittest.TestCase):
|
||||
def test_fromcpu_shape_tracker(self):
|
||||
def helper(a: np.ndarray):
|
||||
print(a.shape, a.strides, a.flags.c_contiguous)
|
||||
b = LazyBuffer.fromCPU(a).realize()
|
||||
b = LazyBuffer.fromCPU(a)
|
||||
#assert b.st.contiguous == a.flags.c_contiguous
|
||||
assert b.st.shape == a.shape
|
||||
np.testing.assert_equal(a, Tensor(b).numpy())
|
||||
|
||||
@@ -4,7 +4,7 @@ try:
|
||||
except ImportError:
|
||||
nx = None # graph won't work
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, TYPE_CHECKING, Tuple, cast
|
||||
from typing import Dict, List, TYPE_CHECKING, Tuple
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters
|
||||
|
||||
@@ -50,7 +50,7 @@ def str_dtype(dtyp):
|
||||
def log_schedule_item(iop: LazyOp, ret: 'LazyBuffer', inp: Tuple['LazyBuffer', ...]):
|
||||
show_graph = bool(GRAPH)
|
||||
if not DEBUG and not show_graph: return
|
||||
if iop.op == LoadOps.CONTIGUOUS: setattr(ret, 'node_id', nm(cast('LazyBuffer', iop.src[0]).base))
|
||||
if iop.op == LoadOps.CONTIGUOUS: setattr(ret, 'node_id', nm(inp[0].base))
|
||||
if iop.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
|
||||
|
||||
op: List[Op] = [x.op for x in iop.get_lazyops()]
|
||||
|
||||
@@ -16,7 +16,6 @@ from tinygrad.runtime.ops_cpu import RawNumpyBuffer
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
OPT = getenv("OPT", 2)
|
||||
LAZY = getenv("LAZY", 1)
|
||||
LAZYCACHE = getenv("LAZYCACHE", 1)
|
||||
|
||||
# TODO: movement ops that only change shape are really nops. treat them as such
|
||||
@@ -115,7 +114,6 @@ class LazyBuffer:
|
||||
self._base = base
|
||||
if base: base.views.add(self)
|
||||
else: assert st.contiguous, "unbased LazyBuffers must be contiguous"
|
||||
if not LAZY: self.realize()
|
||||
|
||||
@property
|
||||
def var_vals_key(self): return tuple(sorted(self.var_vals.keys()))
|
||||
@@ -187,23 +185,12 @@ class LazyBuffer:
|
||||
if self.op.op == LoadOps.CONTIGUOUS:
|
||||
src = cast(LazyBuffer, self.op.src[0])
|
||||
if src.st.contiguous and src.st.size() == src.base.st.size() and not src.is_unrealized_const():
|
||||
return ret + [(self.op, self, (src,))]
|
||||
op = self.op
|
||||
|
||||
# run the ast and log the op
|
||||
op, base_bufs = _replace_bufferops(op)
|
||||
|
||||
# confirm the LoadOps are contiguous and in order
|
||||
if op.op in LoadOps:
|
||||
for i,s in enumerate(op.src):
|
||||
assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
|
||||
return ret + [(op, self, tuple(base_bufs))]
|
||||
|
||||
def realize(self:LazyBuffer) -> LazyBuffer:
|
||||
from tinygrad.realize import run_schedule
|
||||
if not self.realized: run_schedule(self.schedule())
|
||||
return self
|
||||
|
||||
# *** creation/special ops ***
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -17,10 +17,8 @@ class Optimizer:
|
||||
for param in self.params: param.grad = None
|
||||
|
||||
def realize(self, extra=None):
|
||||
# TODO: corealize
|
||||
# NOTE: in extra is too late for most of the params due to issues with assign
|
||||
for p in extra + self.params + self.buffers if extra is not None else self.params + self.buffers:
|
||||
p.realize()
|
||||
Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers)
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Tuple, cast, Dict, Callable
|
||||
import numpy as np
|
||||
from tinygrad.ops import LazyOp, LoadOps, Device
|
||||
from tinygrad.ops import LazyOp, LoadOps, BufferOps, Device
|
||||
from tinygrad.graph import log_schedule_item
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import DEBUG, prod, all_int, getenv
|
||||
@@ -19,7 +19,8 @@ def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]
|
||||
from extra.utils import print_tree # type: ignore
|
||||
print_tree(op)
|
||||
if op.op in LoadOps:
|
||||
# NOTE: load op buffers are promised to be in order by the scheduler
|
||||
# confirm the LoadOps are contiguous and in order
|
||||
for i,s in enumerate(op.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out, *buffers)
|
||||
else:
|
||||
out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **out._device_extra_args())
|
||||
|
||||
@@ -5,12 +5,13 @@ from collections import defaultdict
|
||||
from functools import partialmethod, reduce
|
||||
from itertools import accumulate
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set
|
||||
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import Device, LoadOps
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.realize import run_schedule
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
@@ -89,8 +90,15 @@ class Tensor:
|
||||
|
||||
# ***** data handlers ****
|
||||
|
||||
@staticmethod
|
||||
def corealize(lst:Iterable[Tensor]):
|
||||
seen:Set[LazyBuffer] = set()
|
||||
sched = []
|
||||
for t in lst: sched += t.lazydata.schedule(seen)
|
||||
run_schedule(sched)
|
||||
|
||||
def realize(self) -> Tensor:
|
||||
self.lazydata.realize()
|
||||
run_schedule(self.lazydata.schedule())
|
||||
return self
|
||||
|
||||
def assign(self, x) -> Tensor:
|
||||
@@ -111,7 +119,7 @@ class Tensor:
|
||||
def numpy(self) -> np.ndarray:
|
||||
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
|
||||
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
|
||||
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized._buf.reshape(self.shape)
|
||||
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)
|
||||
|
||||
# TODO: if things are realized this won't work
|
||||
def to_(self, device:str):
|
||||
|
||||
Reference in New Issue
Block a user