mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
deletions from an ops.py "instant rule" audit [pr] (#8424)
* UOp.st cleanup 2 [pr] * deletions from an ops.py instant rule audit [pr] * note
This commit is contained in:
@@ -1930,7 +1930,7 @@ class TestSwizzle(unittest.TestCase):
|
||||
base = ShapeTracker.from_shape((32, 16, 1))
|
||||
start = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop()))
|
||||
r = start.expand((32, 16, 16)).r(Ops.ADD, (2,))
|
||||
add = r.reshape((16, 32, 1)) + UOp.const_with_shape(r.dtype, 0, (16, 32, 1))
|
||||
add = r.reshape((16, 32, 1)) + UOp.const(r.dtype, 0)
|
||||
self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1)))
|
||||
to_store = add.permute((1, 0, 2)).contiguous()
|
||||
to_store = graph_rewrite(to_store, remove_movement_ops)
|
||||
@@ -1941,6 +1941,8 @@ class TestSwizzle(unittest.TestCase):
|
||||
self.assertEqual(swizzle_cnt(ret), 1)
|
||||
|
||||
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
|
||||
# TODO: we only need valid on ast consts if it's masked, can fold this early to UOp.const
|
||||
zero_pm = UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat(Ops.CONST, arg=0), UPat.cvar()))
|
||||
class TestView(unittest.TestCase):
|
||||
def test_all_masked_out(self):
|
||||
# start with non CONST Ops
|
||||
@@ -1948,8 +1950,7 @@ class TestView(unittest.TestCase):
|
||||
# all masked out, degrades to const 0
|
||||
b = a.pad(((0, 10), None))[10:]
|
||||
sched = check_schedule(b.contiguous(), 1)
|
||||
# TODO: this VALID can clean up, where do we need st?
|
||||
self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape))
|
||||
assert zero_pm.match(store_val(sched[-1]), {})
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(b.numpy(), 0)
|
||||
|
||||
@@ -1960,7 +1961,7 @@ class TestView(unittest.TestCase):
|
||||
assert b.shape == (10, 10)
|
||||
sched = check_schedule(b.contiguous(), 1)
|
||||
self.assertEqual(sched[-1].ast.full_shape, (10, 10))
|
||||
self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape))
|
||||
assert zero_pm.match(store_val(sched[-1]), {})
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(b.numpy(), 0)
|
||||
|
||||
|
||||
@@ -276,15 +276,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
|
||||
@functools.cached_property
|
||||
def st(self) -> Optional[ShapeTracker]:
|
||||
def st(self) -> ShapeTracker|None:
|
||||
# these uops define ShapeTracker from the arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
||||
# buffer ops can have a non contiguous shapetracker
|
||||
if self.op in GroupOp.Buffer and len(src_sts:=[unwrap(x.st) for x in self.src if x.op is Ops.VIEW]) != 0: return src_sts[0]
|
||||
# otherwise we derive the st from sources
|
||||
if len(src_sts:=[x.st for x in self.src if x.st is not None]) == 0: return None
|
||||
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
|
||||
# all other ops have a contiguous shapetracker
|
||||
# st_arg on buffer uops defines the ShapeTracker, it's allowed to be non contiguous
|
||||
if self.op in GroupOp.Buffer: return self.st_arg
|
||||
# all other uops have a contiguous ShapeTracker
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
# only reduceop is allowed to change shape
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op in (Ops.REDUCE_AXIS, Ops.WMMA) else src_sts[0].shape)
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> tuple[sint, ...]:
|
||||
@@ -292,7 +295,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
||||
@property
|
||||
def size(self) -> int: return self.arg[-1] if self.op is Ops.BUFFER else unwrap(self.st).size
|
||||
def size(self) -> int: return self.arg[1] if self.op is Ops.BUFFER else unwrap(self.st).size
|
||||
|
||||
# *** uop evaluation ***
|
||||
|
||||
@@ -338,8 +341,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def const_like(self, b:ConstLike):
|
||||
if self._device is not None: return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
|
||||
return UOp.const(self.dtype, b) if self.st is None else UOp.const_with_shape(self.dtype, b, self.shape)
|
||||
# constants can optionally have a DEVICE source
|
||||
return UOp.const(self.dtype, b) if self._device is None else UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.count == 1
|
||||
if count == 1: return self
|
||||
@@ -429,10 +432,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# *** from LazyBuffer ***
|
||||
|
||||
@staticmethod
|
||||
def const_with_shape(dtype:DType, val:ConstLike, shape:tuple[sint,...]) -> UOp:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
return UOp(Ops.VALID, dtypes.bool, (ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)).where(UOp.const(dtype, val), 0)
|
||||
@staticmethod
|
||||
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None, src:tuple[UOp, ...]=()) -> UOp:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -506,8 +505,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
buffer_num = itertools.count(0)
|
||||
@staticmethod
|
||||
def new_buffer(device:str, size:int, dtype:DType) -> UOp:
|
||||
return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size))
|
||||
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size))
|
||||
@property
|
||||
def device(self) -> str: return unwrap(self._device)
|
||||
@functools.cached_property
|
||||
|
||||
Reference in New Issue
Block a user