mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 15:15:13 -05:00
fix _to_const_val and const folding around it (#4017)
* fix _to_const_val and const folding around it is_unrealized_contiguous_const is too strict and almost never hit if const is expanded. suffice to check if there's no pad * that test is folded * test_const_folding
This commit is contained in:
@@ -83,7 +83,7 @@ class MultiLazyBuffer:
|
||||
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
|
||||
|
||||
# TODO: fix this
|
||||
def is_unrealized_contiguous_const(self): return False
|
||||
def is_unrealized_unpadded_const(self): return False
|
||||
|
||||
# passthroughs
|
||||
def is_realized(self) -> bool: return all([lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True])
|
||||
|
||||
@@ -91,7 +91,7 @@ class LazyBuffer:
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
|
||||
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST
|
||||
def is_unrealized_contiguous_const(self): return self.base == self and self.base.realized is None and self.op is LoadOps.CONST
|
||||
def is_unrealized_unpadded_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
||||
|
||||
def _copy(self, device:str) -> LazyBuffer:
|
||||
if (dstart:=self.device.split(":")[0]) in {"EXT", "DISK"} or (dstart in {"HSA", "CUDA"} and device.split(":")[0] == dstart):
|
||||
|
||||
@@ -885,7 +885,7 @@ class Tensor:
|
||||
|
||||
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
|
||||
# TODO: update with multi
|
||||
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \
|
||||
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unpadded_const() \
|
||||
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
||||
|
||||
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user