3 convs are being recomputed

This commit is contained in:
George Hotz
2022-06-22 07:54:52 -07:00
parent ba2defcdef
commit b2d5df6049
3 changed files with 15 additions and 2 deletions

View File

@@ -47,6 +47,12 @@ def cmp(buf1:LazyBuffer, buf2:LazyBuffer):
expanded2.add(x2)
return 0
@functools.lru_cache(maxsize=None)
def depends(haystack:LazyBuffer, needle:LazyBuffer):
gen = get_lazybuffers(haystack.op)
if needle in gen: return True
return any(depends(x, needle) for x in gen if x.realized is None)
class LazyBuffer:
def __init__(self, shape:Union[ShapeTracker, Tuple[int]], optype:Op, op:LazyOp):
self.st = ShapeTracker(shape)
@@ -138,12 +144,14 @@ def elementwise_op(op, srcs:Tuple[LazyBuffer]) -> LazyBuffer:
else:
order = cmp(srcs[0], srcs[1])
if order == -1:
#if depends(srcs[0], srcs[1]):
srcs = [srcs[0].op, srcs[1]]
elif order == 1:
#elif depends(srcs[1], srcs[0]):
srcs = [srcs[0], srcs[1].op]
else:
# all three are okay
#return Buffer(out_shape, BinaryOps, LazyOp(op, list(srcs)))
#return LazyBuffer(out_shape, BinaryOps, LazyOp(op, list(srcs)))
srcs = [srcs[0].op, srcs[1]]
#srcs = [srcs[0], srcs[1].op]
return LazyBuffer(out_shape, ProcessingOps, LazyOp(op, srcs))

View File

@@ -107,6 +107,7 @@ class OpenCLBuffer(GPUBuffer):
self._buf = None
return self._image
seen = set()
def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None):
if C is None:
# TODO: handle an opencl conv without the conv part
@@ -116,6 +117,10 @@ class OpenCLBuffer(GPUBuffer):
x,w = bufs[0][1], bufs[1][1]
ewbufs = bufs[2:]
if tuple(bufs[0:2]) in OpenCLBuffer.seen:
print("WARNING: recomputing CONV with", bufs[0], bufs[1])
OpenCLBuffer.seen.add(tuple(bufs[0:2]))
ewtypes = []
getters = []
for name, buf in ewbufs:

View File

@@ -24,9 +24,9 @@ if GRAPH:
global_num_max = 0
def log_op(optype, op, ret, inp):
cnts[optype] += 1
if DEBUG: print(f"{op} : {', '.join([str(x.shape) for x in inp])} -> {ret.shape}")
if GRAPH:
cnts[optype] += 1
def nm(x):
global global_num_max
if getattr(x, 'global_num', None) is None: