From 453e437598efdb55840502c439cf4603c1ea1073 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 31 Aug 2023 14:42:09 -0700 Subject: [PATCH] move stuff in the linearizer (#1726) * move stuff in linearizer * move stuff in linearizer * minor * fix opts import --- docs/abstractions.py | 3 +- tinygrad/codegen/linearizer.py | 56 +++++++++++++++------------------- tinygrad/runtime/ops_clang.py | 2 +- tinygrad/runtime/ops_cuda.py | 2 +- tinygrad/runtime/ops_gpu.py | 2 +- tinygrad/runtime/ops_hip.py | 2 +- tinygrad/runtime/ops_llvm.py | 2 +- tinygrad/runtime/ops_metal.py | 2 +- tinygrad/runtime/ops_webgpu.py | 2 +- 9 files changed, 34 insertions(+), 39 deletions(-) diff --git a/docs/abstractions.py b/docs/abstractions.py index 8c1e37b55c..add58ccac4 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -270,7 +270,8 @@ result = Tensor(2).realize() + Tensor(3).realize() result.lazydata.realized = Device[Device.DEFAULT].buffer(prod(result.shape), result.dtype) # use the real Linearizer to linearize 2+3 -from tinygrad.codegen.linearizer import Linearizer, LinearizerOptions +from tinygrad.codegen.linearizer import Linearizer +from tinygrad.codegen.kernel import LinearizerOptions linearizer = Linearizer(result.lazydata.op, result.lazydata, LinearizerOptions()) linearizer.linearize() diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 7c22680f86..3162e1eefb 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -11,7 +11,7 @@ from tinygrad.runtime.lib import RawConst from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, sym_rename from tinygrad.codegen.optimizer import OptimizedKernel -from tinygrad.codegen.kernel import LocalBuffer, LinearizerOptions # noqa: F401 # pylint:disable=unused-import +from tinygrad.codegen.kernel import LocalBuffer VariableOrNum = Union[Variable, NumNode, Node] # bottom ones are asm only @@ -205,8 +205,6 @@ class Linearizer(OptimizedKernel): # uops self.uops: List[UOp] = [] - self.load_cache: Dict[str, Token] = {} - self.saved_exprs: Dict[Tuple[Op, Tuple[Token, ...]], Token] = dict() # add global buffers for buf,name in self.arg_bufs.items(): @@ -214,18 +212,16 @@ class Linearizer(OptimizedKernel): # add variables from symbolic shapes for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key): self.uop(UOps.DEFINE_GLOBAL, None, [], (var.expr, dtypes._arg_int32)) - - # add a local buffer for multistage reduce + # define local buffers + for lb in self.local_alias.values(): + self.uop(UOps.DEFINE_LOCAL, None, [], (lb.name, self.sts[self.bufs.index(lb)].size())) + # add a local buffer for multistage reduce. # TODO: use local alias if self.group_for_reduce: # TODO: the strides of this can be controlled self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) self.bufs.append(LocalBuffer("temp", self.sts[-1].size())) self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size())) - # define local buffers - for lb in self.local_alias.values(): - self.uop(UOps.DEFINE_LOCAL, None, [], (lb.name, self.sts[self.bufs.index(lb)].size())) - # print if DEBUG >= 3: self.printbufs() @@ -233,9 +229,26 @@ class Linearizer(OptimizedKernel): self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape]) self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) + # name the function something unique + Linearizer.kernel_cnt[self.function_name] += 1 + suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else "" + self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK') + + # define indexes + global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1) for i in range(0, self.first_reduce-self.local_dims)] + local_idxs = [Variable(f"lidx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce-self.local_dims, self.first_reduce+len(self.group_for_reduce))] + full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]] + upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]] + + # global and local loops + self.uop(UOps.LOOP, None, [], (global_idxs, "global")) + self.uop(UOps.LOOP, None, [], (local_idxs, "local")) + # parse AST loaded_buffers = {} acc = [] + self.load_cache: Dict[str, Token] = {} + self.saved_exprs: Dict[Tuple[Op, Tuple[Token, ...]], Token] = dict() # ssa _ssa:DefaultDict[str,int] = defaultdict(int) @@ -243,18 +256,6 @@ class Linearizer(OptimizedKernel): _ssa[name] += 1 return Token(f"{name}{_ssa[name]-1}", ltype) - # global loop - global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1) for i in range(0, self.first_reduce-self.local_dims)] - self.uop(UOps.LOOP, None, [], (global_idxs, "global")) - - # local loop - local_idxs = [Variable(f"lidx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce-self.local_dims, self.first_reduce+len(self.group_for_reduce))] - self.uop(UOps.LOOP, None, [], (local_idxs, "local")) - - # upcast indexes - full_upcast_idxs = [Variable(None, 0, s-1) for s in self.full_shape[self.shape_len-self.upcasted:]] - upcast_idxs = [Variable(None, 0, s-1) for s in self.output_shape[self.shape_len-self.upcasted:]] - # reduce op fake_reduce_idxs = [] if self.reduceop is not None: @@ -272,6 +273,7 @@ class Linearizer(OptimizedKernel): if self.use_tensor_cores: self.uop(UOps.BARRIER, None, [], ()) # compute local aliases + # TODO: this is garbage code and should be at least moved elsewhere locals_to_store = [] for i in self.local_alias: strides = self.sts[i].real_strides() @@ -381,17 +383,9 @@ class Linearizer(OptimizedKernel): # store self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val, ssa) - if not self.group_for_reduce: - # end the global+local loop - self.uop(UOps.ENDLOOP, None, [], (global_idxs+local_idxs, "global+local")) - else: - # end the global loop - self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global")) + # end the global (and maybe local) loop + self.uop(UOps.ENDLOOP, None, [], (global_idxs+local_idxs, "global+local") if not self.group_for_reduce else (global_idxs, "global")) - # name the function something unique - Linearizer.kernel_cnt[self.function_name] += 1 - suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else "" - self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK') return self _OT = TypeVar("_OT") diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 016a890fda..363f4965e8 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -3,7 +3,7 @@ from functools import partial, reduce from tinygrad.ops import Compiled from tinygrad.helpers import fromimport, getenv, DEBUG, CI from tinygrad.runtime.lib import RawMallocBuffer -from tinygrad.codegen.linearizer import LinearizerOptions +from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage import struct import numpy as np diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index cfb4fe8add..88eac03ccf 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -6,7 +6,7 @@ from pycuda.compiler import compile as cuda_compile # type: ignore from tinygrad.helpers import DEBUG, getenv, colored, fromimport from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator -from tinygrad.codegen.linearizer import LinearizerOptions +from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage def pretty_ptx(s): diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 75f658cb29..656defc468 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -6,7 +6,7 @@ from typing import Optional, List from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer -from tinygrad.codegen.linearizer import LinearizerOptions +from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 5452d0cbbe..a663928d44 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -4,7 +4,7 @@ import extra.hip_wrapper as hip from tinygrad.helpers import DEBUG from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator -from tinygrad.codegen.linearizer import LinearizerOptions +from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage # TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait() diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 7a2b140829..22eb8c7d1a 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -3,7 +3,7 @@ from typing import ClassVar from tinygrad.ops import Compiled from tinygrad.helpers import getenv, DEBUG from ctypes import CFUNCTYPE -from tinygrad.codegen.linearizer import LinearizerOptions +from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.llvmir import uops_to_llvm_ir from tinygrad.runtime.lib import RawMallocBuffer diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index f093322734..fffae4ba34 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -2,7 +2,7 @@ import os, subprocess, pathlib, functools, ctypes import Metal, Cocoa, libdispatch # type: ignore from typing import List, Any -from tinygrad.codegen.linearizer import LinearizerOptions +from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes from tinygrad.ops import Compiled diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 2124f83621..4d58b65204 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -4,7 +4,7 @@ from wgpu.utils._device import get_default_device # type: ignore from tinygrad.runtime.lib import RawBufferCopyIn, LRUAllocator from tinygrad.helpers import dtypes, DType from tinygrad.ops import Compiled -from tinygrad.codegen.linearizer import LinearizerOptions +from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle from tinygrad.renderer.wgsl import WGSLLanguage import wgpu # type: ignore