process replayable ops.py changes from delete_lazy [pr] (#7994)

* process replayable ops.py changes from delete_lazy [pr]

* hotfix: seed tiny_jit
This commit is contained in:
qazal
2024-12-02 06:38:31 -05:00
committed by GitHub
parent 0c7477b108
commit bb606e5bcf
2 changed files with 6 additions and 10 deletions

View File

@@ -41,6 +41,7 @@ class TestTiny(unittest.TestCase):
def test_jit(self):
cnt = 0
random.seed(0)
def new_rand_list(ln=10): return [random.randint(0, 100000) for _ in range(ln)]
@TinyJit

View File

@@ -277,7 +277,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape)
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
@property
def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape
@property
@@ -338,8 +338,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
def alu(self, arg, *src:UOp):
out_dtype = (self, *src)[-1].dtype
if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None:
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(arg, out_dtype, (self,)+src)
@staticmethod
def const(dtype:DType, b:ConstLike):
@@ -384,13 +383,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
buffer_num = itertools.count(0)
@staticmethod
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype)))
def new_buffer(device:str, size:int, dtype:DType) -> UOp: return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype)))
@functools.cached_property
def device(self) -> str:
match self.op:
case Ops.COPY: return self.arg
case Ops.BUFFER: return self.arg[1][0]
case _: return self.src[0].device
def device(self) -> str: return self.arg[1][0] if self.op is Ops.BUFFER else self.src[0].device
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self
@@ -627,7 +622,7 @@ class UPat(MathTrait):
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
def alu(self, op:Ops, *src:UPat):
asrc = (self,)+src
return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
def printable(self:UPat) -> str:
try: return lines(self.location[0])[self.location[1]-1].strip()