more from indexer

This commit is contained in:
George Hotz
2023-01-18 18:11:51 -08:00
parent 9245f4650a
commit 3a3400e3a2
2 changed files with 4 additions and 2 deletions

View File

@@ -317,9 +317,9 @@ class OpenCLBuffer(GPUBuffer):
{chr(10).join([f' float {name} = ' + late_views[name][2] for name in late_views])}
output[gid] = {code};
}}
}}""")
}}""", op_estimate=op_estimate)
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl, op_estimate=op_estimate)
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl)
return ret
def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc", op_estimate=0):

View File

@@ -7,6 +7,8 @@ def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0
def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True
def colored(st, color): return f"\u001b[{30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color)}m{st}\u001b[0m" # replace the termcolor library with one line
def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)]
def modn(x, a): return -((-x)%a) if x < 0 else x%a
def reduce_shape(shape, axis): return tuple(1 if i in axis else shape[i] for i in range(len(shape)))
def shape_to_axis(old_shape, new_shape):