multitensor shouldn't recompile (#4164)

* multitensor shouldn't recompile

* type annotations

* fix tests

* outcount in reduce
This commit is contained in:
George Hotz
2024-04-13 00:03:48 -07:00
committed by GitHub
parent 599eb266b1
commit 50e780a588
3 changed files with 31 additions and 8 deletions

View File

@@ -4,7 +4,7 @@ from examples.llama import Transformer, MODEL_PARAMS
from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad.nn.state import get_state_dict
from tinygrad.device import Allocator
from tinygrad.device import Allocator, method_cache
from tinygrad.helpers import Profiling
class FakeProgram:
@@ -31,7 +31,7 @@ class TestLLaMASpeed(unittest.TestCase):
print("assigned empty tensors, doing warmup")
def run_llama(st, empty_method_cache=True):
if empty_method_cache: Device[Device.DEFAULT].get_runner.cache_clear()
if empty_method_cache: method_cache.clear()
tms = [time.perf_counter()]
for i in range(5):
model(Tensor([[1,2,3,4]]), i).realize()

View File

@@ -1,11 +1,12 @@
import unittest, functools, random
from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit
from tinygrad.device import BufferCopy
from tinygrad.device import BufferCopy, CompiledRunner
from tinygrad.ops import LoadOps, ReduceOps
from tinygrad.helpers import CI, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule
from tinygrad.features.multi import all_reduce, MultiLazyBuffer
from random import randint
import numpy as np
@@ -40,6 +41,17 @@ class TestMultiTensor(unittest.TestCase):
assert lb.shape == (128,)
(X + X).realize()
def test_shard_no_recompile(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_((d0, d1), 0)
out = (X + X)
sched = create_schedule(out.lazydata.lbs)
names = []
for si, ei in zip(sched[:], lower_schedule(sched)):
if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.name)
ei.run()
assert names[-2] == names[-1], "function was relinearized"
def test_sharded_memory(self):
# Buffer may be stuck in track_cross_buffer
for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize()

View File

@@ -169,12 +169,16 @@ class CompiledRunner(Runner):
self.vars: List[Variable] = [] if variables is None else variables
self.op_estimate, self.mem_estimate = op_estimate, mem_estimate
def to_other_device(self, dname:str):
return CompiledRunner(self.display_name, self.prg, dname, self.global_size, self.local_size,
self.vars, self.op_estimate, self.mem_estimate, self.lib, self.outcount)
@property
def device(self): return Device[self.dname]
def __reduce__(self):
return self.__class__, (self.name, self.prg, self.dname, self.global_size, self.local_size,
self.vars, self.op_estimate, self.mem_estimate, self.lib)
return self.__class__, (self.display_name, self.prg, self.dname, self.global_size, self.local_size,
self.vars, self.op_estimate, self.mem_estimate, self.lib, self.outcount)
def launch_dims(self, var_vals):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
@@ -201,6 +205,7 @@ class MultiDeviceJITGraph(Runner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")
method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], bool], CompiledRunner] = {}
logkern, logkern_level = open(getenv("LOGKERN", ""), "a") if getenv("LOGKERN", "") else None, getenv("LOGKERN_LEVEL", 1)
class Compiled:
def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None):
@@ -215,7 +220,7 @@ class Compiled:
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
ret = CompiledRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self.dname, k.global_size, k.local_size,
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count), outcount=len(k.outbufs))
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count), outcount=len(k.outbufs))
return ret
def get_linearizer(self, *ast:LazyOp) -> Linearizer:
@@ -249,5 +254,11 @@ class Compiled:
if DEBUG >= 4: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search
return k
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def get_runner(self, *ast:LazyOp) -> CompiledRunner: return self.to_program(self.get_linearizer(*ast))
def get_runner(self, *ast:LazyOp) -> CompiledRunner:
if cret:=method_cache.get((self.dname, ast, False)): return cret
if bret:=method_cache.get((self.dname.split(":")[0], ast, True)):
method_cache[(self.dname, ast, False)] = ret = bret.to_other_device(self.dname)
else:
method_cache[(self.dname.split(":")[0], ast, True)] = method_cache[(self.dname, ast, False)] = ret = self.to_program(self.get_linearizer(*ast))
return ret