mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
more from indexer
This commit is contained in:
@@ -317,9 +317,9 @@ class OpenCLBuffer(GPUBuffer):
|
||||
{chr(10).join([f' float {name} = ' + late_views[name][2] for name in late_views])}
|
||||
output[gid] = {code};
|
||||
}}
|
||||
}}""")
|
||||
}}""", op_estimate=op_estimate)
|
||||
|
||||
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl, op_estimate=op_estimate)
|
||||
conv_prg([prod(ret.shape), inter_red, 1], [1, inter_red, 1] if inter_red > 1 else None, ret.cl, *buf_cl)
|
||||
return ret
|
||||
|
||||
def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None, op=ReduceOps.SUM, reduce_shape=None, earlybufs:Set[str]=set(), earlycode:str="acc", op_estimate=0):
|
||||
|
||||
@@ -7,6 +7,8 @@ def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0
|
||||
def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True
|
||||
def colored(st, color): return f"\u001b[{30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color)}m{st}\u001b[0m" # replace the termcolor library with one line
|
||||
def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if not fxn(x)]
|
||||
def modn(x, a): return -((-x)%a) if x < 0 else x%a
|
||||
|
||||
def reduce_shape(shape, axis): return tuple(1 if i in axis else shape[i] for i in range(len(shape)))
|
||||
def shape_to_axis(old_shape, new_shape):
|
||||
|
||||
Reference in New Issue
Block a user