mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Rename .sz to .count on DType (#3413)
* rename .sz for .count on dtype (and ANETensor for completeness) * revert the changes to extra, as per review * try to make linter happier * remove the change to extra
This commit is contained in:
committed by
GitHub
parent
7919a1e6ec
commit
736c74b010
@@ -11,12 +11,12 @@ class DType:
|
||||
itemsize: int
|
||||
name: str
|
||||
fmt: Optional[str]
|
||||
sz: int
|
||||
def __repr__(self): return f"dtypes.{'_'*(c:=self.sz!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.sz)*c}"
|
||||
count: int
|
||||
def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}"
|
||||
def vec(self, sz:int):
|
||||
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
|
||||
assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
|
||||
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
|
||||
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
|
||||
# TODO: someday this will be removed with the "remove numpy" project
|
||||
@property
|
||||
def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None
|
||||
@@ -32,10 +32,10 @@ class ImageDType(DType):
|
||||
|
||||
# @dataclass(frozen=True, init=False, repr=False, eq=False)
|
||||
class PtrDType(DType):
|
||||
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.sz)
|
||||
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
|
||||
def __repr__(self): return f"ptr.{super().__repr__()}"
|
||||
def __hash__(self): return super().__hash__()
|
||||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.sz==dt.sz
|
||||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
||||
def __ne__(self, dt): return not (self == dt)
|
||||
|
||||
def cast_scalar(scalar: Scalar, dtype:DType):
|
||||
|
||||
@@ -37,7 +37,7 @@ class CStyleLanguage(NamedTuple):
|
||||
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
|
||||
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))"
|
||||
if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})"
|
||||
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
||||
assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
|
||||
assert self.float4 is not None, "vectorized cast is not supported on this platform"
|
||||
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
|
||||
|
||||
@@ -46,7 +46,8 @@ class CStyleLanguage(NamedTuple):
|
||||
if math.isnan(x): val = "NAN"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" if dtypes.is_int(var_dtype) else f"{bool(x)}".lower()
|
||||
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val
|
||||
return (self.render_cast([val]*var_dtype.count, var_dtype)
|
||||
if var_dtype.count > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
||||
|
||||
# returns a str expression of the loaded value with the output type
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
@@ -54,9 +55,9 @@ class CStyleLanguage(NamedTuple):
|
||||
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
|
||||
return f"read_imagef({buf_name}, smp, {idx})"
|
||||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
|
||||
if output_dtype.sz > 1:
|
||||
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" # noqa: E501
|
||||
return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
|
||||
if output_dtype.count > 1:
|
||||
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
|
||||
else:
|
||||
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
|
||||
@@ -77,9 +78,9 @@ class CStyleLanguage(NamedTuple):
|
||||
assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
|
||||
return f"write_imagef({buf_name}, {idx}, {var_name});"
|
||||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
|
||||
if var_dtype.sz > 1:
|
||||
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" # noqa: E501
|
||||
return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
|
||||
if var_dtype.count > 1:
|
||||
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.count}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.count}){var_name};" # noqa: E501
|
||||
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
|
||||
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];"
|
||||
@@ -164,7 +165,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
||||
elif uop is UOps.WMMA: kk(f"{dtype.name} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
|
||||
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP: r[u] = f"({r[vin[0]]})[{args}]" if cast(DType, vin[0].dtype).sz > 4 else f"({r[vin[0]]}).{'xyzw'[args]}"
|
||||
elif uop is UOps.GEP: r[u] = f"({r[vin[0]]})[{args}]" if cast(DType, vin[0].dtype).count > 4 else f"({r[vin[0]]}).{'xyzw'[args]}"
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
return lang.render_kernel(function_name, kernel, bufs, local_size, uops)
|
||||
|
||||
@@ -66,12 +66,12 @@ class PythonProgram:
|
||||
assert len(inp) <= 3, "gated stores not supported yet"
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
# image store
|
||||
assert dtp[2].sz == 4
|
||||
assert dtp[2].count == 4
|
||||
for j,val in enumerate(inp[2]):
|
||||
for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val):
|
||||
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
|
||||
_store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
|
||||
elif dtp[2].sz > 1:
|
||||
elif dtp[2].count > 1:
|
||||
for j,val in enumerate(inp[2]):
|
||||
for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v)
|
||||
else:
|
||||
@@ -102,8 +102,8 @@ class PythonProgram:
|
||||
ul[i] = [x[2-arg[0]] for x in warp]
|
||||
elif uop is UOps.CONST: ul[i] = [int(arg) if dtypes.is_int(dtype) else float(arg)] * warp_size
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.sz > 1:
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.sz)]
|
||||
if dtype.count > 1:
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.count)]
|
||||
else:
|
||||
ul[i] = [arg] * warp_size
|
||||
elif uop is UOps.LOOP:
|
||||
@@ -116,7 +116,7 @@ class PythonProgram:
|
||||
i = loop_ends[i] + 1
|
||||
continue
|
||||
elif uop is UOps.CAST:
|
||||
if dtype.sz > 1:
|
||||
if dtype.count > 1:
|
||||
ul[i] = inp
|
||||
else:
|
||||
# TODO: add real cast
|
||||
@@ -128,16 +128,16 @@ class PythonProgram:
|
||||
ul[i] = inp[0]
|
||||
elif uop is UOps.LOAD:
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
assert dtype.sz == 4
|
||||
assert dtype.count == 4
|
||||
ul[i] = []
|
||||
for j in range(dtype.sz):
|
||||
for j in range(dtype.count):
|
||||
ret = []
|
||||
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
|
||||
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
|
||||
else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
|
||||
ul[i].append(ret)
|
||||
elif dtype.sz > 1:
|
||||
ul[i] = [load(inp, j) for j in range(dtype.sz)]
|
||||
elif dtype.count > 1:
|
||||
ul[i] = [load(inp, j) for j in range(dtype.count)]
|
||||
else:
|
||||
ul[i] = load(inp)
|
||||
elif uop is UOps.PHI:
|
||||
|
||||
Reference in New Issue
Block a user