fix _to_const_val and const folding around it (#4017)

* fix _to_const_val and const folding around it

is_unrealized_contiguous_const is too strict and almost never hit if const is expanded.
suffice to check if there's no pad

* that test is folded

* test_const_folding
This commit is contained in:
chenyu
2024-03-31 13:09:23 -04:00
committed by GitHub
parent 2abb474d43
commit 7f859593b8
5 changed files with 67 additions and 8 deletions

View File

@@ -83,7 +83,7 @@ class MultiLazyBuffer:
return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs)
# TODO: fix this
def is_unrealized_contiguous_const(self): return False
def is_unrealized_unpadded_const(self): return False
# passthroughs
def is_realized(self) -> bool: return all([lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True])

View File

@@ -91,7 +91,7 @@ class LazyBuffer:
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST
def is_unrealized_contiguous_const(self): return self.base == self and self.base.realized is None and self.op is LoadOps.CONST
def is_unrealized_unpadded_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
def _copy(self, device:str) -> LazyBuffer:
if (dstart:=self.device.split(":")[0]) in {"EXT", "DISK"} or (dstart in {"HSA", "CUDA"} and device.split(":")[0] == dstart):

View File

@@ -885,7 +885,7 @@ class Tensor:
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
# TODO: update with multi
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unpadded_const() \
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: