mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
2 stage reduce
This commit is contained in:
@@ -112,9 +112,9 @@ class GPUBuffer:
|
||||
def binary_op(x, op:BinaryOps, y:GPUBuffer): return type(x)(x.shape)._processing_op([("A", x), ("B", y)], GPUBuffer.code_for_op[op])
|
||||
def contiguous_op(x): return x if x.st.contiguous else x.unary_op(UnaryOps.NOOP)
|
||||
def movement_op(x, op:MovementOps, arg) -> GPUBuffer: return type(x)(ShapeTracker(x.st).movement_op(op, arg), x)
|
||||
def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]): return type(x)(new_shape)._processing_op([("A", x)], code="acc", earlycode=GPUBuffer.code_for_op[op], earlybufs=set("A"), start=GPUBuffer.start_for_op[op])
|
||||
def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]): return type(x)(new_shape)._processing_op([("A", x)], code="acc", earlycode=GPUBuffer.code_for_op[op], earlybufs=set("A"), op=op)
|
||||
|
||||
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0", reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer:
|
||||
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer:
|
||||
assert C is None
|
||||
|
||||
# get the input/output shape and the reduce amount
|
||||
@@ -123,19 +123,27 @@ class GPUBuffer:
|
||||
|
||||
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
|
||||
if red > 1 and prod(ret.shape) != 1: assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
|
||||
inter_red = 256 if (prod(ret.shape) < 256 and red >= 256) else 1
|
||||
|
||||
kernel_name = "reduce" if red > 1 else "elementwise"
|
||||
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs}
|
||||
buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]]
|
||||
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types)}) {{
|
||||
float acc = {start}; int gid = get_global_id(0);
|
||||
for (int idx = gid * {red}; idx < gid * {red} + {red}; idx++) {{
|
||||
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
|
||||
float acc = {GPUBuffer.start_for_op[op]}; int gid = get_global_id(0); int mid = get_global_id(1);
|
||||
for (int idx = gid * {red} + {red//inter_red + 1} * mid; idx < gid * {red} + min({red}, {red//inter_red + 1} * (mid+1)); idx++) {{
|
||||
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs if name in earlybufs])}
|
||||
acc = {earlycode};
|
||||
}}"""+(f"""
|
||||
temp[mid] = acc; barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (mid == 0) {{ acc = {GPUBuffer.start_for_op[op]};
|
||||
for (int rdx = 0; rdx < {inter_red}; rdx++) {{ acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')}; }}
|
||||
""" if inter_red != 1 else "{")+f"""
|
||||
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, gid);' if views[name][1] else f'get_{name}(gid);') for name, _ in bufs if name not in earlybufs])}
|
||||
output[gid] = {code};
|
||||
}}
|
||||
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, gid);' if views[name][1] else f'get_{name}(gid);') for name, _ in bufs if name not in earlybufs])}
|
||||
output[gid] = {code};
|
||||
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types))))
|
||||
conv_prg([prod(ret.shape), 1, 1], None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))
|
||||
}}""")
|
||||
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl,
|
||||
*([buf.cl for name, buf in bufs if name not in views or views[name][1]] + ([cl.LocalMemory(inter_red*4)] if inter_red > 1 else [])),
|
||||
op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))
|
||||
return ret
|
||||
|
||||
@@ -123,7 +123,7 @@ def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
|
||||
buf_names : Dict[LazyBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(real_srcs.keys())}
|
||||
|
||||
return self.dbuffer(self.shape)._processing_op([(buf_names[lb], db) for lb,db in real_srcs.items()],
|
||||
earlycode=_ast(LazyOp(self.op.op, (src.op,), self.op.arg), buf_names, self.dbuffer.code_for_op), earlybufs=buf_names.values(), start=self.dbuffer.start_for_op[self.op.op]), \
|
||||
earlycode=_ast(LazyOp(self.op.op, (src.op,), self.op.arg), buf_names, self.dbuffer.code_for_op), earlybufs=buf_names.values(), op=self.op.op), \
|
||||
list(real_srcs.values()), ReduceOps
|
||||
else:
|
||||
real_src = src.realize(self.device)
|
||||
|
||||
Reference in New Issue
Block a user