Refactor load/store before tensor cores (#1193)

* minor cleanups

* render_const

* now that's a nice refactor

* clean up vload/vstore

* clean up render_load

* debugs there

* dumb

* err, this?

* const float4

* what's failing

* bugfix

* statement includes semicolon

* bugfix
This commit is contained in:
George Hotz
2023-07-08 15:54:58 -07:00
committed by GitHub
parent ef1909500e
commit 7151382364
9 changed files with 85 additions and 90 deletions

View File

@@ -185,14 +185,14 @@ jobs:
python-version: 3.8
- name: Install Dependencies
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test openpilot model compile and size
run: |
DEBUG=2 ALLOWED_KERNEL_COUNT=199 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
python3 -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- name: Test GPU IMAGE ops
run: |
GPU=1 IMAGE=1 python3 test/test_ops.py
FORWARD_ONLY=1 GPU=1 IMAGE=2 python3 test/test_ops.py
- name: Test openpilot model compile and size
run: |
ALLOWED_KERNEL_COUNT=199 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
python3 -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- name: Test openpilot model correctness (float32)
run: DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py

View File

@@ -21,7 +21,7 @@ repos:
pass_filenames: false
- id: tests
name: subset of (CPU) tests
entry: env CPU=1 EXCLUDE_DEVICES=GPU pytest test/unit/ test/test_ops.py
entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py
language: system
always_run: true
pass_filenames: false

View File

@@ -36,7 +36,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
except Exception:
raise Exception(f"{s} failed shape {x.shape}")
if DEBUG >= 4:
if DEBUG >= 6:
np.set_printoptions(linewidth=200, suppress=True)
print(ret.numpy())
print(out.detach().numpy())

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
from tinygrad.helpers import prod, all_same
from tinygrad.shape.shapetracker import ShapeTracker, View, merge_views, get_contraction
from tinygrad.codegen.cstyle import to_image_idx
from tinygrad.codegen.linearizer import to_image_idx
def shapetracker_getitem(st, val):
locals = {"idx": val, "valid": 1}

View File

@@ -2,9 +2,9 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple,
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored
from tinygrad.helpers import ImageDType, dtypes, colored
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
from tinygrad.lazy import LazyBuffer
# div is different in cl than python
@@ -23,28 +23,38 @@ class CStyleLanguage(NamedTuple):
extra_args: List[str] = []
float4: Optional[str] = None
half_prekernel: Optional[str] = None
double_prekernel: Optional[str] = None
uses_vload: bool = False
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
idy = (idxy//(4*base_shape[1]))
if validhacks and valid.min == 0:
idx = (idxy//4) + (idy*-base_shape[1])
# find the ones in idx that didn't factorize and remove them (TODO: this is not universal)
if isinstance(idx, SumNode):
unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1])
assert len(unfactored) <= 1
idx = Variable.sum(idx_nodes)
unfactored = (Variable.sum(unfactored) // base_shape[1])
idy += unfactored
# ugh really...handtuned garbage
if idx.min >= (base_shape[1]*3)//4:
idx -= base_shape[1]
idy += 1
else:
idx = (idxy//4)%base_shape[1]
if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
return idx, idy
# returns a str expression of the const with the given type
def render_const(self, x:Union[float,int], var_dtype) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
return f"{self.float4}({val}, {val}, {val}, {val})" if var_dtype == dtypes._float4 else val
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert output_dtype == dtypes._float4, "images must be float4"
return f"read_imagef({buf_name}, smp, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}))"
elif self.uses_vload and buf_dtype == dtypes.float16:
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx.render(render_cl)})"
elif output_dtype == dtypes._float4:
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl)})))"
else:
return f"{buf_name}[{idx.render(render_cl)}]"
# returns a str statement that does the store
def render_store(self, buf_name, buf_dtype, var_name, var_dtype, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes._float4, "images must be float4"
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
elif self.uses_vload and buf_dtype == dtypes.float16:
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx.render(render_cl)});"
elif var_dtype.sz > 1:
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
else:
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP2: lambda x: f"exp2({x})",
@@ -105,70 +115,33 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
kk("}"*len(args[0]) + f" /* {args[1]} */")
elif uop == UOps.CONST:
assert newvar is not None
if args == -math.inf:
kk(f"{newvar.render(True)} = -INFINITY;")
elif newvar.dtype == dtypes._float4:
kk(f"{newvar.render(True)} = {{ {args}f, {args}f, {args}f, {args}f }};")
else:
kk(f"{newvar.render(True)} = {args}f;")
kk(f"{newvar.render(True)} = {lang.render_const(args, newvar.dtype)};")
elif uop == UOps.ALU:
assert newvar is not None
if newvar in vin:
kk(f"{newvar.render()} = {code_for_op[args](*[x.render() for x in vin])};")
kk(f"{newvar.render(newvar not in vin)} = {code_for_op[args](*[x.render() for x in vin])};")
elif uop == UOps.LOAD:
assert newvar is not None
# valids are handled here
if args.valid.max == 0:
val = lang.render_const(0.0, newvar.dtype)
elif isinstance(bufs[args.i].realized, RawConst):
val = lang.render_const(bufs[args.i].realized._buf, newvar.dtype)
else:
kk(f"{newvar.render(True)} = {code_for_op[args](*[x.render() for x in vin])};")
elif uop == UOps.LOAD and newvar is not None:
# TODO: merge with CONST?
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
assert newvar.dtype == dtypes.float, "const can't be float4"
x = bufs[args.i].realized._buf
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{x}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "")
elif isinstance(bufs[args.i].dtype, ImageDType):
assert newvar.dtype == dtypes._float4, f"image must be float4 {newvar}"
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
val = f"read_imagef({bufnames[args.i]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))"
else:
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
if newvar.dtype == dtypes._float4:
val = f"vload_half4(0, {bufnames[args.i]}+{(args.idx).render(render_cl)})"
else:
val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})"
else:
if newvar.dtype == dtypes._float4:
val = f"({newvar.dtype.name})(*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})))"
else:
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
if args.valid.min == 1: kk(f"{newvar.render(True)} = {val};")
else:
casts = {dtypes._float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), dtypes.half: ("(half)", "(half)(0.0f)"), dtypes.float: ("(float)", "0.0f")}[newvar.dtype]
kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? {casts[0]}({val}) : {casts[1]};")
elif uop == UOps.STORE and (vin[0].dtype == dtypes.float or (vin[0].dtype == dtypes._float4 and vin[0].offset is not None)):
assert not isinstance(bufs[args.i].dtype, ImageDType), "image store must be float4"
val = lang.render_load(newvar.dtype, bufnames[args.i], bufs[args.i].dtype, args.idx, isinstance(bufs[args.i], LocalBuffer))
if args.valid.min == 0 and args.valid.max == 1: val = f"({args.valid.render(render_cl)}) ? ({val}) : {lang.render_const(0.0, newvar.dtype)}"
kk(f"{newvar.render(True)} = {val};")
elif uop == UOps.STORE:
assert args.valid.min == 1, "store must be valid"
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
else:
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};")
# TODO: instead of dtypes.float, a base type
kk(lang.render_store(bufnames[args.i], bufs[args.i].dtype, vin[0].render(), vin[0].dtype if vin[0].offset is None else dtypes.float, args.idx, isinstance(bufs[args.i], LocalBuffer)))
elif uop == UOps.CAST and newvar is not None and newvar.dtype == dtypes._float4:
kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});")
elif uop == UOps.STORE and len(vin) != 0 and vin[0].dtype == dtypes._float4 and vin[0].offset is None:
assert args.valid.min == 1, "store must be valid"
if isinstance(bufs[args[0]].dtype, ImageDType):
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {vin[0].render()});")
elif lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
kk(f"vstore_half4({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
else:
kk(f"*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})) = ({bufs[args.i].dtype.name}4){vin[0].render()};")
elif uop == UOps.DEFINE_LOCAL:
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
else:
raise RuntimeError(f"failed to render {uop}")
if any(isinstance(x.dtype, ImageDType) for x in bufs): prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)
if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)]

