refactor GenericShape for a big line reduction

This commit is contained in:
George Hotz
2023-02-08 18:01:08 -06:00
parent c656513591
commit cfd13c083b
3 changed files with 27 additions and 29 deletions

View File

@@ -1,6 +1,6 @@
import numpy as np
from typing import ClassVar
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, base_fxn_for_op
specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
UnaryOps.RELU: lambda x: np.maximum(x, 0), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x), BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32),
@@ -9,9 +9,8 @@ specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
MovementOps.STRIDED: lambda x, arg: np.lib.stride_tricks.as_strided(x.ravel().reshape(x.shape), shape=[y[0] for y in arg], strides=[y[1]*x.dtype.itemsize for y in arg])
})
class CPUBuffer(GenericBufExecAST):
class CPUBuffer(GenericExecAST):
fxn_for_op : ClassVar = specialized_fxn_for_op
def __init__(self, lbuf:np.ndarray): self.buf, self.shape = lbuf, tuple(lbuf.shape)
@staticmethod
def fromCPU(x): return CPUBuffer(x)

View File

@@ -1,6 +1,6 @@
import torch
from typing import ClassVar
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericBufExecAST, base_fxn_for_op
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ProcessingOps, GenericExecAST, base_fxn_for_op
from tinygrad.helpers import getenv
specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
@@ -10,9 +10,8 @@ specialized_fxn_for_op = (lambda d: d.update(base_fxn_for_op) or d)({
})
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
class TorchBuffer(GenericBufExecAST):
class TorchBuffer(GenericExecAST):
fxn_for_op : ClassVar = specialized_fxn_for_op
def __init__(self, lbuf:torch.Tensor): self.buf, self.shape = lbuf, tuple(lbuf.shape)
@staticmethod
def fromCPU(data): return TorchBuffer(torch.from_numpy(data).requires_grad_(False).to(device))

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import numpy as np
from enum import Enum, auto
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional
from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict
import functools, operator
from tinygrad.helpers import prod, shape_to_axis
from tinygrad.shape import ShapeTracker
@@ -44,8 +44,28 @@ class DeviceBuffer:
@classmethod
def exec_ast(cls, ast:LazyOp): raise NotImplementedError("must be implemented")
# extend this if you don't have an exec_ast function
# this is a quick "buffer" class for flop tracking
class GenericShape(NamedTuple):
shape : Tuple[int, ...]
flops : int = 0
shape_fxn_for_op : Dict[Op, Callable] = {
**{op:lambda self: GenericShape(self.shape, self.flops + prod(self.shape)) for op in UnaryOps},
**{op:lambda self,y: GenericShape(self.shape, self.flops + y.flops + prod(self.shape)) for op in BinaryOps},
**{op:lambda self,new_shape: GenericShape(new_shape, self.flops + prod(self.shape)) for op in ReduceOps},
**{op:(lambda mop: (lambda self,arg: GenericShape(ShapeTracker(self.shape).movement_op(mop, arg).shape, self.flops)))(op) for op in MovementOps},
# https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html
**{op:lambda self,w,C: GenericShape(C.out_shape, 2 * (C.bs * C.cout * C.oy * C.ox) * (C.cin * C.H * C.W)) for op in ProcessingOps}}
# used in CPUBuffer and TorchBuffer
class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
fxn_for_op : ClassVar = shape_fxn_for_op
# TODO: use generic types here to remove __init__ in specialized classes
def __init__(self, lbuf:Any): self.buf, self.shape = lbuf, tuple(lbuf.shape)
def contiguous(self): return self.unary_op(UnaryOps.NOOP)
def unary_op(self, op): return type(self)(self.fxn_for_op[op](self.buf))
def binary_op(self, op, y): return type(self)(self.fxn_for_op[op](self.buf, y.buf))
def reduce_op(self, op, new_shape): return type(self)(self.fxn_for_op[op](self.buf, new_shape))
def movement_op(self, op, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg))
@classmethod
def exec_ast(cls, ast:LazyOp, preprocess=lambda x: x):
srcs = [cls.exec_ast(x, preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src]
@@ -64,17 +84,7 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
else:
raise TypeError("unknown op")
return ret
# used in CPUBuffer and TorchBuffer
class GenericBufExecAST(GenericExecAST): # pylint: disable=abstract-method
fxn_for_op : ClassVar
# TODO: use generic types here to remove __init__ in specialized classes
def __init__(self, lbuf:Any): self.buf, self.shape = lbuf, tuple(lbuf.shape)
def contiguous(self): return self.unary_op(UnaryOps.NOOP)
def unary_op(self, op): return type(self)(self.fxn_for_op[op](self.buf))
def binary_op(self, op, y): return type(self)(self.fxn_for_op[op](self.buf, y.buf))
def reduce_op(self, op, new_shape): return type(self)(self.fxn_for_op[op](self.buf, new_shape))
def movement_op(self, op, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg))
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, lambda x: GenericExecAST(GenericShape(x.shape))).buf
base_fxn_for_op = {
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
@@ -93,16 +103,6 @@ class GlobalCounters:
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0,0,None
class GenericShape(GenericExecAST): # pylint: disable=abstract-method
def __init__(self, shape, flops=0): self.shape, self.flops = shape, flops
def unary_op(self, op:UnaryOps): return GenericShape(self.shape, self.flops + prod(self.shape))
def binary_op(self, op:BinaryOps, y): return GenericShape(self.shape, self.flops + y.flops + prod(self.shape))
def reduce_op(self, op:ReduceOps, new_shape:Tuple[int, ...]): return GenericShape(new_shape, self.flops + prod(self.shape))
def movement_op(self, op:MovementOps, arg): return GenericShape(ShapeTracker(self.shape).movement_op(op, arg).shape, self.flops)
# https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html
def processing_op(self, op:ProcessingOps, w, C): return GenericShape(C.out_shape, 2 * (C.bs * C.cout * C.oy * C.ox) * (C.cin * C.H * C.W))
def get_lazyop_info(ast:LazyOp): return GenericShape.exec_ast(ast, lambda x: GenericShape(x.shape))
# assumes you are using ShapeTracker
# used in GPUBuffer and LLVMBuffer
class ExplicitExecAST(DeviceBuffer): # pylint: disable=abstract-method