diff --git a/test/test_net_speed.py b/test/test_net_speed.py index eaf1f90a80..1013d4959d 100644 --- a/test/test_net_speed.py +++ b/test/test_net_speed.py @@ -4,7 +4,7 @@ import cProfile import pstats import unittest import torch -from tinygrad.tensor import Tensor +from tinygrad.tensor import Tensor, Device def start_profile(): import time @@ -19,6 +19,7 @@ def stop_profile(pr, sort='cumtime'): ps.sort_stats(sort) ps.print_stats(0.2) +@unittest.skipUnless(getattr(Device, "OPENCL", None) is None or Device.DEFAULT != Device.OPENCL, "OOM on OpenCL") class TestConvSpeed(unittest.TestCase): def test_mnist(self): diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 0645447ce8..1d0039dac1 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -115,13 +115,12 @@ class GPUBuffer: return type(x)(new_shape)._processing_op([("A", x)], GPUBuffer.code_for_op[op], None, GPUBuffer.start_for_op[op]) def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0") -> GPUBuffer: - ints, params, ewbufs, conv_src = '', [], bufs, '' - global_size = [prod(ret.shape), 1, 1] - loop : List[Tuple[str, str]] = [] + assert C is None + # this takes a ret index to an inp index, indexing 0 on the reduced strides # if it's not a reduce, this should be a NOOP view = View(ret.shape, strides_for_shape(bufs[0][1].shape)) - assert C is None + loop : List[Tuple[str, str]] = [] if ret.shape != bufs[0][1].shape: # this is a reduce # reverse operation of expand, this validates inputs # generate loops with combined adjacent reduce axis @@ -131,16 +130,16 @@ class GPUBuffer: acc *= shp kernel_name = "reduce" if len(loop) > 0 else "elementwise" - views = {name:buf.contiguous_view_constant_fold(name) for name, buf in ewbufs} + views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs} buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]] conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])} - __kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + [x[0] for x in params])}) {{ {ints} - float acc = {start}; int gid = get_global_id(0); {conv_src} int idx = gid; {view.expr.replace('//', '/')}; + __kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types)}) {{ + float acc = {start}; int gid = get_global_id(0); int idx = gid; {view.expr.replace('//', '/')}; {' '.join([ls for ls, _ in loop[::-1]])} -{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in ewbufs])} +{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs])} acc = {code}; {' '.join([le for _, le in loop])} output[gid] = acc; - }}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)+len(params)))) - conv_prg(global_size, None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], *[x[1] for x in params]) + }}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types)))) + conv_prg([prod(ret.shape), 1, 1], None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]]) return ret