From e0ecab3797ff6aa5bdbf48791cd4fe9e8ecd8383 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 1 Jan 2024 11:33:41 -0800 Subject: [PATCH] touchups from multibuffer branch (#2958) --- tinygrad/lazy.py | 7 +++++++ tinygrad/nn/optim.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 78abc9a630..5067e857c0 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -33,6 +33,7 @@ class LazyBuffer: def __init__(self, device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), base:Optional[LazyBuffer]=None): + assert isinstance(device, str) self.device, self.st, self.dtype, self.shape = device, st, dtype, st.shape if base is None: # properties on base @@ -86,6 +87,9 @@ class LazyBuffer: return ret def copy_to_device(self, device:str) -> LazyBuffer: + # no COPY + if self.device == device: return self + # COPY there and back = no COPY at all if self.base == self and not self.realized and self.op == LoadOps.COPY and self.srcs[0].device == device: return self.srcs[0] @@ -109,6 +113,7 @@ class LazyBuffer: else: srcs.append(s) assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}" + assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}" if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool" out_dtype = srcs[-1].dtype if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool ret = create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs)) @@ -122,6 +127,8 @@ class LazyBuffer: return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,)) def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: + assert len(self.shape) == len(new_shape) and all(s == ns or ns == 1 for s,ns in zip(self.shape, new_shape)), \ + f"reduce shape lens must match {self.shape} {new_shape}" # TODO: can we split symbolic shape if the reduce axis is not symbolic? if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index dc274471c1..97dcafe312 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -49,7 +49,7 @@ def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAM class LAMB(Optimizer): def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False): super().__init__(params, lr) - self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize() + self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], device=self.device, requires_grad=False).realize() self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]