const fold cast unrealized_unpadded_const (#4047)

* const fold unrealized_unpadded_const

changed the underlying arg directly

* CAST_BEFORE_VIEW folds some

* fix const index in getitem
This commit is contained in:
chenyu
2024-04-03 12:31:24 -04:00
committed by GitHub
parent e5a9bff899
commit fe03725b21
5 changed files with 19 additions and 6 deletions

View File

@@ -78,6 +78,8 @@ class LazyBuffer:
def cast(self, dtype:DType, bitcast:bool=False):
if self.dtype == dtype: return self
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
if self.is_unrealized_unpadded_const() and not bitcast:
return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
# TODO: applying this makes gpt2 slower
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
return self.base.cast(dtype, bitcast)._view(self.st)

View File

@@ -422,7 +422,7 @@ class Tensor:
else: indices = [indices]
# turn scalar Tensors into const val for int indexing if possible
indices = [self._to_const_val(i) if isinstance(i, Tensor) else i for i in indices]
indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
# move Tensor indices to the same device as self
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]