From 7151382364c91553ba8db9b0d1772beaed281724 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 8 Jul 2023 15:54:58 -0700 Subject: [PATCH] Refactor load/store before tensor cores (#1193) * minor cleanups * render_const * now that's a nice refactor * clean up vload/vstore * clean up render_load * debugs there * dumb * err, this? * const float4 * what's failing * bugfix * statement includes semicolon * bugfix --- .github/workflows/test.yml | 8 +-- .pre-commit-config.yaml | 2 +- test/test_ops.py | 2 +- test/unit/test_shapetracker.py | 2 +- tinygrad/codegen/cstyle.py | 123 +++++++++++++-------------------- tinygrad/codegen/linearizer.py | 29 +++++++- tinygrad/helpers.py | 6 +- tinygrad/runtime/ops_cuda.py | 2 +- tinygrad/runtime/ops_gpu.py | 1 - 9 files changed, 85 insertions(+), 90 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6b270baa0d..5239df559f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -185,14 +185,14 @@ jobs: python-version: 3.8 - name: Install Dependencies run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu + - name: Test openpilot model compile and size + run: | + DEBUG=2 ALLOWED_KERNEL_COUNT=199 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py + python3 -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' - name: Test GPU IMAGE ops run: | GPU=1 IMAGE=1 python3 test/test_ops.py FORWARD_ONLY=1 GPU=1 IMAGE=2 python3 test/test_ops.py - - name: Test openpilot model compile and size - run: | - ALLOWED_KERNEL_COUNT=199 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py - python3 -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' - name: Test openpilot model correctness (float32) run: DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 198054e917..0bf1290f20 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: pass_filenames: false - id: tests name: subset of (CPU) tests - entry: env CPU=1 EXCLUDE_DEVICES=GPU pytest test/unit/ test/test_ops.py + entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py language: system always_run: true pass_filenames: false diff --git a/test/test_ops.py b/test/test_ops.py index a0e447adc3..4986b660d5 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -36,7 +36,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra except Exception: raise Exception(f"{s} failed shape {x.shape}") - if DEBUG >= 4: + if DEBUG >= 6: np.set_printoptions(linewidth=200, suppress=True) print(ret.numpy()) print(out.detach().numpy()) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 1ada73447f..7e62227f7f 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -3,7 +3,7 @@ import unittest import numpy as np from tinygrad.helpers import prod, all_same from tinygrad.shape.shapetracker import ShapeTracker, View, merge_views, get_contraction -from tinygrad.codegen.cstyle import to_image_idx +from tinygrad.codegen.linearizer import to_image_idx def shapetracker_getitem(st, val): locals = {"idx": val, "valid": 1} diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index 04834ce432..9ab77d788a 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -2,9 +2,9 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, import math, collections from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps -from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored +from tinygrad.helpers import ImageDType, dtypes, colored from tinygrad.runtime.lib import RawConst -from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode +from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable from tinygrad.lazy import LazyBuffer # div is different in cl than python @@ -23,28 +23,38 @@ class CStyleLanguage(NamedTuple): extra_args: List[str] = [] float4: Optional[str] = None half_prekernel: Optional[str] = None - double_prekernel: Optional[str] = None uses_vload: bool = False -def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]: - idy = (idxy//(4*base_shape[1])) - if validhacks and valid.min == 0: - idx = (idxy//4) + (idy*-base_shape[1]) - # find the ones in idx that didn't factorize and remove them (TODO: this is not universal) - if isinstance(idx, SumNode): - unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1]) - assert len(unfactored) <= 1 - idx = Variable.sum(idx_nodes) - unfactored = (Variable.sum(unfactored) // base_shape[1]) - idy += unfactored - # ugh really...handtuned garbage - if idx.min >= (base_shape[1]*3)//4: - idx -= base_shape[1] - idy += 1 - else: - idx = (idxy//4)%base_shape[1] - if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy) - return idx, idy + # returns a str expression of the const with the given type + def render_const(self, x:Union[float,int], var_dtype) -> str: + if math.isnan(x): val = "NAN" + elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY" + else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f") + return f"{self.float4}({val}, {val}, {val}, {val})" if var_dtype == dtypes._float4 else val + + # returns a str expression of the loaded value with the output type + def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: + if isinstance(buf_dtype, ImageDType): + assert output_dtype == dtypes._float4, "images must be float4" + return f"read_imagef({buf_name}, smp, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}))" + elif self.uses_vload and buf_dtype == dtypes.float16: + return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx.render(render_cl)})" + elif output_dtype == dtypes._float4: + return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl)})))" + else: + return f"{buf_name}[{idx.render(render_cl)}]" + + # returns a str statement that does the store + def render_store(self, buf_name, buf_dtype, var_name, var_dtype, idx, local=False) -> str: + if isinstance(buf_dtype, ImageDType): + assert var_dtype == dtypes._float4, "images must be float4" + return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});" + elif self.uses_vload and buf_dtype == dtypes.float16: + return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx.render(render_cl)});" + elif var_dtype.sz > 1: + return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" + else: + return f"{buf_name}[{idx.render(render_cl)}] = {var_name};" code_for_op: Final[Dict[Op, Callable]] = { UnaryOps.EXP2: lambda x: f"exp2({x})", @@ -105,70 +115,33 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan kk("}"*len(args[0]) + f" /* {args[1]} */") elif uop == UOps.CONST: assert newvar is not None - if args == -math.inf: - kk(f"{newvar.render(True)} = -INFINITY;") - elif newvar.dtype == dtypes._float4: - kk(f"{newvar.render(True)} = {{ {args}f, {args}f, {args}f, {args}f }};") - else: - kk(f"{newvar.render(True)} = {args}f;") + kk(f"{newvar.render(True)} = {lang.render_const(args, newvar.dtype)};") elif uop == UOps.ALU: assert newvar is not None - if newvar in vin: - kk(f"{newvar.render()} = {code_for_op[args](*[x.render() for x in vin])};") + kk(f"{newvar.render(newvar not in vin)} = {code_for_op[args](*[x.render() for x in vin])};") + elif uop == UOps.LOAD: + assert newvar is not None + # valids are handled here + if args.valid.max == 0: + val = lang.render_const(0.0, newvar.dtype) + elif isinstance(bufs[args.i].realized, RawConst): + val = lang.render_const(bufs[args.i].realized._buf, newvar.dtype) else: - kk(f"{newvar.render(True)} = {code_for_op[args](*[x.render() for x in vin])};") - elif uop == UOps.LOAD and newvar is not None: - # TODO: merge with CONST? - if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst): - assert newvar.dtype == dtypes.float, "const can't be float4" - x = bufs[args.i].realized._buf - if math.isnan(x): val = "NAN" - elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY" - else: val = f"{x}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "") - elif isinstance(bufs[args.i].dtype, ImageDType): - assert newvar.dtype == dtypes._float4, f"image must be float4 {newvar}" - prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n") - idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid) - val = f"read_imagef({bufnames[args.i]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))" - else: - if lang.uses_vload and bufs[args.i].dtype == dtypes.float16: - if newvar.dtype == dtypes._float4: - val = f"vload_half4(0, {bufnames[args.i]}+{(args.idx).render(render_cl)})" - else: - val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})" - else: - if newvar.dtype == dtypes._float4: - val = f"({newvar.dtype.name})(*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})))" - else: - val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]" - # NOTE: if min and max are both 0, it should be a CONST in the Linearizer - if args.valid.min == 1: kk(f"{newvar.render(True)} = {val};") - else: - casts = {dtypes._float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), dtypes.half: ("(half)", "(half)(0.0f)"), dtypes.float: ("(float)", "0.0f")}[newvar.dtype] - kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? {casts[0]}({val}) : {casts[1]};") - elif uop == UOps.STORE and (vin[0].dtype == dtypes.float or (vin[0].dtype == dtypes._float4 and vin[0].offset is not None)): - assert not isinstance(bufs[args.i].dtype, ImageDType), "image store must be float4" + val = lang.render_load(newvar.dtype, bufnames[args.i], bufs[args.i].dtype, args.idx, isinstance(bufs[args.i], LocalBuffer)) + if args.valid.min == 0 and args.valid.max == 1: val = f"({args.valid.render(render_cl)}) ? ({val}) : {lang.render_const(0.0, newvar.dtype)}" + kk(f"{newvar.render(True)} = {val};") + elif uop == UOps.STORE: assert args.valid.min == 1, "store must be valid" - if lang.uses_vload and bufs[args.i].dtype == dtypes.float16: - kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});") - else: - kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};") + # TODO: instead of dtypes.float, a base type + kk(lang.render_store(bufnames[args.i], bufs[args.i].dtype, vin[0].render(), vin[0].dtype if vin[0].offset is None else dtypes.float, args.idx, isinstance(bufs[args.i], LocalBuffer))) elif uop == UOps.CAST and newvar is not None and newvar.dtype == dtypes._float4: kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});") - elif uop == UOps.STORE and len(vin) != 0 and vin[0].dtype == dtypes._float4 and vin[0].offset is None: - assert args.valid.min == 1, "store must be valid" - if isinstance(bufs[args[0]].dtype, ImageDType): - idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2]) - kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {vin[0].render()});") - elif lang.uses_vload and bufs[args.i].dtype == dtypes.float16: - kk(f"vstore_half4({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});") - else: - kk(f"*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})) = ({bufs[args.i].dtype.name}4){vin[0].render()};") elif uop == UOps.DEFINE_LOCAL: kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];") else: raise RuntimeError(f"failed to render {uop}") + if any(isinstance(x.dtype, ImageDType) for x in bufs): prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n") buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else ("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs) if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)] diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index d41c88b41a..44a6c070d9 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -3,18 +3,38 @@ import itertools, math from collections import defaultdict from enum import Enum, auto -from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same +from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same, partition from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps from tinygrad.lazy import LazyBuffer from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps from tinygrad.runtime.lib import RawConst from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape -from tinygrad.shape.symbolic import Variable, NumNode +from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode VariableOrNum = Union[Variable, NumNode] class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); BARRIER = auto(); \ SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702 +def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]: + idy = (idxy//(4*base_shape[1])) + if validhacks and valid.min == 0: + idx = (idxy//4) + (idy*-base_shape[1]) + # find the ones in idx that didn't factorize and remove them (TODO: this is not universal) + if isinstance(idx, SumNode): + unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1]) + assert len(unfactored) <= 1 + idx = Variable.sum(idx_nodes) + unfactored = (Variable.sum(unfactored) // base_shape[1]) + idy += unfactored + # ugh really...handtuned garbage + if idx.min >= (base_shape[1]*3)//4: + idx -= base_shape[1] + idy += 1 + else: + idx = (idxy//4)%base_shape[1] + if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy) + return idx, idy + class LocalBuffer(NamedTuple): name: str size: int @@ -185,6 +205,7 @@ class Linearizer: localtype = dtypes.float key = f"{localtype}{idx.render()}{valid.render()}" if key not in cache: + if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid) cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else \ self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const) ret.append(Token(cache[key].name, cache[key].dtype, _idx[upcast_dim[0]].b) if localtype == dtypes._float4 else cache[key]) @@ -212,7 +233,9 @@ class Linearizer: store_offset = store_offset_new for idx, var in store_offset.items(): - self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(idx))) + idx, valid = self.sts[i].expr_idxs(idx) + if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid) + self.uop(UOps.STORE, None, [var], MemOp(i, idx, valid)) def linearize(self): # uops diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index dd01bc60ad..f742cb3363 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -97,9 +97,9 @@ class dtypes: int8: Final[DType] = DType(0, 1, "char", np.int8) int32: Final[DType] = DType(1, 4, "int", np.int32) int64: Final[DType] = DType(2, 8, "long", np.int64) - uint8: Final[DType] = DType(0, 1, "uchar", np.uint8) - uint32: Final[DType] = DType(1, 4, "uint", np.uint32) - uint64: Final[DType] = DType(2, 8, "ulong", np.uint64) + uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8) + uint32: Final[DType] = DType(1, 4, "unsigned int", np.uint32) + uint64: Final[DType] = DType(2, 8, "unsigned long", np.uint64) # NOTE: these are internal dtypes, should probably check for that _half4: Final[DType] = DType(0, 2*4, "half4", None, 4) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index f95753c71a..f37c2299e5 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -79,7 +79,7 @@ class CUDAProgram: class CUDACodegen(CStyleCodegen): lang = CStyleLanguage( - kernel_prefix = "typedef unsigned char uchar;\ntypedef unsigned int uint;\ntypedef unsigned long ulong;\n__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4", + kernel_prefix = "__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4", gid = [f'blockIdx.{chr(120+i)}' for i in range(3)], lid = [f'threadIdx.{chr(120+i)}' for i in range(3)], half_prekernel = """ diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 3f488d40ac..af5138fe2d 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -87,7 +87,6 @@ class CLProgram: class CLCodegen(CStyleCodegen): lang = CStyleLanguage( kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ", - double_prekernel="#ifdef cl_khr_fp64\n#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n#elif defined(cl_amd_fp64)\n#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n#endif", half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable", barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)", gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)