delete ltypes (#984)

* delete ltypes

* only upcast float types

* test dtype on mac passes

* ugh, these upcasts
This commit is contained in:
George Hotz
2023-06-15 16:24:45 -07:00
committed by GitHub
parent 804c45b5fc
commit 039f0d372f
5 changed files with 49 additions and 44 deletions

View File

@@ -1,6 +1,6 @@
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored
from tinygrad.runtime.lib import RawConst
@@ -105,7 +105,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
assert newvar is not None
if args == -math.inf:
kk(f"{newvar.render(True)} = -INFINITY;")
elif newvar.ltype == LocalTypes.float4:
elif newvar.dtype == dtypes._float4:
kk(f"{newvar.render(True)} = {{ {args}f, {args}f, {args}f, {args}f }};")
else:
kk(f"{newvar.render(True)} = {args}f;")
@@ -118,42 +118,42 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
elif uop == UOps.LOAD and newvar is not None:
# TODO: merge with CONST?
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
assert newvar.ltype == LocalTypes.float, "const can't be float4"
assert newvar.dtype == dtypes.float, "const can't be float4"
x = bufs[args.i].realized._buf
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{x}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "")
elif isinstance(bufs[args.i].dtype, ImageDType):
assert newvar.ltype == LocalTypes.float4, "image must be float4"
assert newvar.dtype == dtypes._float4, "image must be float4"
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
val = f"read_imagef({bufnames[args.i]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))"
else:
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
if newvar.ltype == LocalTypes.float4:
if newvar.dtype == dtypes._float4:
val = f"vload_half4({(args.idx//4).render(render_cl)}, {bufnames[args.i]})"
else:
val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})"
else:
if newvar.ltype == LocalTypes.float4:
val = f"({newvar.ltype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
if newvar.dtype == dtypes._float4:
val = f"({newvar.dtype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
else:
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
if args.valid.min == 1: kk(f"{newvar.render(True)} = {val};")
else:
casts = {LocalTypes.float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), LocalTypes.half: ("(half)", "(half)(0.0f)"), LocalTypes.float: ("(float)", "0.0f")}[newvar.ltype]
casts = {dtypes._float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), dtypes.half: ("(half)", "(half)(0.0f)"), dtypes.float: ("(float)", "0.0f")}[newvar.dtype]
kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? {casts[0]}({val}) : {casts[1]};")
elif uop == UOps.STORE and (vin[0].ltype == LocalTypes.float or (vin[0].ltype == LocalTypes.float4 and vin[0].offset is not None)):
elif uop == UOps.STORE and (vin[0].dtype == dtypes.float or (vin[0].dtype == dtypes._float4 and vin[0].offset is not None)):
assert not isinstance(bufs[args.i].dtype, ImageDType), "image store must be float4"
assert args.valid.min == 1, "store must be valid"
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
else:
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};")
elif uop == UOps.CAST and newvar is not None and newvar.ltype == LocalTypes.float4:
elif uop == UOps.CAST and newvar is not None and newvar.dtype == dtypes._float4:
kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});")
elif uop == UOps.STORE and len(vin) != 0 and vin[0].ltype == LocalTypes.float4 and vin[0].offset is None:
elif uop == UOps.STORE and len(vin) != 0 and vin[0].dtype == dtypes._float4 and vin[0].offset is None:
assert args.valid.min == 1, "store must be valid"
if isinstance(bufs[args[0]].dtype, ImageDType):
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
@@ -172,7 +172,6 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
if lang.half_prekernel: prg =''.join([f"{lang.half_prekernel}", "\n", prg])
if lang.double_prekernel: prg = ''.join([f"{lang.double_prekernel}", "\n", prg])
return prg, global_size, local_size