mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Refactor load/store before tensor cores (#1193)
* minor cleanups * render_const * now that's a nice refactor * clean up vload/vstore * clean up render_load * debugs there * dumb * err, this? * const float4 * what's failing * bugfix * statement includes semicolon * bugfix
This commit is contained in:
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@@ -185,14 +185,14 @@ jobs:
|
||||
python-version: 3.8
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Test openpilot model compile and size
|
||||
run: |
|
||||
DEBUG=2 ALLOWED_KERNEL_COUNT=199 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
python3 -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
|
||||
- name: Test GPU IMAGE ops
|
||||
run: |
|
||||
GPU=1 IMAGE=1 python3 test/test_ops.py
|
||||
FORWARD_ONLY=1 GPU=1 IMAGE=2 python3 test/test_ops.py
|
||||
- name: Test openpilot model compile and size
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=199 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
python3 -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
|
||||
- name: Test openpilot model correctness (float32)
|
||||
run: DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ repos:
|
||||
pass_filenames: false
|
||||
- id: tests
|
||||
name: subset of (CPU) tests
|
||||
entry: env CPU=1 EXCLUDE_DEVICES=GPU pytest test/unit/ test/test_ops.py
|
||||
entry: env CPU=1 pytest test/unit/ test/test_ops.py test/test_dtype.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
@@ -36,7 +36,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
except Exception:
|
||||
raise Exception(f"{s} failed shape {x.shape}")
|
||||
|
||||
if DEBUG >= 4:
|
||||
if DEBUG >= 6:
|
||||
np.set_printoptions(linewidth=200, suppress=True)
|
||||
print(ret.numpy())
|
||||
print(out.detach().numpy())
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, all_same
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, merge_views, get_contraction
|
||||
from tinygrad.codegen.cstyle import to_image_idx
|
||||
from tinygrad.codegen.linearizer import to_image_idx
|
||||
|
||||
def shapetracker_getitem(st, val):
|
||||
locals = {"idx": val, "valid": 1}
|
||||
|
||||
@@ -2,9 +2,9 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple,
|
||||
import math, collections
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
|
||||
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored
|
||||
from tinygrad.helpers import ImageDType, dtypes, colored
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
||||
# div is different in cl than python
|
||||
@@ -23,28 +23,38 @@ class CStyleLanguage(NamedTuple):
|
||||
extra_args: List[str] = []
|
||||
float4: Optional[str] = None
|
||||
half_prekernel: Optional[str] = None
|
||||
double_prekernel: Optional[str] = None
|
||||
uses_vload: bool = False
|
||||
|
||||
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
|
||||
idy = (idxy//(4*base_shape[1]))
|
||||
if validhacks and valid.min == 0:
|
||||
idx = (idxy//4) + (idy*-base_shape[1])
|
||||
# find the ones in idx that didn't factorize and remove them (TODO: this is not universal)
|
||||
if isinstance(idx, SumNode):
|
||||
unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1])
|
||||
assert len(unfactored) <= 1
|
||||
idx = Variable.sum(idx_nodes)
|
||||
unfactored = (Variable.sum(unfactored) // base_shape[1])
|
||||
idy += unfactored
|
||||
# ugh really...handtuned garbage
|
||||
if idx.min >= (base_shape[1]*3)//4:
|
||||
idx -= base_shape[1]
|
||||
idy += 1
|
||||
else:
|
||||
idx = (idxy//4)%base_shape[1]
|
||||
if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
|
||||
return idx, idy
|
||||
# returns a str expression of the const with the given type
|
||||
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
||||
if math.isnan(x): val = "NAN"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
|
||||
return f"{self.float4}({val}, {val}, {val}, {val})" if var_dtype == dtypes._float4 else val
|
||||
|
||||
# returns a str expression of the loaded value with the output type
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert output_dtype == dtypes._float4, "images must be float4"
|
||||
return f"read_imagef({buf_name}, smp, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}))"
|
||||
elif self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx.render(render_cl)})"
|
||||
elif output_dtype == dtypes._float4:
|
||||
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl)})))"
|
||||
else:
|
||||
return f"{buf_name}[{idx.render(render_cl)}]"
|
||||
|
||||
# returns a str statement that does the store
|
||||
def render_store(self, buf_name, buf_dtype, var_name, var_dtype, idx, local=False) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert var_dtype == dtypes._float4, "images must be float4"
|
||||
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
|
||||
elif self.uses_vload and buf_dtype == dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx.render(render_cl)});"
|
||||
elif var_dtype.sz > 1:
|
||||
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
else:
|
||||
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
|
||||
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.EXP2: lambda x: f"exp2({x})",
|
||||
@@ -105,70 +115,33 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
kk("}"*len(args[0]) + f" /* {args[1]} */")
|
||||
elif uop == UOps.CONST:
|
||||
assert newvar is not None
|
||||
if args == -math.inf:
|
||||
kk(f"{newvar.render(True)} = -INFINITY;")
|
||||
elif newvar.dtype == dtypes._float4:
|
||||
kk(f"{newvar.render(True)} = {{ {args}f, {args}f, {args}f, {args}f }};")
|
||||
else:
|
||||
kk(f"{newvar.render(True)} = {args}f;")
|
||||
kk(f"{newvar.render(True)} = {lang.render_const(args, newvar.dtype)};")
|
||||
elif uop == UOps.ALU:
|
||||
assert newvar is not None
|
||||
if newvar in vin:
|
||||
kk(f"{newvar.render()} = {code_for_op[args](*[x.render() for x in vin])};")
|
||||
kk(f"{newvar.render(newvar not in vin)} = {code_for_op[args](*[x.render() for x in vin])};")
|
||||
elif uop == UOps.LOAD:
|
||||
assert newvar is not None
|
||||
# valids are handled here
|
||||
if args.valid.max == 0:
|
||||
val = lang.render_const(0.0, newvar.dtype)
|
||||
elif isinstance(bufs[args.i].realized, RawConst):
|
||||
val = lang.render_const(bufs[args.i].realized._buf, newvar.dtype)
|
||||
else:
|
||||
kk(f"{newvar.render(True)} = {code_for_op[args](*[x.render() for x in vin])};")
|
||||
elif uop == UOps.LOAD and newvar is not None:
|
||||
# TODO: merge with CONST?
|
||||
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
|
||||
assert newvar.dtype == dtypes.float, "const can't be float4"
|
||||
x = bufs[args.i].realized._buf
|
||||
if math.isnan(x): val = "NAN"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
else: val = f"{x}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "")
|
||||
elif isinstance(bufs[args.i].dtype, ImageDType):
|
||||
assert newvar.dtype == dtypes._float4, f"image must be float4 {newvar}"
|
||||
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
|
||||
val = f"read_imagef({bufnames[args.i]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))"
|
||||
else:
|
||||
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
|
||||
if newvar.dtype == dtypes._float4:
|
||||
val = f"vload_half4(0, {bufnames[args.i]}+{(args.idx).render(render_cl)})"
|
||||
else:
|
||||
val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})"
|
||||
else:
|
||||
if newvar.dtype == dtypes._float4:
|
||||
val = f"({newvar.dtype.name})(*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})))"
|
||||
else:
|
||||
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
|
||||
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
|
||||
if args.valid.min == 1: kk(f"{newvar.render(True)} = {val};")
|
||||
else:
|
||||
casts = {dtypes._float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), dtypes.half: ("(half)", "(half)(0.0f)"), dtypes.float: ("(float)", "0.0f")}[newvar.dtype]
|
||||
kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? {casts[0]}({val}) : {casts[1]};")
|
||||
elif uop == UOps.STORE and (vin[0].dtype == dtypes.float or (vin[0].dtype == dtypes._float4 and vin[0].offset is not None)):
|
||||
assert not isinstance(bufs[args.i].dtype, ImageDType), "image store must be float4"
|
||||
val = lang.render_load(newvar.dtype, bufnames[args.i], bufs[args.i].dtype, args.idx, isinstance(bufs[args.i], LocalBuffer))
|
||||
if args.valid.min == 0 and args.valid.max == 1: val = f"({args.valid.render(render_cl)}) ? ({val}) : {lang.render_const(0.0, newvar.dtype)}"
|
||||
kk(f"{newvar.render(True)} = {val};")
|
||||
elif uop == UOps.STORE:
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
|
||||
kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
|
||||
else:
|
||||
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};")
|
||||
# TODO: instead of dtypes.float, a base type
|
||||
kk(lang.render_store(bufnames[args.i], bufs[args.i].dtype, vin[0].render(), vin[0].dtype if vin[0].offset is None else dtypes.float, args.idx, isinstance(bufs[args.i], LocalBuffer)))
|
||||
elif uop == UOps.CAST and newvar is not None and newvar.dtype == dtypes._float4:
|
||||
kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});")
|
||||
elif uop == UOps.STORE and len(vin) != 0 and vin[0].dtype == dtypes._float4 and vin[0].offset is None:
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
if isinstance(bufs[args[0]].dtype, ImageDType):
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
|
||||
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {vin[0].render()});")
|
||||
elif lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
|
||||
kk(f"vstore_half4({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
|
||||
else:
|
||||
kk(f"*(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*)({bufnames[args.i]}+{args.idx.render(render_cl)})) = ({bufs[args.i].dtype.name}4){vin[0].render()};")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
|
||||
else:
|
||||
raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
if any(isinstance(x.dtype, ImageDType) for x in bufs): prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
|
||||
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
|
||||
("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)
|
||||
if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)]
|
||||
|
||||
@@ -3,18 +3,38 @@ import itertools, math
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same
|
||||
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same, partition
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode
|
||||
VariableOrNum = Union[Variable, NumNode]
|
||||
|
||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); BARRIER = auto(); \
|
||||
SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702
|
||||
|
||||
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
|
||||
idy = (idxy//(4*base_shape[1]))
|
||||
if validhacks and valid.min == 0:
|
||||
idx = (idxy//4) + (idy*-base_shape[1])
|
||||
# find the ones in idx that didn't factorize and remove them (TODO: this is not universal)
|
||||
if isinstance(idx, SumNode):
|
||||
unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1])
|
||||
assert len(unfactored) <= 1
|
||||
idx = Variable.sum(idx_nodes)
|
||||
unfactored = (Variable.sum(unfactored) // base_shape[1])
|
||||
idy += unfactored
|
||||
# ugh really...handtuned garbage
|
||||
if idx.min >= (base_shape[1]*3)//4:
|
||||
idx -= base_shape[1]
|
||||
idy += 1
|
||||
else:
|
||||
idx = (idxy//4)%base_shape[1]
|
||||
if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
|
||||
return idx, idy
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
name: str
|
||||
size: int
|
||||
@@ -185,6 +205,7 @@ class Linearizer:
|
||||
localtype = dtypes.float
|
||||
key = f"{localtype}{idx.render()}{valid.render()}"
|
||||
if key not in cache:
|
||||
if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
|
||||
cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else \
|
||||
self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const)
|
||||
ret.append(Token(cache[key].name, cache[key].dtype, _idx[upcast_dim[0]].b) if localtype == dtypes._float4 else cache[key])
|
||||
@@ -212,7 +233,9 @@ class Linearizer:
|
||||
store_offset = store_offset_new
|
||||
|
||||
for idx, var in store_offset.items():
|
||||
self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(idx)))
|
||||
idx, valid = self.sts[i].expr_idxs(idx)
|
||||
if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
|
||||
self.uop(UOps.STORE, None, [var], MemOp(i, idx, valid))
|
||||
|
||||
def linearize(self):
|
||||
# uops
|
||||
|
||||
@@ -97,9 +97,9 @@ class dtypes:
|
||||
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
||||
int32: Final[DType] = DType(1, 4, "int", np.int32)
|
||||
int64: Final[DType] = DType(2, 8, "long", np.int64)
|
||||
uint8: Final[DType] = DType(0, 1, "uchar", np.uint8)
|
||||
uint32: Final[DType] = DType(1, 4, "uint", np.uint32)
|
||||
uint64: Final[DType] = DType(2, 8, "ulong", np.uint64)
|
||||
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
|
||||
uint32: Final[DType] = DType(1, 4, "unsigned int", np.uint32)
|
||||
uint64: Final[DType] = DType(2, 8, "unsigned long", np.uint64)
|
||||
|
||||
# NOTE: these are internal dtypes, should probably check for that
|
||||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||
|
||||
@@ -79,7 +79,7 @@ class CUDAProgram:
|
||||
|
||||
class CUDACodegen(CStyleCodegen):
|
||||
lang = CStyleLanguage(
|
||||
kernel_prefix = "typedef unsigned char uchar;\ntypedef unsigned int uint;\ntypedef unsigned long ulong;\n__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
|
||||
kernel_prefix = "__global__", smem_prefix = "__shared__ ", barrier = "__syncthreads();", float4 = "make_float4",
|
||||
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
|
||||
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
|
||||
half_prekernel = """
|
||||
|
||||
@@ -87,7 +87,6 @@ class CLProgram:
|
||||
class CLCodegen(CStyleCodegen):
|
||||
lang = CStyleLanguage(
|
||||
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
|
||||
double_prekernel="#ifdef cl_khr_fp64\n#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n#elif defined(cl_amd_fp64)\n#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n#endif",
|
||||
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
|
||||
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
|
||||
gid = [f'get_group_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)
|
||||
|
||||
Reference in New Issue
Block a user