mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
3 convs are being recomputed
This commit is contained in:
@@ -47,6 +47,12 @@ def cmp(buf1:LazyBuffer, buf2:LazyBuffer):
|
||||
expanded2.add(x2)
|
||||
return 0
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def depends(haystack:LazyBuffer, needle:LazyBuffer):
|
||||
gen = get_lazybuffers(haystack.op)
|
||||
if needle in gen: return True
|
||||
return any(depends(x, needle) for x in gen if x.realized is None)
|
||||
|
||||
class LazyBuffer:
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int]], optype:Op, op:LazyOp):
|
||||
self.st = ShapeTracker(shape)
|
||||
@@ -138,12 +144,14 @@ def elementwise_op(op, srcs:Tuple[LazyBuffer]) -> LazyBuffer:
|
||||
else:
|
||||
order = cmp(srcs[0], srcs[1])
|
||||
if order == -1:
|
||||
#if depends(srcs[0], srcs[1]):
|
||||
srcs = [srcs[0].op, srcs[1]]
|
||||
elif order == 1:
|
||||
#elif depends(srcs[1], srcs[0]):
|
||||
srcs = [srcs[0], srcs[1].op]
|
||||
else:
|
||||
# all three are okay
|
||||
#return Buffer(out_shape, BinaryOps, LazyOp(op, list(srcs)))
|
||||
#return LazyBuffer(out_shape, BinaryOps, LazyOp(op, list(srcs)))
|
||||
srcs = [srcs[0].op, srcs[1]]
|
||||
#srcs = [srcs[0], srcs[1].op]
|
||||
return LazyBuffer(out_shape, ProcessingOps, LazyOp(op, srcs))
|
||||
|
||||
@@ -107,6 +107,7 @@ class OpenCLBuffer(GPUBuffer):
|
||||
self._buf = None
|
||||
return self._image
|
||||
|
||||
seen = set()
|
||||
def _processing_op(ret, bufs: List[Tuple[str, OpenCLBuffer]]=[], code:str="acc", C=None):
|
||||
if C is None:
|
||||
# TODO: handle an opencl conv without the conv part
|
||||
@@ -116,6 +117,10 @@ class OpenCLBuffer(GPUBuffer):
|
||||
x,w = bufs[0][1], bufs[1][1]
|
||||
ewbufs = bufs[2:]
|
||||
|
||||
if tuple(bufs[0:2]) in OpenCLBuffer.seen:
|
||||
print("WARNING: recomputing CONV with", bufs[0], bufs[1])
|
||||
OpenCLBuffer.seen.add(tuple(bufs[0:2]))
|
||||
|
||||
ewtypes = []
|
||||
getters = []
|
||||
for name, buf in ewbufs:
|
||||
|
||||
@@ -24,9 +24,9 @@ if GRAPH:
|
||||
|
||||
global_num_max = 0
|
||||
def log_op(optype, op, ret, inp):
|
||||
cnts[optype] += 1
|
||||
if DEBUG: print(f"{op} : {', '.join([str(x.shape) for x in inp])} -> {ret.shape}")
|
||||
if GRAPH:
|
||||
cnts[optype] += 1
|
||||
def nm(x):
|
||||
global global_num_max
|
||||
if getattr(x, 'global_num', None) is None:
|
||||
|
||||
Reference in New Issue
Block a user