factor uops transformers into functions

This commit is contained in:
George Hotz
2023-03-20 08:19:48 -07:00
parent 623fb1ef28
commit 9b314c6342
2 changed files with 214 additions and 200 deletions

View File

@@ -1,10 +1,11 @@
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Any
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import getenv, all_same, partition, ImageDType, DEBUG, dtypes
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
from tinygrad.lazy import LazyBuffer
# div is different in cl than python
render_cl = render_python.copy()
@@ -46,6 +47,132 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=F
if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
return idx, idy
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP: lambda x: f"native_exp({x})" if NATIVE_EXPLOG else f"exp({x})",
UnaryOps.LOG: lambda x: f"native_log({x})" if NATIVE_EXPLOG else f"log({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})",
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({b}*{c})+{a})"
}
def uops_to_cstyle(uops:List[Tuple[UOps, Optional[str], Any]], bufs:List[LazyBuffer], bufnames:List[str], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
def group_float4(grp:List[str]) -> str:
if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0]
else: return f"{lang.float4}({','.join(g for g in grp)})"
prekernel: Set[str] = set()
kernel = []
global_size = []
local_size = []
pend_close = None
depth = 0
def kk(s): kernel.append(" "*depth+s)
for uop,newvar,args in uops:
if uop == UOps.LOOP:
root = None
for i,var in enumerate(args[0]):
if isinstance(var, NumNode):
if args[1] == "global" and lang.gid: global_size.append(1)
if args[1] == "local" and lang.lid: local_size.append(1)
# one number, not an index
kk("{")
else:
if args[1] == "global" and lang.gid:
if len(args[0]) >= 4 and len(args[0])-i > 2:
# sometimes, there's more dimensions. compact all the dimensions into the last CL dimension
# TODO: these compactions should be searchable (they sort of are with reshapes and permutes)
if i == 0:
kk(f"{{ int {var.expr} = {lang.gid[-1]}; /* {var.max+1} */")
root = var.expr
global_size.append(var.max+1)
else:
kk(f"{{ int {var.expr} = {root} % {var.max+1}; {root} /= {var.max+1};")
global_size[-1] *= var.max+1
else:
kk(f"{{ int {var.expr} = {lang.gid[len(args[0])-1-i]}; /* {var.max+1} */")
global_size.append(var.max+1)
elif args[1] == "local" and lang.lid:
assert len(args[0]) <= len(lang.lid)
kk(f"{{ int {var.expr} = {lang.lid[len(args[0])-1-i]}; /* {var.max+1} */")
local_size.append(var.max+1)
else:
kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
depth += 1
if uop == UOps.ENDLOOP:
if args[1] == "local" and len(lang.lid):
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
kk(lang.barrier)
kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{")
pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */"
else:
if args[1] == "global" and pend_close:
depth -= 1
kk(pend_close)
pend_close = None
depth -= 1
kk("}"*len(args[0]) + f" /* {args[1]} */")
if uop == UOps.CONST:
if args[0] == -math.inf:
kk(f"float {newvar} = -INFINITY;")
else:
kk(f"float {newvar} = {args[0]}f;")
if uop == UOps.ALU:
if newvar is None:
kk(f"{args[2]} = {code_for_op[args[0]](*args[1])};")
else:
kk(f"float {newvar} = {code_for_op[args[0]](*args[1])};")
# TODO: refactor the next 14 lines
if uop == UOps.LOAD:
# TODO: merge with CONST?
if bufs[args[0]] is not None and isinstance(bufs[args[0]].realized, RawConst):
# nan? inf?
val = f"{bufs[args[0]].realized._buf}f"
else:
if lang.uses_vload and bufs[args[0]] is not None and bufs[args[0]].dtype == dtypes.float16:
val = f"vload_half({args[1].render(render_cl)}, {bufnames[args[0]]})"
else:
val = f"{bufnames[args[0]]}[{args[1].render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
if args[2].min == 1: kk(f"float {newvar} = {val};")
else: kk(f"float {newvar} = ({args[2].render(render_cl)}) ? ({val}) : 0.0f;")
if uop == UOps.LOAD4:
if bufs[args[0]] is not None and isinstance(bufs[args[0]].dtype, ImageDType):
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
idx, idy = to_image_idx(bufs[args[0]].dtype.shape, args[1], args[2])
val = f"read_imagef({bufnames[args[0]]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))"
else:
val = f"(({lang.buffer_prefix if bufs[args[0]] is not None else lang.smem_prefix}float4*){bufnames[args[0]]})[{(args[1]//4).render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
if args[2].min == 1: kk(f"float4 {newvar} = {val};")
else: kk(f"float4 {newvar} = ({args[2].render(render_cl)}) ? ({val}) : {group_float4(['0.0f']*4)};")
if uop == UOps.STORE:
assert args[2].min == 1, "store must be valid"
if lang.uses_vload and bufs[args[0]] is not None and bufs[args[0]].dtype == dtypes.float16:
kk(f"vstore_half({args[3]}, {args[1].render(render_cl)}, {bufnames[args[0]]});")
else:
kk(f"{bufnames[args[0]]}[{args[1].render(render_cl)}] = {args[3]};")
if uop == UOps.STORE4:
assert args[2].min == 1, "store must be valid"
if bufs[args[0]] is not None and isinstance(bufs[args[0]].dtype, ImageDType):
idx, idy = to_image_idx(bufs[args[0]].dtype.shape, args[1], args[2])
kk(f"write_imagef({bufnames[args[0]]}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {group_float4(args[3])});")
else:
kk(f"(({lang.buffer_prefix if bufs[args[0]] is not None else lang.smem_prefix}float4*){bufnames[args[0]]})[{(args[1]//4).render(render_cl)}] = {group_float4(args[3])};")
if uop == UOps.DEFINE_LOCAL:
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)
if x is not None and not isinstance(x.realized, RawConst)]
prg = ''.join([f"{lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
return prg, global_size, local_size
class CStyleCodegen(Linearizer):
lang: ClassVar[CStyleLanguage] = CStyleLanguage()
supports_constant_folding: bool = True
@@ -55,19 +182,6 @@ class CStyleCodegen(Linearizer):
kernel_cnt: Final[DefaultDict[str, int]] = collections.defaultdict(int)
kernel_name_cache: Final[Dict[str, str]] = {}
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP: lambda x: f"native_exp({x})" if NATIVE_EXPLOG else f"exp({x})",
UnaryOps.LOG: lambda x: f"native_log({x})" if NATIVE_EXPLOG else f"log({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.POW: lambda a,b: f"pow({a},{b})", BinaryOps.MAX: lambda a,b: f"max({a},{b})",
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({b}*{c})+{a})"
}
def group_float4(self, grp:List[str]) -> str:
if all(g.endswith(e) for g,e in zip(grp, [".x", ".y", ".z", ".w"])) and all_same([g.split(".")[0] for g in grp]): return grp[0].split(".")[0]
else: return f"{self.lang.float4}({','.join(g for g in grp)})"
def codegen(self):
self.process()
@@ -86,113 +200,7 @@ class CStyleCodegen(Linearizer):
self.hand_coded_optimizations()
self.linearize()
prekernel: Set[str] = set()
kernel = []
global_size = []
local_size = []
pend_close = None
depth = 0
def kk(s): kernel.append(" "*depth+s)
for uop,newvar,args in self.uops:
if uop == UOps.LOOP:
root = None
for i,var in enumerate(args[0]):
if isinstance(var, NumNode):
if args[1] == "global" and self.lang.gid: global_size.append(1)
if args[1] == "local" and self.lang.lid: local_size.append(1)
# one number, not an index
kk("{")
else:
if args[1] == "global" and self.lang.gid:
if len(args[0]) >= 4 and len(args[0])-i > 2:
# sometimes, there's more dimensions. compact all the dimensions into the last CL dimension
# TODO: these compactions should be searchable (they sort of are with reshapes and permutes)
if i == 0:
kk(f"{{ int {var.expr} = {self.lang.gid[-1]}; /* {var.max+1} */")
root = var.expr
global_size.append(var.max+1)
else:
kk(f"{{ int {var.expr} = {root} % {var.max+1}; {root} /= {var.max+1};")
global_size[-1] *= var.max+1
else:
kk(f"{{ int {var.expr} = {self.lang.gid[len(args[0])-1-i]}; /* {var.max+1} */")
global_size.append(var.max+1)
elif args[1] == "local" and self.lang.lid:
assert len(args[0]) <= len(self.lang.lid)
kk(f"{{ int {var.expr} = {self.lang.lid[len(args[0])-1-i]}; /* {var.max+1} */")
local_size.append(var.max+1)
else:
kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
depth += 1
if uop == UOps.ENDLOOP:
if args[1] == "local" and len(self.lang.lid):
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
kk(self.lang.barrier)
kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{")
pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */"
else:
if args[1] == "global" and pend_close:
depth -= 1
kk(pend_close)
pend_close = None
depth -= 1
kk("}"*len(args[0]) + f" /* {args[1]} */")
if uop == UOps.CONST:
if args[0] == -math.inf:
kk(f"float {newvar} = -INFINITY;")
else:
kk(f"float {newvar} = {args[0]}f;")
if uop == UOps.ALU:
if newvar is None:
kk(f"{args[2]} = {self.code_for_op[args[0]](*args[1])};")
else:
kk(f"float {newvar} = {self.code_for_op[args[0]](*args[1])};")
# TODO: refactor the next 14 lines
if uop == UOps.LOAD:
# TODO: merge with CONST?
if self.bufs[args[0]] is not None and isinstance(self.bufs[args[0]].realized, RawConst):
# nan? inf?
val = f"{self.bufs[args[0]].realized._buf}f"
else:
if self.lang.uses_vload and self.bufs[args[0]] is not None and self.bufs[args[0]].dtype == dtypes.float16:
val = f"vload_half({args[1].render(render_cl)}, {self.registers[args[0]].name})"
else:
val = f"{self.registers[args[0]].name}[{args[1].render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
if args[2].min == 1: kk(f"float {newvar} = {val};")
else: kk(f"float {newvar} = ({args[2].render(render_cl)}) ? ({val}) : 0.0f;")
if uop == UOps.LOAD4:
if self.bufs[args[0]] is not None and isinstance(self.bufs[args[0]].dtype, ImageDType):
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
idx, idy = to_image_idx(self.bufs[args[0]].dtype.shape, args[1], args[2])
val = f"read_imagef({self.registers[args[0]].name}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))"
else:
val = f"(({self.lang.buffer_prefix if self.bufs[args[0]] is not None else self.lang.smem_prefix}float4*){self.registers[args[0]].name})[{(args[1]//4).render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
if args[2].min == 1: kk(f"float4 {newvar} = {val};")
else: kk(f"float4 {newvar} = ({args[2].render(render_cl)}) ? ({val}) : {self.group_float4(['0.0f']*4)};")
if uop == UOps.STORE:
assert args[2].min == 1, "store must be valid"
if self.lang.uses_vload and self.bufs[args[0]] is not None and self.bufs[args[0]].dtype == dtypes.float16:
kk(f"vstore_half({args[3]}, {args[1].render(render_cl)}, {self.registers[args[0]].name});")
else:
kk(f"{self.registers[args[0]].name}[{args[1].render(render_cl)}] = {args[3]};")
if uop == UOps.STORE4:
assert args[2].min == 1, "store must be valid"
if self.bufs[args[0]] is not None and isinstance(self.bufs[args[0]].dtype, ImageDType):
idx, idy = to_image_idx(self.bufs[args[0]].dtype.shape, args[1], args[2])
kk(f"write_imagef({self.registers[args[0]].name}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {self.group_float4(args[3])});")
else:
kk(f"(({self.lang.buffer_prefix if self.bufs[args[0]] is not None else self.lang.smem_prefix}float4*){self.registers[args[0]].name})[{(args[1]//4).render(render_cl)}] = {self.group_float4(args[3])};")
if uop == UOps.DEFINE_LOCAL:
kk(self.lang.smem_prefix + f"float {args[0]}[{args[1]}];")
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else self.lang.buffer_prefix+x.dtype.name+"*"+self.lang.buffer_suffix) for i,x in enumerate(self.bufs) if x is not None and not isinstance(x.realized, RawConst)]
prg = ''.join([f"{self.lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
[', '.join([f'{"const" if i > 0 else ""} {t} data{i}' for i,t in buftypes] + self.lang.extra_args)] +
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
prg, global_size, local_size = uops_to_cstyle(self.uops, self.bufs, [x.name for x in self.registers], self.lang)
# if we have local_sizes, we have to correct the global_size
for i,s in enumerate(local_size): global_size[i] *= s