diff --git a/.gitignore b/.gitignore index dfdbbb4749..7f5d7c5362 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,4 @@ model.safetensors quickstart.py .hypothesis weights +*.lprof diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py index 92fd93f9cd..31647b3332 100644 --- a/test/external/external_test_speed_llama.py +++ b/test/external/external_test_speed_llama.py @@ -54,4 +54,5 @@ class TestLLaMASpeed(unittest.TestCase): Device[Device.DEFAULT].compiler = backup_compiler if __name__ == '__main__': - unittest.main() + TestLLaMASpeed().test_llama_compile() + #unittest.main() diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 54f69583ce..84d0707463 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -155,37 +155,38 @@ class LazyBuffer: # recursively create a lazyop def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker, - realizes:Set[LazyBuffer], first=True, cache=None) -> LazyOp: - if cache is None: cache = {} + realizes:Set[LazyBuffer], cache, first=True) -> LazyOp: if (buf, st) in cache: return cache[(buf, st)] if buf != buf.base: - var_vals.update(merge_dicts([var_vals, buf.st.var_vals])) - st = buf.st.unbind()+st + st = buf.st + st buf = buf.base # all buffers here are base now assert buf.op is not None # consts are always fused and generated if buf.op == LoadOps.CONST: - return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, st.simplify())) + # TODO: make shapetracker unbind also return var_vals + var_vals.update(merge_dicts([var_vals, st.var_vals])) + return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, st.simplify().unbind())) # if we aren't fusing it, it's a load and we add it to the inputs if buf.realized or (buf in realizes and not first): if buf not in inputs: inputs.append(buf) - return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, st.simplify())) + var_vals.update(merge_dicts([var_vals, st.var_vals])) + return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, st.simplify().unbind())) # if a CONTIGUOUS made it all the way here, just skip it if buf.op == LoadOps.CONTIGUOUS: assert first - return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, False, cache) + return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False) # if it's a reduce, we have to change the shapetracker if buf.op in ReduceOps: assert st.contiguous, "ReduceOps late fusion must be contiguous" - st = ShapeTracker.from_shape(buf.srcs[0].shape).unbind() + st = ShapeTracker.from_shape(buf.srcs[0].shape) # otherwise we fuse it like normal - cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, False, cache) for x in buf.srcs), buf.arg) + cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg) return ret # recursively walk back in the graph to create the schedule @@ -204,12 +205,11 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB elif out.op == LoadOps.EMPTY: op = LazyOp(LoadOps.EMPTY) else: - output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape).unbind() - op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes) - op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify())) + output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape) + op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={}) + op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind())) - return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + \ - [ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in op.vars()})] + return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)] # recursively search the entire graph for all LazyBuffers, insert realizes after expands def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],