From cfd13c083b85b5da20eafcd835d49f075eae1a2c Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 8 Feb 2023 18:01:08 -0600 Subject: [PATCH] refactor GenericShape for a big line reduction --- tinygrad/llops/ops_cpu.py | 5 ++-- tinygrad/llops/ops_torch.py | 5 ++-- tinygrad/ops.py | 46 ++++++++++++++++++------------------- 3 files changed, 27 insertions(+), 29 deletions(-) diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 11265eeb42..63e8cf7e9d 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -1,6 +1,6 @@ import numpy as np from typing import ClassVar -from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op +from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, base_fxn_for_op specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({ UnaryOps.RELU: lambda x: np.maximum(x, 0), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x), BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32), @@ -9,9 +9,8 @@ specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({ MovementOps.STRIDED: lambda x, arg: np.lib.stride_tricks.as_strided(x.ravel().reshape(x.shape), shape=[y[0] for y in arg], strides=[y[1]*x.dtype.itemsize for y in arg]) }) -class CPUBuffer(GenericBufExecAST): +class CPUBuffer(GenericExecAST): fxn_for_op : ClassVar = specialized_fxn_for_op - def __init__(self, lbuf:np.ndarray): self.buf, self.shape = lbuf, tuple(lbuf.shape) @staticmethod def fromCPU(x): return CPUBuffer(x) diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index de185f961b..3d341a5215 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -1,6 +1,6 @@ import torch from typing import ClassVar -from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op +from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, base_fxn_for_op from tinygrad.helpers import getenv specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({ @@ -10,9 +10,8 @@ specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({ }) device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) -class TorchBuffer(GenericBufExecAST): +class TorchBuffer(GenericExecAST): fxn_for_op : ClassVar = specialized_fxn_for_op - def __init__(self, lbuf:torch.Tensor): self.buf, self.shape = lbuf, tuple(lbuf.shape) @staticmethod def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False).to(device)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 67c0aefdb8..80f2461ed4 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,7 +1,7 @@ from __future__ import annotations import numpy as np from enum import Enum, auto -from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional +from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict import functools, operator from tinygrad.helpers import prod, shape_to_axis from tinygrad.shape import ShapeTracker @@ -44,8 +44,28 @@ class DeviceBuffer: @classmethod def exec_ast(cls, ast:LazyOp): raise NotImplementedError("must be implemented") -# extend this if you don't have an exec_ast function +# this is a quick "buffer" class for flop tracking +class GenericShape(NamedTuple): + shape : Tuple[int, ...] + flops : int = 0 +shape_fxn_for_op : Dict[Op, Callable] = { + **{op:lambda self: GenericShape(self.shape, self.flops + prod(self.shape)) for op in UnaryOps}, + **{op:lambda self,y: GenericShape(self.shape, self.flops + y.flops + prod(self.shape)) for op in BinaryOps}, + **{op:lambda self,new_shape: GenericShape(new_shape, self.flops + prod(self.shape)) for op in ReduceOps}, + **{op:(lambda mop: (lambda self,arg: GenericShape(ShapeTracker(self.shape).movement_op(mop, arg).shape, self.flops)))(op) for op in MovementOps}, + # https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html + **{op:lambda self,w,C: GenericShape(C.out_shape, 2 * (C.bs * C.cout * C.oy * C.ox) * (C.cin * C.H * C.W)) for op in ProcessingOps}} + +# used in CPUBuffer and TorchBuffer class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method + fxn_for_op : ClassVar = shape_fxn_for_op + # TODO: use generic types here to remove __init__ in specialized classes + def __init__(self, lbuf:Any): self.buf, self.shape = lbuf, tuple(lbuf.shape) + def contiguous(self): return self.unary_op(UnaryOps.NOOP) + def unary_op(self, op): return type(self)(self.fxn_for_op[op](self.buf)) + def binary_op(self, op, y): return type(self)(self.fxn_for_op[op](self.buf, y.buf)) + def reduce_op(self, op, new_shape): return type(self)(self.fxn_for_op[op](self.buf, new_shape)) + def movement_op(self, op, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg)) @classmethod def exec_ast(cls, ast:LazyOp, preprocess=lambda x: x): srcs = [cls.exec_ast(x, preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src] @@ -64,17 +84,7 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method else: raise TypeError("unknown op") return ret - -# used in CPUBuffer and TorchBuffer -class GenericBufExecAST(GenericExecAST): # pylint: disable=abstract-method - fxn_for_op : ClassVar - # TODO: use generic types here to remove __init__ in specialized classes - def __init__(self, lbuf:Any): self.buf, self.shape = lbuf, tuple(lbuf.shape) - def contiguous(self): return self.unary_op(UnaryOps.NOOP) - def unary_op(self, op): return type(self)(self.fxn_for_op[op](self.buf)) - def binary_op(self, op, y): return type(self)(self.fxn_for_op[op](self.buf, y.buf)) - def reduce_op(self, op, new_shape): return type(self)(self.fxn_for_op[op](self.buf, new_shape)) - def movement_op(self, op, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg)) +def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, lambda x: GenericExecAST(GenericShape(x.shape))).buf base_fxn_for_op = { UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x, @@ -93,16 +103,6 @@ class GlobalCounters: @staticmethod def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0,0,None -class GenericShape(GenericExecAST): # pylint: disable=abstract-method - def __init__(self, shape, flops=0): self.shape, self.flops = shape, flops - def unary_op(self, op:UnaryOps): return GenericShape(self.shape, self.flops + prod(self.shape)) - def binary_op(self, op:BinaryOps, y): return GenericShape(self.shape, self.flops + y.flops + prod(self.shape)) - def reduce_op(self, op:ReduceOps, new_shape:Tuple[int, ...]): return GenericShape(new_shape, self.flops + prod(self.shape)) - def movement_op(self, op:MovementOps, arg): return GenericShape(ShapeTracker(self.shape).movement_op(op, arg).shape, self.flops) - # https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html - def processing_op(self, op:ProcessingOps, w, C): return GenericShape(C.out_shape, 2 * (C.bs * C.cout * C.oy * C.ox) * (C.cin * C.H * C.W)) -def get_lazyop_info(ast:LazyOp): return GenericShape.exec_ast(ast, lambda x: GenericShape(x.shape)) - # assumes you are using ShapeTracker # used in GPUBuffer and LLVMBuffer class ExplicitExecAST(DeviceBuffer): # pylint: disable=abstract-method