allow symbolic shape in tensor const parents [pr] (#8699)

This commit is contained in:
qazal
2025-01-21 05:01:25 -05:00
committed by GitHub
parent 2b239db5d2
commit e2008c98c3
2 changed files with 8 additions and 2 deletions

View File

@@ -2273,6 +2273,12 @@ class TestTensorUOpSpec(unittest.TestCase):
t = graph_rewrite(a.lazydata.sink(), remove_movement_ops+merge_views)
create_schedule_with_vars(t)
def test_symbolic_shape_ok(self):
a = Tensor.ones(4)
vi = UOp.variable("i", 1, 10).bind(4)
t = graph_rewrite(a.reshape(vi).sum().lazydata, remove_movement_ops+merge_views)
create_schedule_with_vars(t)
class TestBufferUOp(unittest.TestCase):
# BUFFER has a ShapeTracker of shape=(n,) and stride=(1,)
def test_buffer_has_buffer(self):

View File

@@ -31,9 +31,9 @@ tensor_uop_spec = PatternMatcher([
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
(UPat(Ops.DEFINE_VAR, src=(UPat(Ops.VIEW, arg=ShapeTracker.from_shape(()))), arg=None), lambda: True),
# Tensor const has an unmasked ShapeTracker of stride 0 and a device
# Tensor const has a device and an unmasked ShapeTracker of stride 0 or a ShapeTracker with symbolic shape
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
lambda st: len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides) and st.st.views[0].mask is None),
lambda st: st.st.views[0].mask is None and ((len(st.st.views) == 1 and all(s == 0 for s in st.st.views[0].strides)) or not all_int(st.shape))),
# DETACH and CONTIGUOUS change how we interpret the source UOp
# CONTIGUOUS ensures the source UOp realizes