2 stage reduce

This commit is contained in:
George Hotz
2022-08-20 12:20:32 -07:00
parent b020fc2225
commit c5433db050
2 changed files with 18 additions and 10 deletions

View File

@@ -112,9 +112,9 @@ class GPUBuffer:
def binary_op(x, op:BinaryOps, y:GPUBuffer): return type(x)(x.shape)._processing_op([("A", x), ("B", y)], GPUBuffer.code_for_op[op])
def contiguous_op(x): return x if x.st.contiguous else x.unary_op(UnaryOps.NOOP)
def movement_op(x, op:MovementOps, arg) -> GPUBuffer: return type(x)(ShapeTracker(x.st).movement_op(op, arg), x)
def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]): return type(x)(new_shape)._processing_op([("A", x)], code="acc", earlycode=GPUBuffer.code_for_op[op], earlybufs=set("A"), start=GPUBuffer.start_for_op[op])
def reduce_op(x, op:ReduceOps, new_shape:Tuple[int, ...]): return type(x)(new_shape)._processing_op([("A", x)], code="acc", earlycode=GPUBuffer.code_for_op[op], earlybufs=set("A"), op=op)
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, start="0.0", reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer:
def _processing_op(ret, bufs: List[Tuple[str, GPUBuffer]]=[], code:str="acc", C:Optional[ConvArgs]=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc") -> GPUBuffer:
assert C is None
# get the input/output shape and the reduce amount
@@ -123,19 +123,27 @@ class GPUBuffer:
# if it's a partial reduce, assert last non reduced axis is before the first reduced axis
if red > 1 and prod(ret.shape) != 1: assert max([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s == n and n != 1]) < min([i for i,(s,n) in enumerate(zip(*reduce_shape)) if s != 1 and n == 1])
inter_red = 256 if (prod(ret.shape) < 256 and red >= 256) else 1
kernel_name = "reduce" if red > 1 else "elementwise"
views = {name:buf.contiguous_view_constant_fold(name) for name, buf in bufs}
buf_types = [f"__global const float *{name}_g" for name, _ in bufs if name not in views or views[name][1]]
conv_prg = CLProgram(kernel_name, f"""{chr(10).join([x[0] for x in views.values()])}
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types)}) {{
float acc = {start}; int gid = get_global_id(0);
for (int idx = gid * {red}; idx < gid * {red} + {red}; idx++) {{
__kernel void {kernel_name}({','.join(["__global float* restrict output"] + buf_types + (["__local float *temp"] if inter_red > 1 else []))}) {{
float acc = {GPUBuffer.start_for_op[op]}; int gid = get_global_id(0); int mid = get_global_id(1);
for (int idx = gid * {red} + {red//inter_red + 1} * mid; idx < gid * {red} + min({red}, {red//inter_red + 1} * (mid+1)); idx++) {{
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, idx);' if views[name][1] else f'get_{name}(idx);') for name, _ in bufs if name in earlybufs])}
acc = {earlycode};
}}"""+(f"""
temp[mid] = acc; barrier(CLK_LOCAL_MEM_FENCE);
if (mid == 0) {{ acc = {GPUBuffer.start_for_op[op]};
for (int rdx = 0; rdx < {inter_red}; rdx++) {{ acc = {GPUBuffer.code_for_op[op].replace('A', 'temp[rdx]')}; }}
""" if inter_red != 1 else "{")+f"""
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, gid);' if views[name][1] else f'get_{name}(gid);') for name, _ in bufs if name not in earlybufs])}
output[gid] = {code};
}}
{chr(10).join([f' float {name} = ' + (f'get_{name}({name}_g, gid);' if views[name][1] else f'get_{name}(gid);') for name, _ in bufs if name not in earlybufs])}
output[gid] = {code};
}}""", argdtypes=tuple(None if i < 1+len(buf_types) else np.int32 for i in range(1+len(buf_types))))
conv_prg([prod(ret.shape), 1, 1], None, ret.cl, *[buf.cl for name, buf in bufs if name not in views or views[name][1]], op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))
}}""")
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl,
*([buf.cl for name, buf in bufs if name not in views or views[name][1]] + ([cl.LocalMemory(inter_red*4)] if inter_red > 1 else [])),
op_estimate=prod(reduce_shape[0])*len(earlybufs) + prod(reduce_shape[1])*len(bufs))
return ret

View File

@@ -123,7 +123,7 @@ def _realize_reduceops(self:LazyBuffer) -> Tuple[DeviceBuffer, List[DeviceBuffer
buf_names : Dict[LazyBuffer, str] = {x:f"arg_{i}" for i,x in enumerate(real_srcs.keys())}
return self.dbuffer(self.shape)._processing_op([(buf_names[lb], db) for lb,db in real_srcs.items()],
earlycode=_ast(LazyOp(self.op.op, (src.op,), self.op.arg), buf_names, self.dbuffer.code_for_op), earlybufs=buf_names.values(), start=self.dbuffer.start_for_op[self.op.op]), \
earlycode=_ast(LazyOp(self.op.op, (src.op,), self.op.arg), buf_names, self.dbuffer.code_for_op), earlybufs=buf_names.values(), op=self.op.op), \
list(real_srcs.values()), ReduceOps
else:
real_src = src.realize(self.device)