touchups from multibuffer branch (#2958)

This commit is contained in:
George Hotz
2024-01-01 11:33:41 -08:00
committed by GitHub
parent 45247385eb
commit e0ecab3797
2 changed files with 8 additions and 1 deletions

View File

@@ -33,6 +33,7 @@ class LazyBuffer:
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None):
assert isinstance(device, str)
self.device, self.st, self.dtype, self.shape = device, st, dtype, st.shape
if base is None:
# properties on base
@@ -86,6 +87,9 @@ class LazyBuffer:
return ret
def copy_to_device(self, device:str) -> LazyBuffer:
# no COPY
if self.device == device: return self
# COPY there and back = no COPY at all
if self.base == self and not self.realized and self.op == LoadOps.COPY and self.srcs[0].device == device: return self.srcs[0]
@@ -109,6 +113,7 @@ class LazyBuffer:
else:
srcs.append(s)
assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}"
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
out_dtype = srcs[-1].dtype if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool
ret = create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
@@ -122,6 +127,8 @@ class LazyBuffer:
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,))
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
assert len(self.shape) == len(new_shape) and all(s == ns or ns == 1 for s,ns in zip(self.shape, new_shape)), \
f"reduce shape lens must match {self.shape} {new_shape}"
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
return self._reduce_op(op, new_shape)

View File

@@ -49,7 +49,7 @@ def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAM
class LAMB(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
super().__init__(params, lr)
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize()
self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], device=self.device, requires_grad=False).realize()
self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]
self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params]