View File

@@ -3,18 +3,38 @@ import itertools, math
from collections import defaultdict
from enum import Enum, auto
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same, partition
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
from tinygrad.shape.symbolic import Variable, NumNode
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode
VariableOrNum = Union[Variable, NumNode]
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); BARRIER = auto(); \
SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
idy = (idxy//(4*base_shape[1]))
if validhacks and valid.min == 0:
idx = (idxy//4) + (idy*-base_shape[1])
# find the ones in idx that didn't factorize and remove them (TODO: this is not universal)
if isinstance(idx, SumNode):
unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1])
assert len(unfactored) <= 1
idx = Variable.sum(idx_nodes)
unfactored = (Variable.sum(unfactored) // base_shape[1])
idy += unfactored
# ugh really...handtuned garbage
if idx.min >= (base_shape[1]*3)//4:
idx -= base_shape[1]
idy += 1
else:
idx = (idxy//4)%base_shape[1]
if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
return idx, idy
class LocalBuffer(NamedTuple):
name: str
size: int
@@ -185,6 +205,7 @@ class Linearizer:
localtype = dtypes.float
key = f"{localtype}{idx.render()}{valid.render()}"
if key not in cache:
if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else \
self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const)
ret.append(Token(cache[key].name, cache[key].dtype, _idx[upcast_dim[0]].b) if localtype == dtypes._float4 else cache[key])
@@ -212,7 +233,9 @@ class Linearizer:
store_offset = store_offset_new
for idx, var in store_offset.items():
self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(idx)))
idx, valid = self.sts[i].expr_idxs(idx)
if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
self.uop(UOps.STORE, None, [var], MemOp(i, idx, valid))
def linearize(self):
# uops

View File

@@ -97,9 +97,9 @@ class dtypes:
int8: Final[DType] = DType(0, 1, "char", np.int8)
int32: Final[DType] = DType(1, 4, "int", np.int32)
int64: Final[DType] = DType(2, 8, "long", np.int64)
uint8: Final[DType] = DType(0, 1, "uchar", np.uint8)
uint32: Final[DType] = DType(1, 4, "uint", np.uint32)
uint64: Final[DType] = DType(2, 8, "ulong", np.uint64)
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
uint32: Final[DType] = DType(1, 4, "unsigned int", np.uint32)
uint64: Final[DType] = DType(2, 8, "unsigned long", np.uint64)
# NOTE: these are internal dtypes, should probably check for that
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)

View File

@@ -79,7 +79,7 @@ class CUDAProgram:
class CUDACodegen(CStyleCodegen):
lang = CStyleLanguage(
kernel_prefix = "typedef unsigned char uchar;\ntypedef unsigned int uint;\ntypedef unsigned long ulong;\n__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
kernel_prefix = "__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
half_prekernel = """

View File

@@ -87,7 +87,6 @@ class CLProgram:
class CLCodegen(CStyleCodegen):
lang = CStyleLanguage(
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
double_prekernel="#ifdef cl_khr_fp64\n#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n#elif defined(cl_amd_fp64)\n#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n#endif",
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)