deletions from an ops.py "instant rule" audit [pr] (#8424)

* UOp.st cleanup 2 [pr]

* deletions from an ops.py instant rule audit [pr]

* note
This commit is contained in:
qazal
2024-12-26 18:49:04 +02:00
committed by GitHub
parent 22abd9dc03
commit b5820a5209
2 changed files with 16 additions and 17 deletions

View File

@@ -1930,7 +1930,7 @@ class TestSwizzle(unittest.TestCase):
base = ShapeTracker.from_shape((32, 16, 1))
start = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop()))
r = start.expand((32, 16, 16)).r(Ops.ADD, (2,))
add = r.reshape((16, 32, 1)) + UOp.const_with_shape(r.dtype, 0, (16, 32, 1))
add = r.reshape((16, 32, 1)) + UOp.const(r.dtype, 0)
self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1)))
to_store = add.permute((1, 0, 2)).contiguous()
to_store = graph_rewrite(to_store, remove_movement_ops)
@@ -1941,6 +1941,8 @@ class TestSwizzle(unittest.TestCase):
self.assertEqual(swizzle_cnt(ret), 1)
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
# TODO: we only need valid on ast consts if it's masked, can fold this early to UOp.const
zero_pm = UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat(Ops.CONST, arg=0), UPat.cvar()))
class TestView(unittest.TestCase):
def test_all_masked_out(self):
# start with non CONST Ops
@@ -1948,8 +1950,7 @@ class TestView(unittest.TestCase):
# all masked out, degrades to const 0
b = a.pad(((0, 10), None))[10:]
sched = check_schedule(b.contiguous(), 1)
# TODO: this VALID can clean up, where do we need st?
self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape))
assert zero_pm.match(store_val(sched[-1]), {})
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)
@@ -1960,7 +1961,7 @@ class TestView(unittest.TestCase):
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
self.assertEqual(sched[-1].ast.full_shape, (10, 10))
self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape))
assert zero_pm.match(store_val(sched[-1]), {})
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)

View File

@@ -276,15 +276,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
@functools.cached_property
def st(self) -> Optional[ShapeTracker]:
def st(self) -> ShapeTracker|None:
# these uops define ShapeTracker from the arg
if self.op is Ops.VIEW: return self.arg
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
# buffer ops can have a non contiguous shapetracker
if self.op in GroupOp.Buffer and len(src_sts:=[unwrap(x.st) for x in self.src if x.op is Ops.VIEW]) != 0: return src_sts[0]
# otherwise we derive the st from sources
if len(src_sts:=[x.st for x in self.src if x.st is not None]) == 0: return None
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
# all other ops have a contiguous shapetracker
# st_arg on buffer uops defines the ShapeTracker, it's allowed to be non contiguous
if self.op in GroupOp.Buffer: return self.st_arg
# all other uops have a contiguous ShapeTracker
from tinygrad.shape.shapetracker import ShapeTracker
# only reduceop is allowed to change shape
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op in (Ops.REDUCE_AXIS, Ops.WMMA) else src_sts[0].shape)
@functools.cached_property
def full_shape(self) -> tuple[sint, ...]:
@@ -292,7 +295,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
@property
def size(self) -> int: return self.arg[-1] if self.op is Ops.BUFFER else unwrap(self.st).size
def size(self) -> int: return self.arg[1] if self.op is Ops.BUFFER else unwrap(self.st).size
# *** uop evaluation ***
@@ -338,8 +341,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def const_like(self, b:ConstLike):
if self._device is not None: return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
return UOp.const(self.dtype, b) if self.st is None else UOp.const_with_shape(self.dtype, b, self.shape)
# constants can optionally have a DEVICE source
return UOp.const(self.dtype, b) if self._device is None else UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
def broadcast(self, count:int):
assert self.dtype.count == 1
if count == 1: return self
@@ -429,10 +432,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** from LazyBuffer ***
@staticmethod
def const_with_shape(dtype:DType, val:ConstLike, shape:tuple[sint,...]) -> UOp:
from tinygrad.shape.shapetracker import ShapeTracker
return UOp(Ops.VALID, dtypes.bool, (ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)).where(UOp.const(dtype, val), 0)
@staticmethod
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None, src:tuple[UOp, ...]=()) -> UOp:
from tinygrad.shape.shapetracker import ShapeTracker
@@ -506,8 +505,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
buffer_num = itertools.count(0)
@staticmethod
def new_buffer(device:str, size:int, dtype:DType) -> UOp:
return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size))
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size))
@property
def device(self) -> str: return unwrap(self._device)
@functools.cached_property