prepare mypy==1.13.0: legacy cast (#7866)

* use helper to narrow literal type

* narrow with asserts instead of cast

* remove parantheses

* tensor.item() calls tensor.data()

* no copy

* proper indexing
This commit is contained in:
JaSpa99
2024-11-27 16:33:35 +01:00
committed by GitHub
parent 753f07e193
commit 38f34ca0cb
3 changed files with 14 additions and 13 deletions

View File

@@ -2,7 +2,8 @@
# a python uops emulator
# works to test the tensor cores, and all the uops in general
# this is the (living) definition of uops
from typing import Tuple, List, Optional, Any, Dict
import sys
from typing import Tuple, List, Optional, Any, Dict, TYPE_CHECKING
import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
from tinygrad.helpers import all_same, getenv, flatten
@@ -66,13 +67,11 @@ class PythonProgram:
continue
assert dtype is not None, f"{uop} is missing a dtype"
dl[i] = dtype
if uop is Ops.DEFINE_GLOBAL:
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL}:
assert dtype.fmt is not None
ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
elif uop is Ops.DEFINE_LOCAL:
assert dtype.fmt is not None
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e"
buf = memoryview(bytearray(arg[1]*dtype.itemsize)) if uop is Ops.DEFINE_LOCAL else pbufs.pop(0)
ul[i] = [buf.cast(dtype.fmt)] * warp_size
elif uop is Ops.DEFINE_VAR:
ul[i] = [pvals.pop(0)] * warp_size
elif uop is Ops.SPECIAL: