is_unrealized_unpadded_const -> is_unrealized_unmasked_const (#4071)

realized #3580 was doing the same thing. unmasked is more accurate
This commit is contained in:
chenyu
2024-04-04 14:25:17 -04:00
committed by GitHub
parent 82b7b9655f
commit f836d6a03f
3 changed files with 13 additions and 13 deletions

View File

@@ -70,9 +70,9 @@ class MultiLazyBuffer:
@staticmethod
def from_sharded(lb:LazyBuffer, devices:Tuple[str, ...], axis:Optional[int]=None):
lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unpadded_const() else lb] * len(devices)
lbs = [lb.contiguous() if lb.base != lb and not lb.is_unrealized_unmasked_const() else lb] * len(devices)
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis) if axis is not None else lbs, devices)]
return MultiLazyBuffer([lb if lb.is_unrealized_unpadded_const() else lb.contiguous() for lb in sharded_lbs], axis)
return MultiLazyBuffer([lb if lb.is_unrealized_unmasked_const() else lb.contiguous() for lb in sharded_lbs], axis)
def copy_to_device(self, device:str) -> LazyBuffer:
if self.axis is None: return self.lbs[self.real.index(True)].copy_to_device(device)

View File

@@ -78,7 +78,7 @@ class LazyBuffer:
def cast(self, dtype:DType, bitcast:bool=False):
if self.dtype == dtype: return self
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
if self.is_unrealized_unpadded_const() and not bitcast:
if self.is_unrealized_unmasked_const() and not bitcast:
return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
# TODO: applying this makes gpt2 slower
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
@@ -93,7 +93,7 @@ class LazyBuffer:
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST
def is_unrealized_unpadded_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
def _copy(self, device:str) -> LazyBuffer:
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
@@ -131,17 +131,17 @@ class LazyBuffer:
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
# const folding
if op in python_alu and all(s.is_unrealized_unpadded_const() for s in srcs):
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
if op in BinaryOps: x, y = self, in_srcs[0]
if op is BinaryOps.ADD:
if y.is_unrealized_unpadded_const() and y.base.arg == 0: return x
if x.is_unrealized_unpadded_const() and x.base.arg == 0: return y
if op is BinaryOps.SUB and y.is_unrealized_unpadded_const() and y.base.arg == 0: return x
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
if op is BinaryOps.SUB and y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
if op is BinaryOps.MUL:
if x.is_unrealized_unpadded_const() and (val := x.base.arg) in (1, 0): return {1: y, 0: y.const(0)}[val]
if y.is_unrealized_unpadded_const() and (val := y.base.arg) in (1, 0): return {1: x, 0: x.const(0)}[val]
if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unpadded_const() and y.base.arg != 0:
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return {1: y, 0: y.const(0)}[val]
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return {1: x, 0: x.const(0)}[val]
if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unmasked_const() and y.base.arg != 0:
return x.e(BinaryOps.MUL, x.const(1 / y.base.arg))
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
@@ -160,7 +160,7 @@ class LazyBuffer:
# TODO: this logic should move to the scheduler
if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
if self.is_unrealized_unpadded_const():
if self.is_unrealized_unmasked_const():
return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
# TODO: can we split symbolic shape if the reduce axis is not symbolic?

View File

@@ -885,7 +885,7 @@ class Tensor:
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
# TODO: update with multi
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unpadded_const() \
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Add.apply(*self._broadcasted(x, reverse))