Use ShapeTracker for tracking shapes in kernels (#485)

* local is a normal buffer

* remove extra shapes and strides

* fix opt

* fix llvm
This commit is contained in:
George Hotz
2023-01-28 11:56:32 -08:00
committed by GitHub
parent 259c48f235
commit b3e4e678e8
4 changed files with 73 additions and 87 deletions

View File

@@ -192,13 +192,10 @@ class LLVMBuffer(ExplicitExecAST):
LLVMBuffer.func_cache[k.key](*[x._buf for x in k.bufs])
return k.ret
# cache miss, we have to process the kernel
k.process()
if DEBUG >= 2:
print(k.ast)
print("old:", k.shapes)
print("old:", k.strides)
print("old:", [x.shape for x in k.sts])
print("old:", [x.views[-1].strides for x in k.sts])
# this stuff can't be hand coded
kernel_output_axis : List[int] = []
@@ -242,12 +239,12 @@ class LLVMBuffer(ExplicitExecAST):
"""
# the 4x4 need to go all the way at the end, even after reduce
output_shape = k.shapes[0]
full_shape = [x for x in k.shapes if x != output_shape]
full_shape = output_shape if len(full_shape) == 0 else full_shape[0]
output_shape = k.sts[0].shape
full_shape_options = [x.shape for x in k.sts if x.shape != output_shape]
full_shape = output_shape if len(full_shape_options) == 0 else full_shape_options[0]
full_shape = full_shape if not kernel_output_axis else full_shape[:-len(kernel_output_axis)]
kernel_output_dim = prod([k.shapes[0][a] for a in kernel_output_axis])
kernel_output_dim = prod([k.sts[0].shape[a] for a in kernel_output_axis])
kernel_output_type = ir.FloatType() if kernel_output_dim == 1 else ir.VectorType(ir.FloatType(), kernel_output_dim)
def get_idxs(builder, idx, buf_index):
@@ -279,13 +276,13 @@ class LLVMBuffer(ExplicitExecAST):
loop_exit = loop_exit[::-1]
# add the buffer indexing
idx_level = [[int_const(o)] for o in k.offsets]
idx_level = [[int_const(st.offset)] for st in k.sts]
for i in range(len(full_shape)):
for j in range(len(k.bufs)):
# stride
si = loop_entry[i+1].phi(ir.IntType(64), name=f"idx_{j}_{i}")
si.add_incoming(idx_level[j][-1], loop_entry[i]._block)
si_ps = loop_exit[i+1].add(si, int_const(k.strides[j][i]))
si_ps = loop_exit[i+1].add(si, int_const(k.sts[j].views[-1].strides[i]))
si.add_incoming(si_ps, loop_exit[i+1]._block)
idx_level[j].append(si)