mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
multitensor shouldn't recompile (#4164)
* multitensor shouldn't recompile * type annotations * fix tests * outcount in reduce
This commit is contained in:
4
test/external/external_test_speed_llama.py
vendored
4
test/external/external_test_speed_llama.py
vendored
@@ -4,7 +4,7 @@ from examples.llama import Transformer, MODEL_PARAMS
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import Device
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
from tinygrad.device import Allocator
|
||||
from tinygrad.device import Allocator, method_cache
|
||||
from tinygrad.helpers import Profiling
|
||||
|
||||
class FakeProgram:
|
||||
@@ -31,7 +31,7 @@ class TestLLaMASpeed(unittest.TestCase):
|
||||
print("assigned empty tensors, doing warmup")
|
||||
|
||||
def run_llama(st, empty_method_cache=True):
|
||||
if empty_method_cache: Device[Device.DEFAULT].get_runner.cache_clear()
|
||||
if empty_method_cache: method_cache.clear()
|
||||
tms = [time.perf_counter()]
|
||||
for i in range(5):
|
||||
model(Tensor([[1,2,3,4]]), i).realize()
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import unittest, functools, random
|
||||
from typing import List
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit
|
||||
from tinygrad.device import BufferCopy
|
||||
from tinygrad.device import BufferCopy, CompiledRunner
|
||||
from tinygrad.ops import LoadOps, ReduceOps
|
||||
from tinygrad.helpers import CI, prod, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.features.multi import all_reduce, MultiLazyBuffer
|
||||
from random import randint
|
||||
import numpy as np
|
||||
@@ -40,6 +41,17 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert lb.shape == (128,)
|
||||
(X + X).realize()
|
||||
|
||||
def test_shard_no_recompile(self):
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
X.shard_((d0, d1), 0)
|
||||
out = (X + X)
|
||||
sched = create_schedule(out.lazydata.lbs)
|
||||
names = []
|
||||
for si, ei in zip(sched[:], lower_schedule(sched)):
|
||||
if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.name)
|
||||
ei.run()
|
||||
assert names[-2] == names[-1], "function was relinearized"
|
||||
|
||||
def test_sharded_memory(self):
|
||||
# Buffer may be stuck in track_cross_buffer
|
||||
for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize()
|
||||
|
||||
@@ -169,12 +169,16 @@ class CompiledRunner(Runner):
|
||||
self.vars: List[Variable] = [] if variables is None else variables
|
||||
self.op_estimate, self.mem_estimate = op_estimate, mem_estimate
|
||||
|
||||
def to_other_device(self, dname:str):
|
||||
return CompiledRunner(self.display_name, self.prg, dname, self.global_size, self.local_size,
|
||||
self.vars, self.op_estimate, self.mem_estimate, self.lib, self.outcount)
|
||||
|
||||
@property
|
||||
def device(self): return Device[self.dname]
|
||||
|
||||
def __reduce__(self):
|
||||
return self.__class__, (self.name, self.prg, self.dname, self.global_size, self.local_size,
|
||||
self.vars, self.op_estimate, self.mem_estimate, self.lib)
|
||||
return self.__class__, (self.display_name, self.prg, self.dname, self.global_size, self.local_size,
|
||||
self.vars, self.op_estimate, self.mem_estimate, self.lib, self.outcount)
|
||||
|
||||
def launch_dims(self, var_vals):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
|
||||
@@ -201,6 +205,7 @@ class MultiDeviceJITGraph(Runner):
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], bool], CompiledRunner] = {}
|
||||
logkern, logkern_level = open(getenv("LOGKERN", ""), "a") if getenv("LOGKERN", "") else None, getenv("LOGKERN_LEVEL", 1)
|
||||
class Compiled:
|
||||
def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None):
|
||||
@@ -215,7 +220,7 @@ class Compiled:
|
||||
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
||||
# NOTE: we use min here to ignore the indexing FLOPS
|
||||
ret = CompiledRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self.dname, k.global_size, k.local_size,
|
||||
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count), outcount=len(k.outbufs))
|
||||
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count), outcount=len(k.outbufs))
|
||||
return ret
|
||||
|
||||
def get_linearizer(self, *ast:LazyOp) -> Linearizer:
|
||||
@@ -249,5 +254,11 @@ class Compiled:
|
||||
if DEBUG >= 4: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
|
||||
return k
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def get_runner(self, *ast:LazyOp) -> CompiledRunner: return self.to_program(self.get_linearizer(*ast))
|
||||
def get_runner(self, *ast:LazyOp) -> CompiledRunner:
|
||||
if cret:=method_cache.get((self.dname, ast, False)): return cret
|
||||
if bret:=method_cache.get((self.dname.split(":")[0], ast, True)):
|
||||
method_cache[(self.dname, ast, False)] = ret = bret.to_other_device(self.dname)
|
||||
else:
|
||||
method_cache[(self.dname.split(":")[0], ast, True)] = method_cache[(self.dname, ast, False)] = ret = self.to_program(self.get_linearizer(*ast))
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user