delete ltypes (#984)

* delete ltypes

* only upcast float types

* test dtype on mac passes

* ugh, these upcasts
This commit is contained in:
George Hotz
2023-06-15 16:24:45 -07:00
committed by GitHub
parent 804c45b5fc
commit 039f0d372f
5 changed files with 49 additions and 44 deletions

View File

@@ -1,6 +1,6 @@
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import partition, ImageDType, DEBUG, dtypes, colored
from tinygrad.runtime.lib import RawConst
@@ -105,7 +105,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
assert newvar is not None
if args == -math.inf:
kk(f"{newvar.render(True)} = -INFINITY;")
elif newvar.ltype == LocalTypes.float4:
elif newvar.dtype == dtypes._float4:
kk(f"{newvar.render(True)} = {{ {args}f, {args}f, {args}f, {args}f }};")
else:
kk(f"{newvar.render(True)} = {args}f;")
@@ -118,42 +118,42 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
elif uop == UOps.LOAD and newvar is not None:
# TODO: merge with CONST?
if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst):
assert newvar.ltype == LocalTypes.float, "const can't be float4"
assert newvar.dtype == dtypes.float, "const can't be float4"
x = bufs[args.i].realized._buf
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{x}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "")
elif isinstance(bufs[args.i].dtype, ImageDType):
assert newvar.ltype == LocalTypes.float4, "image must be float4"
assert newvar.dtype == dtypes._float4, "image must be float4"
prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args.idx, args.valid)
val = f"read_imagef({bufnames[args.i]}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}))"
else:
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
if newvar.ltype == LocalTypes.float4:
if newvar.dtype == dtypes._float4:
val = f"vload_half4({(args.idx//4).render(render_cl)}, {bufnames[args.i]})"
else:
val = f"vload_half({args.idx.render(render_cl)}, {bufnames[args.i]})"
else:
if newvar.ltype == LocalTypes.float4:
val = f"({newvar.ltype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
if newvar.dtype == dtypes._float4:
val = f"({newvar.dtype.name})((({lang.smem_prefix if isinstance(bufs[args.i], LocalBuffer) else lang.buffer_prefix}{bufs[args.i].dtype.name}4*){bufnames[args.i]})[{(args.idx//4).render(render_cl)}])"
else:
val = f"{bufnames[args.i]}[{args.idx.render(render_cl)}]"
# NOTE: if min and max are both 0, it should be a CONST in the Linearizer
if args.valid.min == 1: kk(f"{newvar.render(True)} = {val};")
else:
casts = {LocalTypes.float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), LocalTypes.half: ("(half)", "(half)(0.0f)"), LocalTypes.float: ("(float)", "0.0f")}[newvar.ltype]
casts = {dtypes._float4: ("", f"{lang.float4}(0.0f, 0.0f, 0.0f, 0.0f)"), dtypes.half: ("(half)", "(half)(0.0f)"), dtypes.float: ("(float)", "0.0f")}[newvar.dtype]
kk(f"{newvar.render(True)} = ({args.valid.render(render_cl)}) ? {casts[0]}({val}) : {casts[1]};")
elif uop == UOps.STORE and (vin[0].ltype == LocalTypes.float or (vin[0].ltype == LocalTypes.float4 and vin[0].offset is not None)):
elif uop == UOps.STORE and (vin[0].dtype == dtypes.float or (vin[0].dtype == dtypes._float4 and vin[0].offset is not None)):
assert not isinstance(bufs[args.i].dtype, ImageDType), "image store must be float4"
assert args.valid.min == 1, "store must be valid"
if lang.uses_vload and bufs[args.i].dtype == dtypes.float16:
kk(f"vstore_half({vin[0].render()}, {args.idx.render(render_cl)}, {bufnames[args.i]});")
else:
kk(f"{bufnames[args.i]}[{args.idx.render(render_cl)}] = {vin[0].render()};")
elif uop == UOps.CAST and newvar is not None and newvar.ltype == LocalTypes.float4:
elif uop == UOps.CAST and newvar is not None and newvar.dtype == dtypes._float4:
kk(f"{newvar.render(True)} = {lang.float4}({','.join([x.render() for x in vin])});")
elif uop == UOps.STORE and len(vin) != 0 and vin[0].ltype == LocalTypes.float4 and vin[0].offset is None:
elif uop == UOps.STORE and len(vin) != 0 and vin[0].dtype == dtypes._float4 and vin[0].offset is None:
assert args.valid.min == 1, "store must be valid"
if isinstance(bufs[args[0]].dtype, ImageDType):
idx, idy = to_image_idx(bufs[args.i].dtype.shape, args[1], args[2])
@@ -172,7 +172,6 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
if lang.half_prekernel: prg =''.join([f"{lang.half_prekernel}", "\n", prg])
if lang.double_prekernel: prg = ''.join([f"{lang.double_prekernel}", "\n", prg])
return prg, global_size, local_size

View File

@@ -18,21 +18,18 @@ class LocalBuffer(NamedTuple):
dtype: DType = dtypes.float32
realized: None = None
# NOTE: half and half4 are not actually used yet
class LocalTypes(Enum): float = auto(); float4 = auto(); half = auto(); half4 = auto(); simdgroup_float8x8 = auto() # noqa: E702
class Token(NamedTuple):
name: str
ltype: LocalTypes
dtype: DType
offset: Optional[int] = None
def render(self, with_type=False):
if with_type:
assert self.offset is None
return f"{self.ltype.name} {self.name}"
return f"{self.dtype.name} {self.name}"
if self.offset is None: return self.name
assert self.ltype == LocalTypes.float4
assert self.dtype == dtypes._float4
return self.name+"."+"xyzw"[int(self.offset)]
def __repr__(self): return f"<{self.name}>" if self.offset is None and self.ltype == LocalTypes.float else f"<{self.name}:{self.ltype.name}:{self.offset}>"
def __repr__(self): return f"<{self.name}>" if self.offset is None and self.dtype == dtypes.float32 else f"<{self.name}:{self.dtype.name}:{self.offset}>"
# TODO: the next three functions are poorly written
def get_grouped_float4_idxs(acc:List[Token]) -> Optional[List[int]]:
@@ -40,13 +37,13 @@ def get_grouped_float4_idxs(acc:List[Token]) -> Optional[List[int]]:
for i,a in enumerate(acc):
if idxs is None: break
if i in idxs: continue
if a.ltype == LocalTypes.float4 and a.offset == 0:
if a.dtype.sz > 1 and a.offset == 0:
idxs.append(i)
friends: List[int] = []
for j,b in enumerate(acc):
if len(friends) == 3: break
if j in idxs: continue
if a.name == b.name and b.ltype == LocalTypes.float4 and b.offset == len(friends)+1:
if a.name == b.name and b.dtype.sz > 1 and b.offset == len(friends)+1:
friends.append(j)
if len(friends) == 3: idxs += friends
else: idxs = None
@@ -56,8 +53,8 @@ def get_grouped_float4_idxs(acc:List[Token]) -> Optional[List[int]]:
def to_float4(x:List[Token]) -> Optional[Token]:
if all_same(x): return x[0]
if all_same([y.name for y in x]) and all([y.ltype == LocalTypes.float4 and y.offset == i for i,y in enumerate(x)]):
return Token(x[0].name, LocalTypes.float4)
if all_same([y.name for y in x]) and all([y.dtype == dtypes._float4 and y.offset == i for i,y in enumerate(x)]):
return Token(x[0].name, dtypes._float4)
return None
def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
@@ -170,10 +167,10 @@ class Linearizer:
return store_offset_float4
def global_load(self, i, idxs:List[Variable], const=None) -> List[Token]:
load_offset: Dict[Tuple[int, ...], Any] = {uidxs:(LocalTypes.float,uidxs)+self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]) for uidxs in self.shape_offsets(i)}
load_offset: Dict[Tuple[int, ...], Any] = {uidxs:(dtypes.float,uidxs)+self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]) for uidxs in self.shape_offsets(i)}
# float4 grouping (optional)
should_upcast = self.supports_float4 and len(self.float4_axis(i)) == 1
should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1
if should_upcast:
load_offset_new = {}
for k,out_tokens in self._group_float4(i, load_offset).items():
@@ -183,7 +180,7 @@ class Linearizer:
# idxs not in order, valids don't match, or idx doesn't evenly divide 4. use normal float
for x in out_tokens: load_offset_new[x[1]] = x
else:
load_offset_new[k] = (LocalTypes.float4, [x[1] for x in out_tokens], out_tokens[0][2], out_tokens[0][3])
load_offset_new[k] = (dtypes._float4, [x[1] for x in out_tokens], out_tokens[0][2], out_tokens[0][3])
load_offset = load_offset_new
# do loads
@@ -193,9 +190,9 @@ class Linearizer:
key = f"{localtype}{idx.render()}{valid.render()}"
if key not in cache:
cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const)
if localtype == LocalTypes.float4:
if localtype == dtypes._float4:
for j,uidx in enumerate(uidx_list):
loaded[uidx] = Token(cache[key].name, LocalTypes.float4, j)
loaded[uidx] = Token(cache[key].name, dtypes._float4, j)
else:
loaded[uidxs] = cache[key]
return [loaded[uidxs] for uidxs in self.shape_offsets(i)]
@@ -204,14 +201,15 @@ class Linearizer:
store_offset: Dict[Tuple[int, ...], Token] = dict(zip(self.shape_offsets(i), store))
# float4 grouping (optional)
should_upcast = self.supports_float4 and (self.bufs[i].dtype not in (dtypes.float16, dtypes.int8, dtypes.uint8)) and len(self.float4_axis(i)) == 1
# TODO: why does this not work for float16?
should_upcast = self.supports_float4 and (self.bufs[i].dtype == dtypes.float32 or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1
if should_upcast:
store_offset_new = {}
for k,out_tokens in self._group_float4(i, store_offset).items():
if all_same([x.name for x in out_tokens]) and tuple(range(4)) == tuple(x.offset for x in out_tokens):
store_offset_new[k] = Token(out_tokens[0].name, LocalTypes.float4)
store_offset_new[k] = Token(out_tokens[0].name, dtypes._float4)
else:
store_offset_new[k] = self.uop(UOps.CAST, ssa("alu", LocalTypes.float4), out_tokens)
store_offset_new[k] = self.uop(UOps.CAST, ssa("alu", dtypes._float4), out_tokens)
store_offset = store_offset_new
# do stores
@@ -242,7 +240,7 @@ class Linearizer:
# ssa
_ssa:DefaultDict[str,int] = defaultdict(int)
def ssa(name, ltype=LocalTypes.float) -> Token:
def ssa(name, ltype=dtypes.float) -> Token:
_ssa[name] += 1
return Token(f"{name}{_ssa[name]-1}", ltype)
@@ -345,12 +343,12 @@ class Linearizer:
if isinstance(x.op, (ReduceOps, FusedOps)):
ret = [(idx, self.uop(UOps.ALU, val[0], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(acc, *values, grouping_allowed=self.supports_float4_alu)]
else:
ret = [(idx, self.uop(UOps.ALU, ssa('alu', LocalTypes.float4) if any(x.ltype == LocalTypes.float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
# scatter
for i,j in ret:
for o,k in enumerate(i):
ordered_ret[k] = Token(j.name, j.ltype, o) if j.ltype == LocalTypes.float4 else j
ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
return cast(List[Token], ordered_ret)
@@ -537,7 +535,7 @@ class Linearizer:
# if nothing at all is upcasted and it's easy to, do an upcast
# TODO: this is breaking the tests
#for splits in [4]:
# if self.upcasted == 0 and len(self.full_unupcasted_shape) > 0 and self.full_unupcasted_shape[-1] % splits == 0:
# self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
# self.upcast()
for splits in [4]:
if self.upcasted == 0 and len(self.full_unupcasted_shape) > 0 and self.full_unupcasted_shape[-1] % splits == 0:
self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
self.upcast()

View File

@@ -56,7 +56,8 @@ class DType(NamedTuple):
priority: int # this determines when things get upcasted
itemsize: int
name: str
np: type # TODO: someday this will be removed with the "remove numpy" project
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
sz: int = 1
def __repr__(self): return f"dtypes.{self.name}"
# dependent typing?
@@ -80,7 +81,9 @@ class dtypes:
def from_np(x) -> DType: return asdict(dtypes())[np.dtype(x).name]
bool: Final[DType] = DType(0, 1, "bool", bool)
float16: Final[DType] = DType(0, 2, "half", np.float16)
half = float16
float32: Final[DType] = DType(4, 4, "float", np.float32)
float = float32
float64: Final[DType] = DType(5, 8, "double", np.float64)
int8: Final[DType] = DType(0, 1, "char", np.int8)
int32: Final[DType] = DType(1, 4, "int", np.int32)
@@ -89,6 +92,9 @@ class dtypes:
uint32: Final[DType] = DType(1, 4, "uint", np.uint32)
uint64: Final[DType] = DType(2, 8, "uint64", np.uint64)
# NOTE: these are internal dtypes, should probably check for that
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
class GlobalCounters:
global_ops: ClassVar[int] = 0

View File

@@ -136,6 +136,7 @@ class LazyBuffer:
self.realized = Device[self.device].buffer(prod(self.shape), self.dtype, **self._device_extra_args())
elif self.op.op == LoadOps.RAND:
rng = np.random.default_rng(self.op.arg)
assert self.dtype.np is not None, "internal dtypes don't work with LoadOps.RAND"
self.realized = Device[self.device].buffer.fromCPU(rng.random(size=self.shape, dtype=self.dtype.np), **self._device_extra_args())
elif self.op.op == LoadOps.CONST:
if hasattr(Device[self.device].codegen, 'supports_constant_folding'):