mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
cleanups
This commit is contained in:
@@ -3,7 +3,7 @@ import os, functools
|
||||
import numpy as np
|
||||
import pyopencl as cl # type: ignore
|
||||
from collections import defaultdict
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
from typing import List, Tuple, Optional, Dict, Union
|
||||
from tinygrad.helpers import prod, ConvArgs
|
||||
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape
|
||||
@@ -66,8 +66,8 @@ class GPUBuffer:
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
}
|
||||
|
||||
def __init__(self, shape, hostbuf:Optional[GPUBuffer]=None, backing:Optional[np.ndarray]=None):
|
||||
self.st = ShapeTracker(shape)
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], hostbuf:Optional[GPUBuffer]=None, backing:Optional[np.ndarray]=None):
|
||||
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
self.shape = self.st.shape
|
||||
self._buf : Optional[CLBuffer] = hostbuf._buf if hostbuf is not None else None
|
||||
self._base_shape : Tuple[int, ...] = hostbuf._base_shape if hostbuf is not None else self.shape
|
||||
@@ -84,7 +84,7 @@ class GPUBuffer:
|
||||
return self._buf.cl
|
||||
|
||||
def __repr__(self): return f"<GPUBuffer with shape {self.shape!r}>"
|
||||
def shapeTrackerView(self, st:ShapeTracker): return GPUBuffer(st, hostbuf=self)
|
||||
def shapeTrackerView(self, st:ShapeTracker): return GPUBuffer(ShapeTracker(st), hostbuf=self)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x): return GPUBuffer(x.shape, backing=x.view(np.ndarray).astype(np.float32).ravel())
|
||||
@@ -106,20 +106,14 @@ class GPUBuffer:
|
||||
def unary_op(x, op:UnaryOps): return type(x)(x.shape)._processing_op([("A", x)], GPUBuffer.code_for_op[op])
|
||||
def binary_op(x, op:BinaryOps, y:GPUBuffer): return type(x)(x.shape)._processing_op([("A", x), ("B", y)], GPUBuffer.code_for_op[op])
|
||||
def contiguous_op(x): return x if x.st.contiguous else x.unary_op(UnaryOps.NOOP)
|
||||
|
||||
def movement_op(x, op:MovementOps, arg) -> GPUBuffer:
|
||||
ret = type(x)(x.st, x)
|
||||
ret.shape = ret.st.movement_op(op, arg).shape
|
||||
return ret
|
||||
def movement_op(x, op:MovementOps, arg) -> GPUBuffer: return type(x)(ShapeTracker(x.st).movement_op(op, arg), x)
|
||||
|
||||
def processing_op(x, op:ProcessingOps, w:GPUBuffer, C:ConvArgs):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
return type(x)(C.out_shape)._processing_op([("input", x.contiguous_op()), ("weight", w.contiguous_op())], "acc", C)
|
||||
|
||||
def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]):
|
||||
if op == ReduceOps.SUM: code, start = "acc + A", "0.0"
|
||||
elif op == ReduceOps.MAX: code, start = "max(A, acc)", "-INFINITY"
|
||||
return type(x)(new_shape)._processing_op([("A", x)], code, None, start)
|
||||
return type(x)(new_shape)._processing_op([("A", x)], {ReduceOps.SUM: "acc + A", ReduceOps.MAX: "max(A, acc)"}[op], None, {ReduceOps.SUM: "0.0", ReduceOps.MAX: "-INFINITY"}[op])
|
||||
|
||||
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0") -> GPUBuffer:
|
||||
ints, params, ewbufs, conv_src = '', [], bufs, ''
|
||||
@@ -132,13 +126,14 @@ class GPUBuffer:
|
||||
ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "sy", "sx", "dx", "dy", "px", "py", "groups", "rcout", "cin"])
|
||||
params = [(f"int {x}", getattr(C, x)) for x in ["oy", "ox", "iy", "ix"]]
|
||||
global_size = [C.bs*C.cout, C.oy, C.ox]
|
||||
assert ret.shape == C.out_shape, "output shape is wrong (can't reduce and conv together)"
|
||||
assert ret.shape == C.out_shape, "output shape is wrong (NOTE: you can't reduce and conv together)"
|
||||
|
||||
# now input and weight can be anywhere in bufs
|
||||
bufs = [(x[0], x[1].contiguous_op()) if x[0] in ["input", "weight"] else x for x in bufs]
|
||||
ewbufs = [x for x in bufs if x[0] not in ["input", "weight"]]
|
||||
assert len(bufs) == len(ewbufs)+2, "input or weight missing"
|
||||
|
||||
# TODO: is there a way to unify this with reduce? it looks very similar
|
||||
conv_src = """
|
||||
int B = gid/(groups*rcout); int g = (gid/rcout)%groups; int c = gid % rcout;
|
||||
int Y = get_global_id(1); int X = get_global_id(2); gid = gid*oy*ox + Y*ox + X; idx = gid;
|
||||
@@ -168,7 +163,7 @@ class GPUBuffer:
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ {ints}
|
||||
float acc = {start}; int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')}; {conv_src}
|
||||
{''.join([ls for ls, _ in loop[::-1]])}
|
||||
{''.join([f'float {name} = get_{name}({name}_g, idx);' if views[name][1] else f'float {name} = get_{name}(idx);' for name, _ in ewbufs])}
|
||||
{''.join([f'float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in ewbufs])}
|
||||
acc = {code};
|
||||
{''.join([le for _, le in loop])}
|
||||
output[gid] = acc;
|
||||
|
||||
Reference in New Issue
Block a user