mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-16 01:26:29 -05:00
prepare mypy==1.13.0: legacy cast (#7866)
* use helper to narrow literal type * narrow with asserts instead of cast * remove parantheses * tensor.item() calls tensor.data() * no copy * proper indexing
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
# a python uops emulator
|
||||
# works to test the tensor cores, and all the uops in general
|
||||
# this is the (living) definition of uops
|
||||
from typing import Tuple, List, Optional, Any, Dict
|
||||
import sys
|
||||
from typing import Tuple, List, Optional, Any, Dict, TYPE_CHECKING
|
||||
import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
@@ -66,13 +67,11 @@ class PythonProgram:
|
||||
continue
|
||||
assert dtype is not None, f"{uop} is missing a dtype"
|
||||
dl[i] = dtype
|
||||
if uop is Ops.DEFINE_GLOBAL:
|
||||
if uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL}:
|
||||
assert dtype.fmt is not None
|
||||
ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
|
||||
elif uop is Ops.DEFINE_LOCAL:
|
||||
assert dtype.fmt is not None
|
||||
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
|
||||
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert dtype.fmt != "e"
|
||||
buf = memoryview(bytearray(arg[1]*dtype.itemsize)) if uop is Ops.DEFINE_LOCAL else pbufs.pop(0)
|
||||
ul[i] = [buf.cast(dtype.fmt)] * warp_size
|
||||
elif uop is Ops.DEFINE_VAR:
|
||||
ul[i] = [pvals.pop(0)] * warp_size
|
||||
elif uop is Ops.SPECIAL:
|
||||
|
||||
Reference in New Issue
Block a user