From 82833f1b3cb0880b08836a78e21b93cd5db5a6bb Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 19 Dec 2024 22:09:52 -0800 Subject: [PATCH] a little more typing [pr] (#8346) * a little more typing [pr] * few more --- tinygrad/codegen/kernel.py | 2 +- tinygrad/codegen/linearize.py | 9 ++++----- tinygrad/codegen/lowerer.py | 6 +++--- tinygrad/codegen/uopgraph.py | 8 ++++---- tinygrad/device.py | 2 +- tinygrad/function.py | 11 +++++------ tinygrad/nn/state.py | 4 ++-- tinygrad/ops.py | 18 +++++++++--------- tinygrad/renderer/ptx.py | 10 +++++----- tinygrad/runtime/graph/hcq.py | 4 ++-- tinygrad/runtime/ops_nv.py | 4 ++-- 11 files changed, 38 insertions(+), 40 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 1b0da8a4d7..7b9426d99c 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -139,7 +139,7 @@ class Kernel: def first_upcast(self) -> int: return self.shape_len-self.upcasted @property - def reduceop(self) -> Optional[UOp]: return self.reduceops[0] if len(self.reduceops) > 0 else None + def reduceop(self) -> UOp|None: return self.reduceops[0] if len(self.reduceops) > 0 else None @property def output_shape(self) -> tuple[sint, ...]: return self.sts[0].shape diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index de5f836c7e..6b6e9fab9d 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -1,5 +1,4 @@ from __future__ import annotations -from typing import Tuple, Optional, DefaultDict import collections, heapq from dataclasses import dataclass from tinygrad.ops import type_verify, UOp, Ops, PatternMatcher, UPat, graph_rewrite, GroupOp @@ -18,7 +17,7 @@ def disp(y:UOp) -> str: class BasicBlock: ctx: tuple[UOp, ...] lst: tuple[UOp, ...] - end: Optional[UOp] = None + end: UOp|None = None def __lt__(self, o:BasicBlock): return tuple(x.tuplize for x in self.ctx+self.lst) < tuple(x.tuplize for x in o.ctx+o.lst) def __repr__(self): return f"{(str(disp(self.end))+' ') if self.end is not None else ''}"+\ @@ -115,8 +114,8 @@ pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), blo # NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed def block_reorder(in_block:UOp): in_this_block = set(in_block.arg.lst) - local_children: DefaultDict[UOp, list[UOp]] = collections.defaultdict(list) - in_degree: DefaultDict[UOp, int] = collections.defaultdict(int) + local_children: collections.defaultdict[UOp, list[UOp]] = collections.defaultdict(list) + in_degree: collections.defaultdict[UOp, int] = collections.defaultdict(int) priorities:dict[UOp, int] = {} # get local children and assign priorities @@ -129,7 +128,7 @@ def block_reorder(in_block:UOp): priorities[u] = min([-1000 if u.op is Ops.LOAD else 0] + [priorities[x] for x in local_children[u]]) # placement queue - queue:list[tuple[int, Tuple, UOp]] = [] + queue:list[tuple[int, tuple, UOp]] = [] def push(u:UOp): heapq.heappush(queue, (priorities[u], u.tuplize, u)) # place the first ones that don't have deps diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 9c80fa3739..429769fd1d 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -1,14 +1,14 @@ # the job of the lowerer is to do indexing import functools, itertools, operator from dataclasses import dataclass -from typing import cast, Optional +from typing import cast from tinygrad.dtype import dtypes, PtrDType from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop from tinygrad.renderer import Renderer from tinygrad.helpers import all_int, prod, partition, flatten # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape -def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> Optional[list[list[int]]]: +def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None: acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul)) try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new] except ValueError: return None @@ -27,7 +27,7 @@ def _limit_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]): else: raise RuntimeError(f"cannot limit dim {dims=}, {max_sizes=}") return dims -def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:Optional[tuple[int, ...]], reverse=False) -> list[UOp]: +def get_grouped_dims(prefix, dims:tuple[sint, ...], max_sizes:tuple[int, ...]|None, reverse=False) -> list[UOp]: if reverse: dims = dims[::-1] limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)] diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 8c0832cc1e..a536db0f04 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -84,7 +84,7 @@ float4_folding = PatternMatcher([ # ***** image load valid simplification ***** -def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]: +def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0) if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid) @@ -398,7 +398,7 @@ def no_vectorized_alu(alu): alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount)) return UOp(Ops.VECTORIZE, alu.dtype, alus) -def create_gate(root:UOp) -> Optional[UOp]: +def create_gate(root:UOp) -> UOp|None: @functools.lru_cache(None) def _gate_srcs(u:UOp, gate:UOp) -> UOp: if u.op is Ops.BARRIER: return u @@ -450,7 +450,7 @@ devectorize = PatternMatcher([ (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store), ]) -def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optional[UOp]=None) -> Optional[UOp]: +def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None: if store_gate not in [gate.src[0] for gate in val.toposort if gate.op is Ops.IF]: return None # remove the gate from the index return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val) @@ -472,7 +472,7 @@ migrate_indexing = PatternMatcher([ (UPat(Ops.STORE, name="root"), create_gate), ]) -def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp: +def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:UOp|None=None) -> UOp: # this moves the mask from the indexing to the load/store op for rendering nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx) return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:]) diff --git a/tinygrad/device.py b/tinygrad/device.py index a92940668e..023095f1d6 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -86,7 +86,7 @@ class BufferSpec: class Buffer: def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferSpec]=None, initial_value:Optional[bytes]=None, - lb_refcount=0, uop_ref:Optional[UOp]=None, base:Optional[Buffer]=None, offset:int=0, preallocate=False): + lb_refcount=0, uop_ref:UOp|None=None, base:Optional[Buffer]=None, offset:int=0, preallocate=False): if isinstance(dtype, ImageDType): options = BufferSpec(image=dtype) # TODO: image hack shouldn't be here. where should it be? else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType) self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset diff --git a/tinygrad/function.py b/tinygrad/function.py index 155c77e017..96b53d10a3 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -1,6 +1,5 @@ """This is where the forwards and backwards passes live.""" import math -from typing import Optional from tinygrad.helpers import argsort from tinygrad.dtype import dtypes, DType, sum_acc_dtype from tinygrad.ops import Ops, resolve, sint, UOp @@ -77,11 +76,11 @@ class Sign(Function): class Less(Function): def forward(self, x:UOp, y:UOp) -> UOp: return x tuple[Optional[UOp], Optional[UOp]]: return None, None + def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None class Neq(Function): def forward(self, x:UOp, y:UOp) -> UOp: return x.ne(y) - def backward(self, grad_output:UOp) -> tuple[Optional[UOp], Optional[UOp]]: return None, None + def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None class Xor(Function): def forward(self, x:UOp, y:UOp) -> UOp: return x^y @@ -98,7 +97,7 @@ class Threefry(Function): class Add(Function): def forward(self, x:UOp, y:UOp) -> UOp: return x+y - def backward(self, grad_output:UOp) -> tuple[Optional[UOp], Optional[UOp]]: + def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return grad_output if self.needs_input_grad[0] else None, \ grad_output if self.needs_input_grad[1] else None @@ -107,7 +106,7 @@ class Mul(Function): self.x, self.y = x, y return x * y - def backward(self, grad_output:UOp) -> tuple[Optional[UOp], Optional[UOp]]: + def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return (self.y * grad_output) if self.needs_input_grad[0] else None, \ (self.x * grad_output) if self.needs_input_grad[1] else None @@ -121,7 +120,7 @@ class Where(Function): self.x = x return self.x.where(y, z) - def backward(self, grad_output:UOp) -> tuple[None, Optional[UOp], Optional[UOp]]: + def backward(self, grad_output:UOp) -> tuple[None, UOp|None, UOp|None]: return None, \ self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \ self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 4f9e05febe..e3f2ef237c 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -1,5 +1,5 @@ import json, pathlib, zipfile, pickle, tarfile, struct, functools, io -from typing import Dict, Union, Optional, Any, Callable, BinaryIO, Iterable, TypeVar +from typing import Union, Optional, Any, Callable, BinaryIO, Iterable, TypeVar from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up @@ -291,7 +291,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor: raise ValueError(f"GGML type '{ggml_type}' is not supported!") @accept_filename -def gguf_load(tensor: Tensor) -> tuple[Dict, dict[str, Tensor]]: +def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]: """ Loads a gguf file from a tensor. diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b9e4a12da5..d4c1b3a806 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -338,7 +338,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ret def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs) def detach(self): return UOp(Ops.DETACH, self.dtype, (self,)) - def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) + def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) def const_like(self, b:ConstLike): if self._device is not None: return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b) return UOp.const(self.dtype, b) if self.st is None else UOp.const_with_shape(self.dtype, b, self.shape) @@ -575,7 +575,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1 return 1 - def divides(self, v) -> Optional[UOp]: + def divides(self, v) -> UOp|None: if v==1: return self if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None @@ -800,7 +800,7 @@ class PatternMatcher: @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns) - def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]: + def rewrite(self, uop:UOp, ctx=None) -> UOp|None: ler = {u.op for u in uop.src} for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []): if not early_reject.issubset(ler): continue @@ -832,7 +832,7 @@ def track_rewrites(named=False): return _decorator class TrackedPatternMatcher(PatternMatcher): - def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]: + def rewrite(self, uop:UOp, ctx=None) -> UOp|None: ret = None ler = {u.op for u in uop.src} for p,fxn,early_reject,has_ctx in self.pdict.get(uop.op, []): @@ -1000,7 +1000,7 @@ def split_uop(x:UOp, sep:Ops): for s in x.src: yield from split_uop(s, sep) else: yield x -def div_and_mod_folding(x: UOp, c: int, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> Optional[UOp]: +def div_and_mod_folding(x: UOp, c: int, which: Literal[Ops.MOD, Ops.IDIV], split_rem: bool=False) -> UOp|None: # simplify x // c or x % c, None means no change, c must be > 0 assert c > 0 if x.dtype.count > 1: return None @@ -1059,7 +1059,7 @@ def div_and_mod_folding(x: UOp, c: int, which: Literal[Ops.MOD, Ops.IDIV], split if which is Ops.MOD: return gcd*(rem % (c//gcd)) + const%gcd return rem//(c//gcd)+quo -def lt_folding(x:UOp, c:int) -> Optional[UOp]: +def lt_folding(x:UOp, c:int) -> UOp|None: p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1) if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d: return cast(UOp, functools.reduce(operator.add, np).divides(d))<(c//d) @@ -1086,7 +1086,7 @@ def fold_unrolled_divs(divs:UOp): if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i) return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None -def canonicalize_simplex(X:UOp) -> Optional[UOp]: +def canonicalize_simplex(X:UOp) -> UOp|None: # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. # returns x0 + x1 + ... in such case, or None if not changed, ret = False, [] @@ -1117,7 +1117,7 @@ def parse_valid(valid:UOp) -> tuple[UOp, bool, int]: if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1 raise ValueError(f"not able to parse {valid=}") -def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: +def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: # return None if valid is always False, otherwise the simplified uop (might be the same as input) # first, parse valid into {expr: (lower_bound, upper_bound)} @@ -1156,7 +1156,7 @@ def _valid_priority(v: UOp, valids:list[UOp]): try: return sum(-1 if parse_valid(v)[0] in other.toposort else 0 for other in valids) except ValueError: return 0 -def simplify_valid(valid:UOp) -> Optional[UOp]: +def simplify_valid(valid:UOp) -> UOp|None: ret:list[UOp] = [] something_changed = False valids = list(split_uop(valid, Ops.AND)) diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index d620f6a6e3..0503c0b386 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -1,4 +1,4 @@ -from typing import DefaultDict, Union, Optional, cast, Callable +from typing import cast, Callable import struct from collections import defaultdict from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp @@ -156,12 +156,12 @@ class PTXRenderer(Renderer): kernel:list[str] = [] bufs = [] - c: DefaultDict[str, int] = defaultdict(int) - r: dict[UOp, Union[list[str], str]] = {} + c: defaultdict[str, int] = defaultdict(int) + r: dict[UOp, list[str]|str] = {} self.r = r self.uops = uops - def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str: + def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str: nonlocal c, r prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_" c[prefix] += 1 @@ -192,7 +192,7 @@ class PTXRenderer(Renderer): Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None)) if prefix: r[u] = ssa(prefix, u, dtype) - if (l:=cast(Union[str, list[str]], string_rewrite.rewrite(u, ctx=self))) is None: + if (l:=cast(str|list[str], string_rewrite.rewrite(u, ctx=self))) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}") kernel.extend([l] if isinstance(l, str) else l) diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 507d86d410..352c5476da 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -1,5 +1,5 @@ import collections, time -from typing import List, Any, cast, Optional +from typing import Any, cast, Optional from tinygrad.helpers import round_up, PROFILE from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent @@ -42,7 +42,7 @@ class HCQGraph(MultiGraphRunner): # graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with # global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s # compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue. - self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, List, List, HCQSignal, Optional[int]]] = {} + self.ji_schedule: dict[int, tuple[HCQCompiled, HWQueue, list, list, HCQSignal, Optional[int]]] = {} self.comp_queues: dict[HCQCompiled, HWQueue] = {dev: dev.hw_compute_queue_t() for dev in self.devices} self.copy_queues: dict[HCQCompiled, HWQueue] = {} # lazy allocation diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index b4912e68c8..c503b56258 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -1,7 +1,7 @@ from __future__ import annotations import os, ctypes, contextlib, re, fcntl, functools, mmap, struct, array, sys assert sys.platform != 'win32' -from typing import List, Any, cast, Union, Type, Optional +from typing import Any, cast, Union, Type, Optional from dataclasses import dataclass from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQProgram, HCQSignal, BumpAllocator from tinygrad.ops import sint @@ -285,7 +285,7 @@ class NVDevice(HCQCompiled[NVSignal]): root = None fd_ctl: int = -1 fd_uvm: int = -1 - gpus_info: Union[List, ctypes.Array] = [] + gpus_info: Union[list, ctypes.Array] = [] signals_page: Any = None signals_pool: list[int] = []