mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
delete ltypes (#984)
* delete ltypes * only upcast float types * test dtype on mac passes * ugh, these upcasts
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
|
||||
import math, collections
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
|
||||
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
@@ -105,7 +105,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
assert newvar is not None
|
||||
if args == -math.inf:
|
||||
kk(f"{newvar.render(True)} = -INFINITY;")
|
||||
elif newvar.ltype == LocalTypes.float4:
|
||||
elif newvar.dtype == dtypes._float4:
|
||||
kk(f"{newvar.render(True)} = {{ {args}f, {args}f, {args}f, {args}f }};")
|
||||
else:
|
||||
kk(f"{newvar.render(True)} = {args}f;")
|
||||
@@ -118,42 +118,42 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
elif uop == UOps.LOAD and newvar is not None:
|
||||
# TODO: merge with CONST?
|
||||
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
|
||||
assert newvar.ltype == LocalTypes.float, "const can't be float4"
|
||||
assert newvar.dtype == dtypes.float, "const can't be float4"
|
||||
x = bufs[args.i].realized._buf
|
||||
if math.isnan(x): val = "NAN"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
else: val = f"{x}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "")
|
||||
elif isinstance(bufs[args.i].dtype, ImageDType):
|
||||
assert newvar.ltype == LocalTypes.float4, "image must be float4"
|
||||
assert newvar.dtype == dtypes._float4, "image must be float4"
|
||||
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
|
||||
val = f"read_imagef({bufnames[args.i]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))"
|
||||
else:
|
||||
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
|
||||
if newvar.ltype == LocalTypes.float4:
|
||||
if newvar.dtype == dtypes._float4:
|
||||
val = f"vload_half4({(args.idx//4).render(render_cl)}, {bufnames[args.i]})"
|
||||
else:
|
||||
val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})"
|
||||
else:
|
||||
if newvar.ltype == LocalTypes.float4:
|
||||
val = f"({newvar.ltype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
|
||||
if newvar.dtype == dtypes._float4:
|
||||
val = f"({newvar.dtype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
|
||||
else:
|
||||
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
|
||||
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
|
||||
if args.valid.min == 1: kk(f"{newvar.render(True)} = {val};")
|
||||
else:
|
||||
casts = {LocalTypes.float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), LocalTypes.half: ("(half)", "(half)(0.0f)"), LocalTypes.float: ("(float)", "0.0f")}[newvar.ltype]
|
||||
casts = {dtypes._float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), dtypes.half: ("(half)", "(half)(0.0f)"), dtypes.float: ("(float)", "0.0f")}[newvar.dtype]
|
||||
kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? {casts[0]}({val}) : {casts[1]};")
|
||||
elif uop == UOps.STORE and (vin[0].ltype == LocalTypes.float or (vin[0].ltype == LocalTypes.float4 and vin[0].offset is not None)):
|
||||
elif uop == UOps.STORE and (vin[0].dtype == dtypes.float or (vin[0].dtype == dtypes._float4 and vin[0].offset is not None)):
|
||||
assert not isinstance(bufs[args.i].dtype, ImageDType), "image store must be float4"
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
|
||||
kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
|
||||
else:
|
||||
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};")
|
||||
elif uop == UOps.CAST and newvar is not None and newvar.ltype == LocalTypes.float4:
|
||||
elif uop == UOps.CAST and newvar is not None and newvar.dtype == dtypes._float4:
|
||||
kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});")
|
||||
elif uop == UOps.STORE and len(vin) != 0 and vin[0].ltype == LocalTypes.float4 and vin[0].offset is None:
|
||||
elif uop == UOps.STORE and len(vin) != 0 and vin[0].dtype == dtypes._float4 and vin[0].offset is None:
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
if isinstance(bufs[args[0]].dtype, ImageDType):
|
||||
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
|
||||
@@ -172,7 +172,6 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
|
||||
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
|
||||
|
||||
|
||||
if lang.half_prekernel: prg =''.join([f"{lang.half_prekernel}", "\n", prg])
|
||||
if lang.double_prekernel: prg = ''.join([f"{lang.double_prekernel}", "\n", prg])
|
||||
return prg, global_size, local_size
|
||||
|
||||
@@ -18,21 +18,18 @@ class LocalBuffer(NamedTuple):
|
||||
dtype: DType = dtypes.float32
|
||||
realized: None = None
|
||||
|
||||
# NOTE: half and half4 are not actually used yet
|
||||
class LocalTypes(Enum): float = auto(); float4 = auto(); half = auto(); half4 = auto(); simdgroup_float8x8 = auto() # noqa: E702
|
||||
|
||||
class Token(NamedTuple):
|
||||
name: str
|
||||
ltype: LocalTypes
|
||||
dtype: DType
|
||||
offset: Optional[int] = None
|
||||
def render(self, with_type=False):
|
||||
if with_type:
|
||||
assert self.offset is None
|
||||
return f"{self.ltype.name} {self.name}"
|
||||
return f"{self.dtype.name} {self.name}"
|
||||
if self.offset is None: return self.name
|
||||
assert self.ltype == LocalTypes.float4
|
||||
assert self.dtype == dtypes._float4
|
||||
return self.name+"."+"xyzw"[int(self.offset)]
|
||||
def __repr__(self): return f"<{self.name}>" if self.offset is None and self.ltype == LocalTypes.float else f"<{self.name}:{self.ltype.name}:{self.offset}>"
|
||||
def __repr__(self): return f"<{self.name}>" if self.offset is None and self.dtype == dtypes.float32 else f"<{self.name}:{self.dtype.name}:{self.offset}>"
|
||||
|
||||
# TODO: the next three functions are poorly written
|
||||
def get_grouped_float4_idxs(acc:List[Token]) -> Optional[List[int]]:
|
||||
@@ -40,13 +37,13 @@ def get_grouped_float4_idxs(acc:List[Token]) -> Optional[List[int]]:
|
||||
for i,a in enumerate(acc):
|
||||
if idxs is None: break
|
||||
if i in idxs: continue
|
||||
if a.ltype == LocalTypes.float4 and a.offset == 0:
|
||||
if a.dtype.sz > 1 and a.offset == 0:
|
||||
idxs.append(i)
|
||||
friends: List[int] = []
|
||||
for j,b in enumerate(acc):
|
||||
if len(friends) == 3: break
|
||||
if j in idxs: continue
|
||||
if a.name == b.name and b.ltype == LocalTypes.float4 and b.offset == len(friends)+1:
|
||||
if a.name == b.name and b.dtype.sz > 1 and b.offset == len(friends)+1:
|
||||
friends.append(j)
|
||||
if len(friends) == 3: idxs += friends
|
||||
else: idxs = None
|
||||
@@ -56,8 +53,8 @@ def get_grouped_float4_idxs(acc:List[Token]) -> Optional[List[int]]:
|
||||
|
||||
def to_float4(x:List[Token]) -> Optional[Token]:
|
||||
if all_same(x): return x[0]
|
||||
if all_same([y.name for y in x]) and all([y.ltype == LocalTypes.float4 and y.offset == i for i,y in enumerate(x)]):
|
||||
return Token(x[0].name, LocalTypes.float4)
|
||||
if all_same([y.name for y in x]) and all([y.dtype == dtypes._float4 and y.offset == i for i,y in enumerate(x)]):
|
||||
return Token(x[0].name, dtypes._float4)
|
||||
return None
|
||||
|
||||
def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
|
||||
@@ -170,10 +167,10 @@ class Linearizer:
|
||||
return store_offset_float4
|
||||
|
||||
def global_load(self, i, idxs:List[Variable], const=None) -> List[Token]:
|
||||
load_offset: Dict[Tuple[int, ...], Any] = {uidxs:(LocalTypes.float,uidxs)+self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]) for uidxs in self.shape_offsets(i)}
|
||||
load_offset: Dict[Tuple[int, ...], Any] = {uidxs:(dtypes.float,uidxs)+self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]) for uidxs in self.shape_offsets(i)}
|
||||
|
||||
# float4 grouping (optional)
|
||||
should_upcast = self.supports_float4 and len(self.float4_axis(i)) == 1
|
||||
should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1
|
||||
if should_upcast:
|
||||
load_offset_new = {}
|
||||
for k,out_tokens in self._group_float4(i, load_offset).items():
|
||||
@@ -183,7 +180,7 @@ class Linearizer:
|
||||
# idxs not in order, valids don't match, or idx doesn't evenly divide 4. use normal float
|
||||
for x in out_tokens: load_offset_new[x[1]] = x
|
||||
else:
|
||||
load_offset_new[k] = (LocalTypes.float4, [x[1] for x in out_tokens], out_tokens[0][2], out_tokens[0][3])
|
||||
load_offset_new[k] = (dtypes._float4, [x[1] for x in out_tokens], out_tokens[0][2], out_tokens[0][3])
|
||||
load_offset = load_offset_new
|
||||
|
||||
# do loads
|
||||
@@ -193,9 +190,9 @@ class Linearizer:
|
||||
key = f"{localtype}{idx.render()}{valid.render()}"
|
||||
if key not in cache:
|
||||
cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const)
|
||||
if localtype == LocalTypes.float4:
|
||||
if localtype == dtypes._float4:
|
||||
for j,uidx in enumerate(uidx_list):
|
||||
loaded[uidx] = Token(cache[key].name, LocalTypes.float4, j)
|
||||
loaded[uidx] = Token(cache[key].name, dtypes._float4, j)
|
||||
else:
|
||||
loaded[uidxs] = cache[key]
|
||||
return [loaded[uidxs] for uidxs in self.shape_offsets(i)]
|
||||
@@ -204,14 +201,15 @@ class Linearizer:
|
||||
store_offset: Dict[Tuple[int, ...], Token] = dict(zip(self.shape_offsets(i), store))
|
||||
|
||||
# float4 grouping (optional)
|
||||
should_upcast = self.supports_float4 and (self.bufs[i].dtype not in (dtypes.float16, dtypes.int8, dtypes.uint8)) and len(self.float4_axis(i)) == 1
|
||||
# TODO: why does this not work for float16?
|
||||
should_upcast = self.supports_float4 and (self.bufs[i].dtype == dtypes.float32 or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1
|
||||
if should_upcast:
|
||||
store_offset_new = {}
|
||||
for k,out_tokens in self._group_float4(i, store_offset).items():
|
||||
if all_same([x.name for x in out_tokens]) and tuple(range(4)) == tuple(x.offset for x in out_tokens):
|
||||
store_offset_new[k] = Token(out_tokens[0].name, LocalTypes.float4)
|
||||
store_offset_new[k] = Token(out_tokens[0].name, dtypes._float4)
|
||||
else:
|
||||
store_offset_new[k] = self.uop(UOps.CAST, ssa("alu", LocalTypes.float4), out_tokens)
|
||||
store_offset_new[k] = self.uop(UOps.CAST, ssa("alu", dtypes._float4), out_tokens)
|
||||
store_offset = store_offset_new
|
||||
|
||||
# do stores
|
||||
@@ -242,7 +240,7 @@ class Linearizer:
|
||||
|
||||
# ssa
|
||||
_ssa:DefaultDict[str,int] = defaultdict(int)
|
||||
def ssa(name, ltype=LocalTypes.float) -> Token:
|
||||
def ssa(name, ltype=dtypes.float) -> Token:
|
||||
_ssa[name] += 1
|
||||
return Token(f"{name}{_ssa[name]-1}", ltype)
|
||||
|
||||
@@ -345,12 +343,12 @@ class Linearizer:
|
||||
if isinstance(x.op, (ReduceOps, FusedOps)):
|
||||
ret = [(idx, self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(acc, *values, grouping_allowed=self.supports_float4_alu)]
|
||||
else:
|
||||
ret = [(idx, self.uop(UOps.ALU, ssa('alu', LocalTypes.float4) if any(x.ltype == LocalTypes.float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
|
||||
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
|
||||
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
|
||||
# scatter
|
||||
for i,j in ret:
|
||||
for o,k in enumerate(i):
|
||||
ordered_ret[k] = Token(j.name, j.ltype, o) if j.ltype == LocalTypes.float4 else j
|
||||
ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
|
||||
assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
|
||||
return cast(List[Token], ordered_ret)
|
||||
|
||||
@@ -537,7 +535,7 @@ class Linearizer:
|
||||
|
||||
# if nothing at all is upcasted and it's easy to, do an upcast
|
||||
# TODO: this is breaking the tests
|
||||
#for splits in [4]:
|
||||
# if self.upcasted == 0 and len(self.full_unupcasted_shape) > 0 and self.full_unupcasted_shape[-1] % splits == 0:
|
||||
# self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
|
||||
# self.upcast()
|
||||
for splits in [4]:
|
||||
if self.upcasted == 0 and len(self.full_unupcasted_shape) > 0 and self.full_unupcasted_shape[-1] % splits == 0:
|
||||
self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
|
||||
self.upcast()
|
||||
|
||||
@@ -56,7 +56,8 @@ class DType(NamedTuple):
|
||||
priority: int # this determines when things get upcasted
|
||||
itemsize: int
|
||||
name: str
|
||||
np: type # TODO: someday this will be removed with the "remove numpy" project
|
||||
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
|
||||
sz: int = 1
|
||||
def __repr__(self): return f"dtypes.{self.name}"
|
||||
|
||||
# dependent typing?
|
||||
@@ -80,7 +81,9 @@ class dtypes:
|
||||
def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]
|
||||
bool: Final[DType] = DType(0, 1, "bool", bool)
|
||||
float16: Final[DType] = DType(0, 2, "half", np.float16)
|
||||
half = float16
|
||||
float32: Final[DType] = DType(4, 4, "float", np.float32)
|
||||
float = float32
|
||||
float64: Final[DType] = DType(5, 8, "double", np.float64)
|
||||
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
||||
int32: Final[DType] = DType(1, 4, "int", np.int32)
|
||||
@@ -89,6 +92,9 @@ class dtypes:
|
||||
uint32: Final[DType] = DType(1, 4, "uint", np.uint32)
|
||||
uint64: Final[DType] = DType(2, 8, "uint64", np.uint64)
|
||||
|
||||
# NOTE: these are internal dtypes, should probably check for that
|
||||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
||||
|
||||
class GlobalCounters:
|
||||
global_ops: ClassVar[int] = 0
|
||||
|
||||
@@ -136,6 +136,7 @@ class LazyBuffer:
|
||||
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
|
||||
elif self.op.op == LoadOps.RAND:
|
||||
rng = np.random.default_rng(self.op.arg)
|
||||
assert self.dtype.np is not None, "internal dtypes don't work with LoadOps.RAND"
|
||||
self.realized = Device[self.device].buffer.fromCPU(rng.random(size=self.shape, dtype=self.dtype.np), **self._device_extra_args())
|
||||
elif self.op.op == LoadOps.CONST:
|
||||
if hasattr(Device[self.device].codegen, 'supports_constant_folding'):
|
||||
|
||||
Reference in New Issue
Block a user