mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
No extra vars call (#3054)
* remove unused reciprocal * comment * remove unneeded call to vars * free speedup
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -49,3 +49,4 @@ model.safetensors
|
||||
quickstart.py
|
||||
.hypothesis
|
||||
weights
|
||||
*.lprof
|
||||
|
||||
3
test/external/external_test_speed_llama.py
vendored
3
test/external/external_test_speed_llama.py
vendored
@@ -54,4 +54,5 @@ class TestLLaMASpeed(unittest.TestCase):
|
||||
Device[Device.DEFAULT].compiler = backup_compiler
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
TestLLaMASpeed().test_llama_compile()
|
||||
#unittest.main()
|
||||
|
||||
@@ -155,37 +155,38 @@ class LazyBuffer:
|
||||
|
||||
# recursively create a lazyop
|
||||
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Set[LazyBuffer], first=True, cache=None) -> LazyOp:
|
||||
if cache is None: cache = {}
|
||||
realizes:Set[LazyBuffer], cache, first=True) -> LazyOp:
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
if buf != buf.base:
|
||||
var_vals.update(merge_dicts([var_vals, buf.st.var_vals]))
|
||||
st = buf.st.unbind()+st
|
||||
st = buf.st + st
|
||||
buf = buf.base
|
||||
# all buffers here are base now
|
||||
assert buf.op is not None
|
||||
|
||||
# consts are always fused and generated
|
||||
if buf.op == LoadOps.CONST:
|
||||
return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, st.simplify()))
|
||||
# TODO: make shapetracker unbind also return var_vals
|
||||
var_vals.update(merge_dicts([var_vals, st.var_vals]))
|
||||
return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, st.simplify().unbind()))
|
||||
|
||||
# if we aren't fusing it, it's a load and we add it to the inputs
|
||||
if buf.realized or (buf in realizes and not first):
|
||||
if buf not in inputs: inputs.append(buf)
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, st.simplify()))
|
||||
var_vals.update(merge_dicts([var_vals, st.var_vals]))
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, st.simplify().unbind()))
|
||||
|
||||
# if a CONTIGUOUS made it all the way here, just skip it
|
||||
if buf.op == LoadOps.CONTIGUOUS:
|
||||
assert first
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, False, cache)
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
|
||||
|
||||
# if it's a reduce, we have to change the shapetracker
|
||||
if buf.op in ReduceOps:
|
||||
assert st.contiguous, "ReduceOps late fusion must be contiguous"
|
||||
st = ShapeTracker.from_shape(buf.srcs[0].shape).unbind()
|
||||
st = ShapeTracker.from_shape(buf.srcs[0].shape)
|
||||
|
||||
# otherwise we fuse it like normal
|
||||
cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, False, cache) for x in buf.srcs), buf.arg)
|
||||
cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg)
|
||||
return ret
|
||||
|
||||
# recursively walk back in the graph to create the schedule
|
||||
@@ -204,12 +205,11 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
|
||||
elif out.op == LoadOps.EMPTY:
|
||||
op = LazyOp(LoadOps.EMPTY)
|
||||
else:
|
||||
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape).unbind()
|
||||
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes)
|
||||
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify()))
|
||||
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
||||
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
|
||||
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()))
|
||||
|
||||
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + \
|
||||
[ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in op.vars()})]
|
||||
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
|
||||
|
||||
Reference in New Issue
Block a user