good changes from the M1 Tensor Core project (#730)

* good changes

* working except llvm

* llvm types

* nice acc

* archprobe

* lang.float4

* use self.acc for late acc

* fix store bug
This commit is contained in:
George Hotz
2023-03-29 05:11:02 +04:00
committed by GitHub
parent 156640e90d
commit 20894991ed
8 changed files with 231 additions and 83 deletions

View File

@@ -1,8 +1,8 @@
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes, colored
from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
from tinygrad.lazy import LazyBuffer
@@ -57,10 +57,6 @@ code_for_op: Final[Dict[Op, Callable]] = {
}
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
def group_float4(grp:List[str]) -> str:
if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0]
else: return f"{lang.float4}({','.join(g for g in grp)})"
prekernel: Set[str] = set()
kernel = []
global_size = []
@@ -103,7 +99,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
else:
kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
depth += 1
if uop == UOps.ENDLOOP:
elif uop == UOps.ENDLOOP:
if args[1] == "local" and len(lang.lid):
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
kk(lang.barrier)
@@ -116,18 +112,19 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
pend_close = None
depth -= 1
kk("}"*len(args[0]) + f" /* {args[1]} */")
if uop == UOps.CONST:
elif uop == UOps.CONST:
assert newvar is not None
if args == -math.inf:
kk(f"float {newvar} = -INFINITY;")
kk(f"{newvar.render(True)} = -INFINITY;")
else:
kk(f"float {newvar} = {args}f;")
if uop == UOps.ALU:
kk(f"{newvar.render(True)} = {args}f;")
elif uop == UOps.ALU:
assert newvar is not None
if newvar in vin:
kk(f"{newvar} = {code_for_op[args](*vin)};")
kk(f"{newvar.render()} = {code_for_op[args](*[x.render() for x in vin])};")
else:
kk(f"float {newvar} = {code_for_op[args](*vin)};")
# TODO: refactor the next 14 lines
if uop == UOps.LOAD:
kk(f"{newvar.render(True)} = {code_for_op[args](*[x.render() for x in vin])};")
elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float:
# TODO: merge with CONST?
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
# nan? inf?
@@ -138,9 +135,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
else:
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
if args.valid.min == 1: kk(f"float {newvar} = {val};")
else: kk(f"float {newvar} = ({args.valid.render(render_cl)}) ? ({val}) : 0.0f;")
if uop == UOps.LOAD4:
if args.valid.min == 1: kk(f"float {newvar.name} = {val};")
else: kk(f"float {newvar.name} = ({args.valid.render(render_cl)}) ? ({val}) : 0.0f;")
elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float4:
assert newvar.offset is None, "load can't have an offset"
if isinstance(bufs[args.i].dtype, ImageDType):
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
@@ -148,23 +146,27 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
else:
val = f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
if args[2].min == 1: kk(f"float4 {newvar} = {val};")
else: kk(f"float4 {newvar} = ({args.valid.render(render_cl)}) ? ({val}) : {group_float4(['0.0f']*4)};")
if uop == UOps.STORE:
if args[2].min == 1: kk(f"{newvar.render(True)} = {val};")
else: kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? ({val}) : {lang.float4}(0.0f, 0.0f, 0.0f, 0.0f);")
elif uop == UOps.STORE and (vin[0].ltype == LocalTypes.float or (vin[0].ltype == LocalTypes.float4 and vin[0].offset is not None)):
assert args.valid.min == 1, "store must be valid"
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
kk(f"vstore_half({vin[0]}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
else:
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0]};")
if uop == UOps.STORE4:
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};")
elif uop == UOps.CAST and newvar is not None and newvar.ltype == LocalTypes.float4:
kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});")
elif uop == UOps.STORE and len(vin) != 0 and vin[0].ltype == LocalTypes.float4 and vin[0].offset is None:
assert args.valid.min == 1, "store must be valid"
if isinstance(bufs[args[0]].dtype, ImageDType):
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {group_float4(vin)});")
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {vin[0].render()});")
else:
kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {group_float4(vin)};")
if uop == UOps.DEFINE_LOCAL:
kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {vin[0].render()};")
elif uop == UOps.DEFINE_LOCAL:
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
else:
raise RuntimeError(f"failed to render {uop}")
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)