mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
good changes from the M1 Tensor Core project (#730)
* good changes * working except llvm * llvm types * nice acc * archprobe * lang.float4 * use self.acc for late acc * fix store bug
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
|
||||
import math, collections
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
|
||||
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes, colored
|
||||
from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -57,10 +57,6 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
||||
}
|
||||
|
||||
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
|
||||
def group_float4(grp:List[str]) -> str:
|
||||
if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0]
|
||||
else: return f"{lang.float4}({','.join(g for g in grp)})"
|
||||
|
||||
prekernel: Set[str] = set()
|
||||
kernel = []
|
||||
global_size = []
|
||||
@@ -103,7 +99,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
else:
|
||||
kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
|
||||
depth += 1
|
||||
if uop == UOps.ENDLOOP:
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] == "local" and len(lang.lid):
|
||||
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
|
||||
kk(lang.barrier)
|
||||
@@ -116,18 +112,19 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
pend_close = None
|
||||
depth -= 1
|
||||
kk("}"*len(args[0]) + f" /* {args[1]} */")
|
||||
if uop == UOps.CONST:
|
||||
elif uop == UOps.CONST:
|
||||
assert newvar is not None
|
||||
if args == -math.inf:
|
||||
kk(f"float {newvar} = -INFINITY;")
|
||||
kk(f"{newvar.render(True)} = -INFINITY;")
|
||||
else:
|
||||
kk(f"float {newvar} = {args}f;")
|
||||
if uop == UOps.ALU:
|
||||
kk(f"{newvar.render(True)} = {args}f;")
|
||||
elif uop == UOps.ALU:
|
||||
assert newvar is not None
|
||||
if newvar in vin:
|
||||
kk(f"{newvar} = {code_for_op[args](*vin)};")
|
||||
kk(f"{newvar.render()} = {code_for_op[args](*[x.render() for x in vin])};")
|
||||
else:
|
||||
kk(f"float {newvar} = {code_for_op[args](*vin)};")
|
||||
# TODO: refactor the next 14 lines
|
||||
if uop == UOps.LOAD:
|
||||
kk(f"{newvar.render(True)} = {code_for_op[args](*[x.render() for x in vin])};")
|
||||
elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float:
|
||||
# TODO: merge with CONST?
|
||||
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
|
||||
# nan? inf?
|
||||
@@ -138,9 +135,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
else:
|
||||
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
|
||||
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
|
||||
if args.valid.min == 1: kk(f"float {newvar} = {val};")
|
||||
else: kk(f"float {newvar} = ({args.valid.render(render_cl)}) ? ({val}) : 0.0f;")
|
||||
if uop == UOps.LOAD4:
|
||||
if args.valid.min == 1: kk(f"float {newvar.name} = {val};")
|
||||
else: kk(f"float {newvar.name} = ({args.valid.render(render_cl)}) ? ({val}) : 0.0f;")
|
||||
elif uop == UOps.LOAD and newvar is not None and newvar.ltype == LocalTypes.float4:
|
||||
assert newvar.offset is None, "load can't have an offset"
|
||||
if isinstance(bufs[args.i].dtype, ImageDType):
|
||||
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
|
||||
@@ -148,23 +146,27 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
else:
|
||||
val = f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}]"
|
||||
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
|
||||
if args[2].min == 1: kk(f"float4 {newvar} = {val};")
|
||||
else: kk(f"float4 {newvar} = ({args.valid.render(render_cl)}) ? ({val}) : {group_float4(['0.0f']*4)};")
|
||||
if uop == UOps.STORE:
|
||||
if args[2].min == 1: kk(f"{newvar.render(True)} = {val};")
|
||||
else: kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? ({val}) : {lang.float4}(0.0f, 0.0f, 0.0f, 0.0f);")
|
||||
elif uop == UOps.STORE and (vin[0].ltype == LocalTypes.float or (vin[0].ltype == LocalTypes.float4 and vin[0].offset is not None)):
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
|
||||
kk(f"vstore_half({vin[0]}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
|
||||
kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
|
||||
else:
|
||||
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0]};")
|
||||
if uop == UOps.STORE4:
|
||||
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};")
|
||||
elif uop == UOps.CAST and newvar is not None and newvar.ltype == LocalTypes.float4:
|
||||
kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});")
|
||||
elif uop == UOps.STORE and len(vin) != 0 and vin[0].ltype == LocalTypes.float4 and vin[0].offset is None:
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
if isinstance(bufs[args[0]].dtype, ImageDType):
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
|
||||
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {group_float4(vin)});")
|
||||
kk(f"write_imagef({bufnames[args.i]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {vin[0].render()});")
|
||||
else:
|
||||
kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {group_float4(vin)};")
|
||||
if uop == UOps.DEFINE_LOCAL:
|
||||
kk(f"(({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}float4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}] = {vin[0].render()};")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
|
||||
else:
|
||||
raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
|
||||
("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)
|
||||
|
||||
Reference in New Issue
Block a user