mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
process replayable ops.py changes from delete_lazy [pr] (#7994)
* process replayable ops.py changes from delete_lazy [pr] * hotfix: seed tiny_jit
This commit is contained in:
@@ -41,6 +41,7 @@ class TestTiny(unittest.TestCase):
|
||||
|
||||
def test_jit(self):
|
||||
cnt = 0
|
||||
random.seed(0)
|
||||
def new_rand_list(ln=10): return [random.randint(0, 100000) for _ in range(ln)]
|
||||
|
||||
@TinyJit
|
||||
|
||||
@@ -277,7 +277,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape)
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
@property
|
||||
def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape
|
||||
@property
|
||||
@@ -338,8 +338,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def alu(self, arg, *src:UOp):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None:
|
||||
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
if arg in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return UOp(arg, out_dtype, (self,)+src)
|
||||
@staticmethod
|
||||
def const(dtype:DType, b:ConstLike):
|
||||
@@ -384,13 +383,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
buffer_num = itertools.count(0)
|
||||
@staticmethod
|
||||
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype)))
|
||||
def new_buffer(device:str, size:int, dtype:DType) -> UOp: return UOp(Ops.BUFFER, dtype.ptr(), (), (next(UOp.buffer_num), (device, size, dtype)))
|
||||
@functools.cached_property
|
||||
def device(self) -> str:
|
||||
match self.op:
|
||||
case Ops.COPY: return self.arg
|
||||
case Ops.BUFFER: return self.arg[1][0]
|
||||
case _: return self.src[0].device
|
||||
def device(self) -> str: return self.arg[1][0] if self.op is Ops.BUFFER else self.src[0].device
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
if self.op is Ops.BUFFER: return self
|
||||
@@ -627,7 +622,7 @@ class UPat(MathTrait):
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
||||
def alu(self, op:Ops, *src:UPat):
|
||||
asrc = (self,)+src
|
||||
return UPat(op, None if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
||||
return UPat(op, dtypes.bool if op in {Ops.CMPLT, Ops.CMPNE} else asrc[-1].dtype, list(asrc) if op in GroupOp.Commutative else asrc)
|
||||
|
||||
def printable(self:UPat) -> str:
|
||||
try: return lines(self.location[0])[self.location[1]-1].strip()
|
||||
|
||||
Reference in New Issue
Block a user