diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index a26bc1bf27..a8c182e4ed 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -3,7 +3,7 @@ import os, functools import numpy as np import pyopencl as cl # type: ignore from collections import defaultdict -from typing import List, Tuple, Optional, Dict +from typing import List, Tuple, Optional, Dict, Union from tinygrad.helpers import prod, ConvArgs from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape @@ -66,8 +66,8 @@ class GPUBuffer: BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)", } - def __init__(self, shape, hostbuf:Optional[GPUBuffer]=None, backing:Optional[np.ndarray]=None): - self.st = ShapeTracker(shape) + def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[GPUBuffer]=None, backing:Optional[np.ndarray]=None): + self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape)) self.shape = self.st.shape self._buf : Optional[CLBuffer] = hostbuf._buf if hostbuf is not None else None self._base_shape : Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape @@ -84,7 +84,7 @@ class GPUBuffer: return self._buf.cl def __repr__(self): return f"" - def shapeTrackerView(self, st:ShapeTracker): return GPUBuffer(st, hostbuf=self) + def shapeTrackerView(self, st:ShapeTracker): return GPUBuffer(ShapeTracker(st), hostbuf=self) @staticmethod def fromCPU(x): return GPUBuffer(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel()) @@ -106,20 +106,14 @@ class GPUBuffer: def unary_op(x, op:UnaryOps): return type(x)(x.shape)._processing_op([("A", x)], GPUBuffer.code_for_op[op]) def binary_op(x, op:BinaryOps, y:GPUBuffer): return type(x)(x.shape)._processing_op([("A", x), ("B", y)], GPUBuffer.code_for_op[op]) def contiguous_op(x): return x if x.st.contiguous else x.unary_op(UnaryOps.NOOP) - - def movement_op(x, op:MovementOps, arg) -> GPUBuffer: - ret = type(x)(x.st, x) - ret.shape = ret.st.movement_op(op, arg).shape - return ret + def movement_op(x, op:MovementOps, arg) -> GPUBuffer: return type(x)(ShapeTracker(x.st).movement_op(op, arg), x) def processing_op(x, op:ProcessingOps, w:GPUBuffer, C:ConvArgs): assert op == ProcessingOps.CONV, f"{op} isn't supported" return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C) def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]): - if op == ReduceOps.SUM: code, start = "acc + A", "0.0" - elif op == ReduceOps.MAX: code, start = "max(A, acc)", "-INFINITY" - return type(x)(new_shape)._processing_op([("A", x)], code, None, start) + return type(x)(new_shape)._processing_op([("A", x)], {ReduceOps.SUM: "acc + A", ReduceOps.MAX: "max(A, acc)"}[op], None, {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}[op]) def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0") -> GPUBuffer: ints, params, ewbufs, conv_src = '', [], bufs, '' @@ -132,13 +126,14 @@ class GPUBuffer: ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "sy", "sx", "dx", "dy", "px", "py", "groups", "rcout", "cin"]) params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]] global_size = [C.bs*C.cout, C.oy, C.ox] - assert ret.shape == C.out_shape, "output shape is wrong (can't reduce and conv together)" + assert ret.shape == C.out_shape, "output shape is wrong (NOTE: you can't reduce and conv together)" # now input and weight can be anywhere in bufs bufs = [(x[0], x[1].contiguous_op()) if x[0] in ["input", "weight"] else x for x in bufs] ewbufs = [x for x in bufs if x[0] not in ["input", "weight"]] assert len(bufs) == len(ewbufs)+2, "input or weight missing" + # TODO: is there a way to unify this with reduce? it looks very similar conv_src = """ int B = gid/(groups*rcout); int g = (gid/rcout)%groups; int c = gid % rcout; int Y = get_global_id(1); int X = get_global_id(2); gid = gid*oy*ox + Y*ox + X; idx = gid; @@ -168,7 +163,7 @@ class GPUBuffer: __kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ {ints} float acc = {start}; int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')}; {conv_src} {''.join([ls for ls, _ in loop[::-1]])} - {''.join([f'float {name} = get_{name}({name}_g, idx);' if views[name][1] else f'float {name} = get_{name}(idx);' for name, _ in ewbufs])} + {''.join([f'float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in ewbufs])} acc = {code}; {''.join([le for _, le in loop])} output[gid] = acc;