mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-15 00:55:11 -05:00
Rename .sz to .count on DType (#3413)
* rename .sz for .count on dtype (and ANETensor for completeness) * revert the changes to extra, as per review * try to make linter happier * remove the change to extra
This commit is contained in:
committed by
GitHub
parent
7919a1e6ec
commit
736c74b010
@@ -66,12 +66,12 @@ class PythonProgram:
|
||||
assert len(inp) <= 3, "gated stores not supported yet"
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
# image store
|
||||
assert dtp[2].sz == 4
|
||||
assert dtp[2].count == 4
|
||||
for j,val in enumerate(inp[2]):
|
||||
for m,ox,oy,v in zip(inp[0], inp[1][0], inp[1][1], val):
|
||||
assert ox >= 0 and ox < dtp[0].shape[1] and oy >= 0 and oy < dtp[0].shape[0]
|
||||
_store(m, ox*4 + oy*dtp[0].shape[1]*4 + j, v)
|
||||
elif dtp[2].sz > 1:
|
||||
elif dtp[2].count > 1:
|
||||
for j,val in enumerate(inp[2]):
|
||||
for m,o,v in zip(inp[0], inp[1], val): _store(m, o+j, v)
|
||||
else:
|
||||
@@ -102,8 +102,8 @@ class PythonProgram:
|
||||
ul[i] = [x[2-arg[0]] for x in warp]
|
||||
elif uop is UOps.CONST: ul[i] = [int(arg) if dtypes.is_int(dtype) else float(arg)] * warp_size
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.sz > 1:
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.sz)]
|
||||
if dtype.count > 1:
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.count)]
|
||||
else:
|
||||
ul[i] = [arg] * warp_size
|
||||
elif uop is UOps.LOOP:
|
||||
@@ -116,7 +116,7 @@ class PythonProgram:
|
||||
i = loop_ends[i] + 1
|
||||
continue
|
||||
elif uop is UOps.CAST:
|
||||
if dtype.sz > 1:
|
||||
if dtype.count > 1:
|
||||
ul[i] = inp
|
||||
else:
|
||||
# TODO: add real cast
|
||||
@@ -128,16 +128,16 @@ class PythonProgram:
|
||||
ul[i] = inp[0]
|
||||
elif uop is UOps.LOAD:
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
assert dtype.sz == 4
|
||||
assert dtype.count == 4
|
||||
ul[i] = []
|
||||
for j in range(dtype.sz):
|
||||
for j in range(dtype.count):
|
||||
ret = []
|
||||
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
|
||||
if ox < 0 or ox >= dtp[0].shape[1] or oy < 0 or oy >= dtp[0].shape[0]: ret.append(0)
|
||||
else: ret.append(_load(m, ox*4 + oy*dtp[0].shape[1]*4 + j))
|
||||
ul[i].append(ret)
|
||||
elif dtype.sz > 1:
|
||||
ul[i] = [load(inp, j) for j in range(dtype.sz)]
|
||||
elif dtype.count > 1:
|
||||
ul[i] = [load(inp, j) for j in range(dtype.count)]
|
||||
else:
|
||||
ul[i] = load(inp)
|
||||
elif uop is UOps.PHI:
|
||||
|
||||
Reference in New Issue
Block a user