mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
Use ShapeTracker for tracking shapes in kernels (#485)
* local is a normal buffer * remove extra shapes and strides * fix opt * fix llvm
This commit is contained in:
@@ -192,13 +192,10 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
LLVMBuffer.func_cache[k.key](*[x._buf for x in k.bufs])
|
||||
return k.ret
|
||||
|
||||
# cache miss, we have to process the kernel
|
||||
k.process()
|
||||
|
||||
if DEBUG >= 2:
|
||||
print(k.ast)
|
||||
print("old:", k.shapes)
|
||||
print("old:", k.strides)
|
||||
print("old:", [x.shape for x in k.sts])
|
||||
print("old:", [x.views[-1].strides for x in k.sts])
|
||||
|
||||
# this stuff can't be hand coded
|
||||
kernel_output_axis : List[int] = []
|
||||
@@ -242,12 +239,12 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
"""
|
||||
|
||||
# the 4x4 need to go all the way at the end, even after reduce
|
||||
output_shape = k.shapes[0]
|
||||
full_shape = [x for x in k.shapes if x != output_shape]
|
||||
full_shape = output_shape if len(full_shape) == 0 else full_shape[0]
|
||||
output_shape = k.sts[0].shape
|
||||
full_shape_options = [x.shape for x in k.sts if x.shape != output_shape]
|
||||
full_shape = output_shape if len(full_shape_options) == 0 else full_shape_options[0]
|
||||
|
||||
full_shape = full_shape if not kernel_output_axis else full_shape[:-len(kernel_output_axis)]
|
||||
kernel_output_dim = prod([k.shapes[0][a] for a in kernel_output_axis])
|
||||
kernel_output_dim = prod([k.sts[0].shape[a] for a in kernel_output_axis])
|
||||
kernel_output_type = ir.FloatType() if kernel_output_dim == 1 else ir.VectorType(ir.FloatType(), kernel_output_dim)
|
||||
|
||||
def get_idxs(builder, idx, buf_index):
|
||||
@@ -279,13 +276,13 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
loop_exit = loop_exit[::-1]
|
||||
|
||||
# add the buffer indexing
|
||||
idx_level = [[int_const(o)] for o in k.offsets]
|
||||
idx_level = [[int_const(st.offset)] for st in k.sts]
|
||||
for i in range(len(full_shape)):
|
||||
for j in range(len(k.bufs)):
|
||||
# stride
|
||||
si = loop_entry[i+1].phi(ir.IntType(64), name=f"idx_{j}_{i}")
|
||||
si.add_incoming(idx_level[j][-1], loop_entry[i]._block)
|
||||
si_ps = loop_exit[i+1].add(si, int_const(k.strides[j][i]))
|
||||
si_ps = loop_exit[i+1].add(si, int_const(k.sts[j].views[-1].strides[i]))
|
||||
si.add_incoming(si_ps, loop_exit[i+1]._block)
|
||||
idx_level[j].append(si)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user