From 328cf2e86afb96c13d43e981b9acf66ccd650a05 Mon Sep 17 00:00:00 2001 From: Roelof van Dijk <3604013+roelofvandijk@users.noreply.github.com> Date: Sun, 27 Aug 2023 20:15:52 +0200 Subject: [PATCH] perf: remove cast and revert back to isinstance (#1694) Co-authored-by: Roelof van Dijk --- tinygrad/lazy.py | 7 ++++--- tinygrad/ops.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 8d1c0f784f..dd6bc5d85b 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -150,7 +150,7 @@ class LazyBuffer: for x in self.op.buffers: x.realize() # HACK: image shape can be wrong, hot cast it back to a normal float - if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())): + if isinstance(self.dtype, ImageDType) and self.optype != MovementOps and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())): if self.op.op == MovementOps.RESHAPE: # put CAST before the final RESHAPE self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, (dtypes.float32, False)),), self.op.arg) @@ -315,7 +315,8 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]: while not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1: assert isinstance(bx.op.op, MovementOps) mops.append((bx.op.op, bx.op.arg)) - bx = cast(LazyBuffer, bx.op.src[0]) + assert isinstance(bx.op.src[0], LazyBuffer) + bx = bx.op.src[0] # NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0 if mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and (all(x[0] is not MovementOps.PAD for x in mops) or all(x.op not in UNSAFE_PAD_OPS for x in bx.op.get_lazyops())): new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1])) @@ -325,7 +326,7 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]: def _realize_contiguous(buffer: LazyBuffer) -> None: realized = buffer.op.src[0].realize().realized - if buffer.op.src[0].st.contiguous and realized.__class__ is not RawConst and cast(RawBuffer, realized).size == prod(buffer.shape): + if buffer.op.src[0].st.contiguous and realized.__class__ is not RawConst and realized is not None and realized.size == prod(buffer.shape): # no need to run an AST, this is already contiguous buffer.realized = realized else: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 648a04677d..387d9956dd 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -96,12 +96,12 @@ class Interpreted: self.codegen = None def exec_ast(self, ast:LazyOp, output=None, context=None, **kwargs): - if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and ast.src[0].__class__ is LazyOp and ast.src[0].op == BinaryOps.MUL: - ast = LazyOp(TernaryOps.MULACC, cast(LazyOp, ast.src[0]).src, ast.arg) + if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: + ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) created_context = context is None if context is None: context = dict() if not created_context and ast in context: return context[ast] - srcs = [self.exec_ast(cast(LazyOp, x), context=context, **kwargs) if x.__class__ is LazyOp else self.from_lazybuffer(x) for x in ast.src] + srcs = [self.exec_ast(x, context=context, **kwargs) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src] if DEBUG >= 3: st = time.perf_counter() ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else [])))) if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.