diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index e857023161..d108e496ca 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -11,12 +11,12 @@ class DType: itemsize: int name: str fmt: Optional[str] - sz: int - def __repr__(self): return f"dtypes.{'_'*(c:=self.sz!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.sz)*c}" + count: int + def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}" def vec(self, sz:int): - assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}" + assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}" return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz) - def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self + def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self # TODO: someday this will be removed with the "remove numpy" project @property def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None @@ -32,10 +32,10 @@ class ImageDType(DType): # @dataclass(frozen=True, init=False, repr=False, eq=False) class PtrDType(DType): - def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.sz) + def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count) def __repr__(self): return f"ptr.{super().__repr__()}" def __hash__(self): return super().__hash__() - def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.sz==dt.sz + def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count def __ne__(self, dt): return not (self == dt) def cast_scalar(scalar: Scalar, dtype:DType): diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 1b7fda4ad3..7b0c70266f 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -37,7 +37,7 @@ class CStyleLanguage(NamedTuple): def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str: if bitcast: return f"(*(({self.buffer_prefix}{self.render_dtype(var_dtype)}*)&{x[0]}))" if len(x) == 1: return f"({self.render_dtype(var_dtype)})({x[0]})" - assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}" + assert len(x) == var_dtype.count, f"cast is wrong size {len(x)} != {var_dtype.count}" assert self.float4 is not None, "vectorized cast is not supported on this platform" return f"{self.float4.replace('float4', self.render_dtype(var_dtype))}({','.join(x)})" @@ -46,7 +46,8 @@ class CStyleLanguage(NamedTuple): if math.isnan(x): val = "NAN" elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY" else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" if dtypes.is_int(var_dtype) else f"{bool(x)}".lower() - return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val + return (self.render_cast([val]*var_dtype.count, var_dtype) + if var_dtype.count > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val) # returns a str expression of the loaded value with the output type def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: @@ -54,9 +55,9 @@ class CStyleLanguage(NamedTuple): assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}" return f"read_imagef({buf_name}, smp, {idx})" if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16: - return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})" - if output_dtype.sz > 1: - out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" # noqa: E501 + return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})" + if output_dtype.count > 1: + out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.count}*)({buf_name}+{idx}))" # noqa: E501 else: out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val @@ -77,9 +78,9 @@ class CStyleLanguage(NamedTuple): assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}" return f"write_imagef({buf_name}, {idx}, {var_name});" if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16: - return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});" - if var_dtype.sz > 1: - return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" # noqa: E501 + return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});" + if var_dtype.count > 1: + return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.count}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.count}){var_name};" # noqa: E501 return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];" @@ -164,7 +165,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st elif uop is UOps.WMMA: kk(f"{dtype.name} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") elif uop is UOps.DEFINE_ACC: kk(f"{dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};") elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})" - elif uop is UOps.GEP: r[u] = f"({r[vin[0]]})[{args}]" if cast(DType, vin[0].dtype).sz > 4 else f"({r[vin[0]]}).{'xyzw'[args]}" + elif uop is UOps.GEP: r[u] = f"({r[vin[0]]})[{args}]" if cast(DType, vin[0].dtype).count > 4 else f"({r[vin[0]]}).{'xyzw'[args]}" else: raise RuntimeError(f"failed to render {uop}") return lang.render_kernel(function_name, kernel, bufs, local_size, uops) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 0587536e6e..8202e9fb46 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -66,12 +66,12 @@ class PythonProgram: assert len(inp) <= 3, "gated stores not supported yet" if isinstance(dtp[0], ImageDType): # image store - assert dtp[2].sz == 4 + assert dtp[2].count == 4 for j,val in enumerate(inp[2]): for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val): assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0] _store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v) - elif dtp[2].sz > 1: + elif dtp[2].count > 1: for j,val in enumerate(inp[2]): for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v) else: @@ -102,8 +102,8 @@ class PythonProgram: ul[i] = [x[2-arg[0]] for x in warp] elif uop is UOps.CONST: ul[i] = [int(arg) if dtypes.is_int(dtype) else float(arg)] * warp_size elif uop is UOps.DEFINE_ACC: - if dtype.sz > 1: - ul[i] = [[arg] * warp_size for _ in range(dtype.sz)] + if dtype.count > 1: + ul[i] = [[arg] * warp_size for _ in range(dtype.count)] else: ul[i] = [arg] * warp_size elif uop is UOps.LOOP: @@ -116,7 +116,7 @@ class PythonProgram: i = loop_ends[i] + 1 continue elif uop is UOps.CAST: - if dtype.sz > 1: + if dtype.count > 1: ul[i] = inp else: # TODO: add real cast @@ -128,16 +128,16 @@ class PythonProgram: ul[i] = inp[0] elif uop is UOps.LOAD: if isinstance(dtp[0], ImageDType): - assert dtype.sz == 4 + assert dtype.count == 4 ul[i] = [] - for j in range(dtype.sz): + for j in range(dtype.count): ret = [] for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]): if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0) else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j)) ul[i].append(ret) - elif dtype.sz > 1: - ul[i] = [load(inp, j) for j in range(dtype.sz)] + elif dtype.count > 1: + ul[i] = [load(inp, j) for j in range(dtype.count)] else: ul[i] = load(inp) elif uop is UOps.PHI: