mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 15:15:13 -05:00
const fold cast unrealized_unpadded_const (#4047)
* const fold unrealized_unpadded_const changed the underlying arg directly * CAST_BEFORE_VIEW folds some * fix const index in getitem
This commit is contained in:
@@ -78,6 +78,8 @@ class LazyBuffer:
|
||||
def cast(self, dtype:DType, bitcast:bool=False):
|
||||
if self.dtype == dtype: return self
|
||||
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
||||
if self.is_unrealized_unpadded_const() and not bitcast:
|
||||
return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
|
||||
# TODO: applying this makes gpt2 slower
|
||||
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
|
||||
return self.base.cast(dtype, bitcast)._view(self.st)
|
||||
|
||||
@@ -422,7 +422,7 @@ class Tensor:
|
||||
else: indices = [indices]
|
||||
|
||||
# turn scalar Tensors into const val for int indexing if possible
|
||||
indices = [self._to_const_val(i) if isinstance(i, Tensor) else i for i in indices]
|
||||
indices = [self._to_const_val(i) if isinstance(i, Tensor) and i.shape == () else i for i in indices]
|
||||
# move Tensor indices to the same device as self
|
||||
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user