From 50e780a588ab598f2a56d23aa0ebab0c760f2140 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 13 Apr 2024 00:03:48 -0700 Subject: [PATCH] multitensor shouldn't recompile (#4164) * multitensor shouldn't recompile * type annotations * fix tests * outcount in reduce --- test/external/external_test_speed_llama.py | 4 ++-- test/test_multitensor.py | 14 +++++++++++++- tinygrad/device.py | 21 ++++++++++++++++----- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py index 768d9683d3..3939bbb650 100644 --- a/test/external/external_test_speed_llama.py +++ b/test/external/external_test_speed_llama.py @@ -4,7 +4,7 @@ from examples.llama import Transformer, MODEL_PARAMS from tinygrad.tensor import Tensor from tinygrad import Device from tinygrad.nn.state import get_state_dict -from tinygrad.device import Allocator +from tinygrad.device import Allocator, method_cache from tinygrad.helpers import Profiling class FakeProgram: @@ -31,7 +31,7 @@ class TestLLaMASpeed(unittest.TestCase): print("assigned empty tensors, doing warmup") def run_llama(st, empty_method_cache=True): - if empty_method_cache: Device[Device.DEFAULT].get_runner.cache_clear() + if empty_method_cache: method_cache.clear() tms = [time.perf_counter()] for i in range(5): model(Tensor([[1,2,3,4]]), i).realize() diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 974ec92bda..486da76fa6 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1,11 +1,12 @@ import unittest, functools, random from typing import List from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit -from tinygrad.device import BufferCopy +from tinygrad.device import BufferCopy, CompiledRunner from tinygrad.ops import LoadOps, ReduceOps from tinygrad.helpers import CI, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.schedule import create_schedule +from tinygrad.engine.realize import lower_schedule from tinygrad.features.multi import all_reduce, MultiLazyBuffer from random import randint import numpy as np @@ -40,6 +41,17 @@ class TestMultiTensor(unittest.TestCase): assert lb.shape == (128,) (X + X).realize() + def test_shard_no_recompile(self): + X = Tensor.ones(256).contiguous().realize() + X.shard_((d0, d1), 0) + out = (X + X) + sched = create_schedule(out.lazydata.lbs) + names = [] + for si, ei in zip(sched[:], lower_schedule(sched)): + if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.name) + ei.run() + assert names[-2] == names[-1], "function was relinearized" + def test_sharded_memory(self): # Buffer may be stuck in track_cross_buffer for x in (d_zero, d0, d1, d2, d3): Device[x].synchronize() diff --git a/tinygrad/device.py b/tinygrad/device.py index 909c9b9c57..0a10ec319a 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -169,12 +169,16 @@ class CompiledRunner(Runner): self.vars: List[Variable] = [] if variables is None else variables self.op_estimate, self.mem_estimate = op_estimate, mem_estimate + def to_other_device(self, dname:str): + return CompiledRunner(self.display_name, self.prg, dname, self.global_size, self.local_size, + self.vars, self.op_estimate, self.mem_estimate, self.lib, self.outcount) + @property def device(self): return Device[self.dname] def __reduce__(self): - return self.__class__, (self.name, self.prg, self.dname, self.global_size, self.local_size, - self.vars, self.op_estimate, self.mem_estimate, self.lib) + return self.__class__, (self.display_name, self.prg, self.dname, self.global_size, self.local_size, + self.vars, self.op_estimate, self.mem_estimate, self.lib, self.outcount) def launch_dims(self, var_vals): global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size @@ -201,6 +205,7 @@ class MultiDeviceJITGraph(Runner): def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: raise NotImplementedError("override this") +method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], bool], CompiledRunner] = {} logkern, logkern_level = open(getenv("LOGKERN", ""), "a") if getenv("LOGKERN", "") else None, getenv("LOGKERN_LEVEL", 1) class Compiled: def __init__(self, device:str, allocator:Allocator, compiler:Optional[Compiler], runtime, graph=None): @@ -215,7 +220,7 @@ class Compiled: run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else [])) # NOTE: we use min here to ignore the indexing FLOPS ret = CompiledRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self.dname, k.global_size, k.local_size, - k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count), outcount=len(k.outbufs)) + k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count), outcount=len(k.outbufs)) return ret def get_linearizer(self, *ast:LazyOp) -> Linearizer: @@ -249,5 +254,11 @@ class Compiled: if DEBUG >= 4: print((k.ast, k.applied_opts)) # print here to show final applied_opts for all kernels instead of just in beam_search return k - @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none - def get_runner(self, *ast:LazyOp) -> CompiledRunner: return self.to_program(self.get_linearizer(*ast)) + def get_runner(self, *ast:LazyOp) -> CompiledRunner: + if cret:=method_cache.get((self.dname, ast, False)): return cret + if bret:=method_cache.get((self.dname.split(":")[0], ast, True)): + method_cache[(self.dname, ast, False)] = ret = bret.to_other_device(self.dname) + else: + method_cache[(self.dname.split(":")[0], ast, True)] = method_cache[(self.dname, ast, False)] = ret = self.to_program(self.get_linearizer(*ast)) + return ret +