mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
delete ltypes (#984)
* delete ltypes * only upcast float types * test dtype on mac passes * ugh, these upcasts
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
|
||||
import math, collections
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
|
||||
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
@@ -105,7 +105,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
assert newvar is not None
|
||||
if args == -math.inf:
|
||||
kk(f"{newvar.render(True)} = -INFINITY;")
|
||||
elif newvar.ltype == LocalTypes.float4:
|
||||
elif newvar.dtype == dtypes._float4:
|
||||
kk(f"{newvar.render(True)} = {{ {args}f, {args}f, {args}f, {args}f }};")
|
||||
else:
|
||||
kk(f"{newvar.render(True)} = {args}f;")
|
||||
@@ -118,42 +118,42 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
elif uop == UOps.LOAD and newvar is not None:
|
||||
# TODO: merge with CONST?
|
||||
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
|
||||
assert newvar.ltype == LocalTypes.float, "const can't be float4"
|
||||
assert newvar.dtype == dtypes.float, "const can't be float4"
|
||||
x = bufs[args.i].realized._buf
|
||||
if math.isnan(x): val = "NAN"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
else: val = f"{x}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "")
|
||||
elif isinstance(bufs[args.i].dtype, ImageDType):
|
||||
assert newvar.ltype == LocalTypes.float4, "image must be float4"
|
||||
assert newvar.dtype == dtypes._float4, "image must be float4"
|
||||
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
|
||||
val = f"read_imagef({bufnames[args.i]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))"
|
||||
else:
|
||||
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
|
||||
if newvar.ltype == LocalTypes.float4:
|
||||
if newvar.dtype == dtypes._float4:
|
||||
val = f"vload_half4({(args.idx//4).render(render_cl)}, {bufnames[args.i]})"
|
||||
else:
|
||||
val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})"
|
||||
else:
|
||||
if newvar.ltype == LocalTypes.float4:
|
||||
val = f"({newvar.ltype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
|
||||
if newvar.dtype == dtypes._float4:
|
||||
val = f"({newvar.dtype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
|
||||
else:
|
||||
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
|
||||
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
|
||||
if args.valid.min == 1: kk(f"{newvar.render(True)} = {val};")
|
||||
else:
|
||||
casts = {LocalTypes.float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), LocalTypes.half: ("(half)", "(half)(0.0f)"), LocalTypes.float: ("(float)", "0.0f")}[newvar.ltype]
|
||||
casts = {dtypes._float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), dtypes.half: ("(half)", "(half)(0.0f)"), dtypes.float: ("(float)", "0.0f")}[newvar.dtype]
|
||||
kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? {casts[0]}({val}) : {casts[1]};")
|
||||
elif uop == UOps.STORE and (vin[0].ltype == LocalTypes.float or (vin[0].ltype == LocalTypes.float4 and vin[0].offset is not None)):
|
||||
elif uop == UOps.STORE and (vin[0].dtype == dtypes.float or (vin[0].dtype == dtypes._float4 and vin[0].offset is not None)):
|
||||
assert not isinstance(bufs[args.i].dtype, ImageDType), "image store must be float4"
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
|
||||
kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
|
||||
else:
|
||||
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};")
|
||||
elif uop == UOps.CAST and newvar is not None and newvar.ltype == LocalTypes.float4:
|
||||
elif uop == UOps.CAST and newvar is not None and newvar.dtype == dtypes._float4:
|
||||
kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});")
|
||||
elif uop == UOps.STORE and len(vin) != 0 and vin[0].ltype == LocalTypes.float4 and vin[0].offset is None:
|
||||
elif uop == UOps.STORE and len(vin) != 0 and vin[0].dtype == dtypes._float4 and vin[0].offset is None:
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
if isinstance(bufs[args[0]].dtype, ImageDType):
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
|
||||
@@ -172,7 +172,6 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
|
||||
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
|
||||
|
||||
|
||||
if lang.half_prekernel: prg =''.join([f"{lang.half_prekernel}", "\n", prg])
|
||||
if lang.double_prekernel: prg = ''.join([f"{lang.double_prekernel}", "\n", prg])
|
||||
return prg, global_size, local_size
|
||||
|
||||
Reference in New Issue
Block a user