Rename .sz to .count on DType (#3413)

* rename .sz for .count on dtype (and ANETensor for completeness)

* revert the changes to extra, as per review

* try to make linter happier

* remove the change to extra
This commit is contained in:
Maciej Fijalkowski
2024-02-15 16:03:49 +02:00
committed by GitHub
parent 7919a1e6ec
commit 736c74b010
3 changed files with 25 additions and 24 deletions

View File

@@ -66,12 +66,12 @@ class PythonProgram:
assert len(inp) <= 3, "gated stores not supported yet"
if isinstance(dtp[0], ImageDType):
# image store
assert dtp[2].sz == 4
assert dtp[2].count == 4
for j,val in enumerate(inp[2]):
for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val):
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
_store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
elif dtp[2].sz > 1:
elif dtp[2].count > 1:
for j,val in enumerate(inp[2]):
for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v)
else:
@@ -102,8 +102,8 @@ class PythonProgram:
ul[i] = [x[2-arg[0]] for x in warp]
elif uop is UOps.CONST: ul[i] = [int(arg) if dtypes.is_int(dtype) else float(arg)] * warp_size
elif uop is UOps.DEFINE_ACC:
if dtype.sz > 1:
ul[i] = [[arg] * warp_size for _ in range(dtype.sz)]
if dtype.count > 1:
ul[i] = [[arg] * warp_size for _ in range(dtype.count)]
else:
ul[i] = [arg] * warp_size
elif uop is UOps.LOOP:
@@ -116,7 +116,7 @@ class PythonProgram:
i = loop_ends[i] + 1
continue
elif uop is UOps.CAST:
if dtype.sz > 1:
if dtype.count > 1:
ul[i] = inp
else:
# TODO: add real cast
@@ -128,16 +128,16 @@ class PythonProgram:
ul[i] = inp[0]
elif uop is UOps.LOAD:
if isinstance(dtp[0], ImageDType):
assert dtype.sz == 4
assert dtype.count == 4
ul[i] = []
for j in range(dtype.sz):
for j in range(dtype.count):
ret = []
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
ul[i].append(ret)
elif dtype.sz > 1:
ul[i] = [load(inp, j) for j in range(dtype.sz)]
elif dtype.count > 1:
ul[i] = [load(inp, j) for j in range(dtype.count)]
else:
ul[i] = load(inp)
elif uop is UOps.PHI: