From b2d5df60494997a8ca0ef9b84a836eb4f0cfd892 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 22 Jun 2022 07:54:52 -0700 Subject: [PATCH] 3 convs are being recomputed --- accel/lazy/ops_lazy.py | 10 +++++++++- accel/opencl/ops_opencl.py | 5 +++++ tinygrad/ops.py | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/accel/lazy/ops_lazy.py b/accel/lazy/ops_lazy.py index c17b923980..683d15dab9 100644 --- a/accel/lazy/ops_lazy.py +++ b/accel/lazy/ops_lazy.py @@ -47,6 +47,12 @@ def cmp(buf1:LazyBuffer, buf2:LazyBuffer): expanded2.add(x2) return 0 +@functools.lru_cache(maxsize=None) +def depends(haystack:LazyBuffer, needle:LazyBuffer): + gen = get_lazybuffers(haystack.op) + if needle in gen: return True + return any(depends(x, needle) for x in gen if x.realized is None) + class LazyBuffer: def __init__(self, shape:Union[ShapeTracker, Tuple[int]], optype:Op, op:LazyOp): self.st = ShapeTracker(shape) @@ -138,12 +144,14 @@ def elementwise_op(op, srcs:Tuple[LazyBuffer]) -> LazyBuffer: else: order = cmp(srcs[0], srcs[1]) if order == -1: + #if depends(srcs[0], srcs[1]): srcs = [srcs[0].op, srcs[1]] elif order == 1: + #elif depends(srcs[1], srcs[0]): srcs = [srcs[0], srcs[1].op] else: # all three are okay - #return Buffer(out_shape, BinaryOps, LazyOp(op, list(srcs))) + #return LazyBuffer(out_shape, BinaryOps, LazyOp(op, list(srcs))) srcs = [srcs[0].op, srcs[1]] #srcs = [srcs[0], srcs[1].op] return LazyBuffer(out_shape, ProcessingOps, LazyOp(op, srcs)) diff --git a/accel/opencl/ops_opencl.py b/accel/opencl/ops_opencl.py index 8e8a51640d..e00cc1d3f6 100644 --- a/accel/opencl/ops_opencl.py +++ b/accel/opencl/ops_opencl.py @@ -107,6 +107,7 @@ class OpenCLBuffer(GPUBuffer): self._buf = None return self._image + seen = set() def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None): if C is None: # TODO: handle an opencl conv without the conv part @@ -116,6 +117,10 @@ class OpenCLBuffer(GPUBuffer): x,w = bufs[0][1], bufs[1][1] ewbufs = bufs[2:] + if tuple(bufs[0:2]) in OpenCLBuffer.seen: + print("WARNING: recomputing CONV with", bufs[0], bufs[1]) + OpenCLBuffer.seen.add(tuple(bufs[0:2])) + ewtypes = [] getters = [] for name, buf in ewbufs: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 3e547ed07c..d68df775a6 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -24,9 +24,9 @@ if GRAPH: global_num_max = 0 def log_op(optype, op, ret, inp): - cnts[optype] += 1 if DEBUG: print(f"{op} : {', '.join([str(x.shape) for x in inp])} -> {ret.shape}") if GRAPH: + cnts[optype] += 1 def nm(x): global global_num_max if getattr(x, 'global_num', None) is None: