mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
touchups from multibuffer branch (#2958)
This commit is contained in:
@@ -33,6 +33,7 @@ class LazyBuffer:
|
||||
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
||||
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None):
|
||||
assert isinstance(device, str)
|
||||
self.device, self.st, self.dtype, self.shape = device, st, dtype, st.shape
|
||||
if base is None:
|
||||
# properties on base
|
||||
@@ -86,6 +87,9 @@ class LazyBuffer:
|
||||
return ret
|
||||
|
||||
def copy_to_device(self, device:str) -> LazyBuffer:
|
||||
# no COPY
|
||||
if self.device == device: return self
|
||||
|
||||
# COPY there and back = no COPY at all
|
||||
if self.base == self and not self.realized and self.op == LoadOps.COPY and self.srcs[0].device == device: return self.srcs[0]
|
||||
|
||||
@@ -109,6 +113,7 @@ class LazyBuffer:
|
||||
else:
|
||||
srcs.append(s)
|
||||
assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}"
|
||||
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
|
||||
if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
||||
out_dtype = srcs[-1].dtype if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool
|
||||
ret = create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
||||
@@ -122,6 +127,8 @@ class LazyBuffer:
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,))
|
||||
|
||||
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
assert len(self.shape) == len(new_shape) and all(s == ns or ns == 1 for s,ns in zip(self.shape, new_shape)), \
|
||||
f"reduce shape lens must match {self.shape} {new_shape}"
|
||||
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
||||
if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
||||
return self._reduce_op(op, new_shape)
|
||||
|
||||
@@ -49,7 +49,7 @@ def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAM
|
||||
class LAMB(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
|
||||
super().__init__(params, lr)
|
||||
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize()
|
||||
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], device=self.device, requires_grad=False).realize()
|
||||
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user