Rename .sz to .count on DType (#3413)

* rename .sz for .count on dtype (and ANETensor for completeness)

* revert the changes to extra, as per review

* try to make linter happier

* remove the change to extra
This commit is contained in:
Maciej Fijalkowski
2024-02-15 16:03:49 +02:00
committed by GitHub
parent 7919a1e6ec
commit 736c74b010
3 changed files with 25 additions and 24 deletions

View File

@@ -11,12 +11,12 @@ class DType:
itemsize: int
name: str
fmt: Optional[str]
sz: int
def __repr__(self): return f"dtypes.{'_'*(c:=self.sz!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.sz)*c}"
count: int
def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}"
def vec(self, sz:int):
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
# TODO: someday this will be removed with the "remove numpy" project
@property
def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None
@@ -32,10 +32,10 @@ class ImageDType(DType):
# @dataclass(frozen=True, init=False, repr=False, eq=False)
class PtrDType(DType):
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.sz)
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
def __repr__(self): return f"ptr.{super().__repr__()}"
def __hash__(self): return super().__hash__()
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.sz==dt.sz
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
def __ne__(self, dt): return not (self == dt)
def cast_scalar(scalar: Scalar, dtype:DType):

View File

@@ -37,7 +37,7 @@ class CStyleLanguage(NamedTuple):
def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str:
if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))"
if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})"
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}"
assert self.float4 is not None, "vectorized cast is not supported on this platform"
return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})"
@@ -46,7 +46,8 @@ class CStyleLanguage(NamedTuple):
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" if dtypes.is_int(var_dtype) else f"{bool(x)}".lower()
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val
return (self.render_cast([val]*var_dtype.count, var_dtype)
if var_dtype.count > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
@@ -54,9 +55,9 @@ class CStyleLanguage(NamedTuple):
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
return f"read_imagef({buf_name}, smp, {idx})"
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
if output_dtype.sz > 1:
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" # noqa: E501
return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
if output_dtype.count > 1:
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501
else:
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
@@ -77,9 +78,9 @@ class CStyleLanguage(NamedTuple):
assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
return f"write_imagef({buf_name}, {idx}, {var_name});"
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
if var_dtype.sz > 1:
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" # noqa: E501
return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
if var_dtype.count > 1:
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.count}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.count}){var_name};" # noqa: E501
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];"
@@ -164,7 +165,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
elif uop is UOps.WMMA: kk(f"{dtype.name} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
elif uop is UOps.GEP: r[u] = f"({r[vin[0]]})[{args}]" if cast(DType, vin[0].dtype).sz > 4 else f"({r[vin[0]]}).{'xyzw'[args]}"
elif uop is UOps.GEP: r[u] = f"({r[vin[0]]})[{args}]" if cast(DType, vin[0].dtype).count > 4 else f"({r[vin[0]]}).{'xyzw'[args]}"
else: raise RuntimeError(f"failed to render {uop}")
return lang.render_kernel(function_name, kernel, bufs, local_size, uops)

View File

@@ -66,12 +66,12 @@ class PythonProgram:
assert len(inp) <= 3, "gated stores not supported yet"
if isinstance(dtp[0], ImageDType):
# image store
assert dtp[2].sz == 4
assert dtp[2].count == 4
for j,val in enumerate(inp[2]):
for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val):
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
_store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
elif dtp[2].sz > 1:
elif dtp[2].count > 1:
for j,val in enumerate(inp[2]):
for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v)
else:
@@ -102,8 +102,8 @@ class PythonProgram:
ul[i] = [x[2-arg[0]] for x in warp]
elif uop is UOps.CONST: ul[i] = [int(arg) if dtypes.is_int(dtype) else float(arg)] * warp_size
elif uop is UOps.DEFINE_ACC:
if dtype.sz > 1:
ul[i] = [[arg] * warp_size for _ in range(dtype.sz)]
if dtype.count > 1:
ul[i] = [[arg] * warp_size for _ in range(dtype.count)]
else:
ul[i] = [arg] * warp_size
elif uop is UOps.LOOP:
@@ -116,7 +116,7 @@ class PythonProgram:
i = loop_ends[i] + 1
continue
elif uop is UOps.CAST:
if dtype.sz > 1:
if dtype.count > 1:
ul[i] = inp
else:
# TODO: add real cast
@@ -128,16 +128,16 @@ class PythonProgram:
ul[i] = inp[0]
elif uop is UOps.LOAD:
if isinstance(dtp[0], ImageDType):
assert dtype.sz == 4
assert dtype.count == 4
ul[i] = []
for j in range(dtype.sz):
for j in range(dtype.count):
ret = []
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
ul[i].append(ret)
elif dtype.sz > 1:
ul[i] = [load(inp, j) for j in range(dtype.sz)]
elif dtype.count > 1:
ul[i] = [load(inp, j) for j in range(dtype.count)]
else:
ul[i] = load(inp)
elif uop is UOps.PHI: