diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1bc6ed1fd7..e1e4c1b7bd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -105,6 +105,8 @@ jobs: source venv/bin/activate pip install $GITHUB_WORKSPACE python -c "from tinygrad.tensor import Tensor; print(Tensor([1,2,3,4,5]))" + - name: Test DEBUG + run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())" - name: Repo line count <6000 lines run: MAX_LINE_COUNT=6000 python sz.py diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 57f303fc06..cc9f19c88c 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -1,6 +1,6 @@ from __future__ import annotations import functools, math, operator -from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, cast +from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast from collections import defaultdict from tinygrad.helpers import DEBUG, flatten, all_same from tinygrad.dtype import dtypes, DType @@ -47,18 +47,13 @@ def exec_alu(arg, dtype, p): #return (ret + adjusted) % 2 ** (dtype.itemsize * 8) - adjusted def uop_alu_resolve(u:UOp) -> sint: - if u.uop == UOps.CONST: return u.arg - elif u.uop == UOps.DEFINE_VAR: return u.arg - elif u.uop == UOps.ALU and u.arg == BinaryOps.MUL: - return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1]) - elif u.uop == UOps.ALU and u.arg == BinaryOps.ADD: - return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1]) - else: - raise RuntimeError(f"ALU resolve fail @ {u.uop}") + if u.uop is UOps.CONST: return u.arg + elif u.uop is UOps.DEFINE_VAR: return u.arg + elif u.uop is UOps.ALU and u.arg == BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1]) + elif u.uop is UOps.ALU and u.arg == BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1]) + else: raise RuntimeError(f"ALU resolve fail @ {u.uop}") -def phi_resolve_acc(u:UOp) -> UOp: - if u.uop == UOps.DEFINE_ACC: return u - return phi_resolve_acc(u.vin[0]) +def phi_resolve_acc(u:UOp) -> UOp: return u if u.uop is UOps.DEFINE_ACC else phi_resolve_acc(u.vin[0]) class UOpGraph: def __init__(self): @@ -78,9 +73,11 @@ class UOpGraph: def print(self): for u in self.uops: - print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") # noqa: E501 + print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " + f"{str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") - def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, simplify=True) -> UOp: # noqa: E501 + def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None, cachable=True, insert_before=None, + simplify=True) -> UOp: if simplify: if uop is UOps.PHI and len(vin) == 2: return vin[1] # a phi without loops is a noop if uop is UOps.GEP and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=vin[0].arg, insert_before=insert_before) @@ -130,7 +127,7 @@ class UOpGraph: def type_verify(self): for u in self.uops: uop, arg, vin, dtype = u.uop, u.arg, u.vin, u.dtype - if uop == UOps.ALU: + if uop is UOps.ALU: if arg in UnaryOps: assert dtype == vin[0].dtype, f"{arg} dtype mismatch {dtype=} != {vin[0].dtype=}" elif arg in (BinaryOps.CMPLT, BinaryOps.CMPEQ): @@ -156,22 +153,23 @@ class UOpGraph: for u in self.uops: if u.uop is UOps.LOOP: # add END of loops after the last thing that (recursively) depends on them - self.add(UOps.ENDLOOP, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(self.get_recursive_children(u)), key=self.uops.index)[-1])+1) # noqa: E501 + insert_before = self.uops.index(sorted(list(self.get_recursive_children(u)), key=self.uops.index)[-1])+1 + self.add(UOps.ENDLOOP, None, (u,), cachable=False, insert_before=insert_before) elif u.uop is UOps.IF: # END any if statements at the end of the uops self.add(UOps.ENDIF, None, (u,), cachable=False) - def fix_loop_scope(self, get_recursive_parents): + def fix_loop_scope(self, get_recursive_parents:Callable[..., Set[UOp]]): loop_stack: List[List[UOp]] = [[]] # push uops upward out of loop if it does not depend on the loop for u in self.uops: if not loop_stack[-1]: loop_stack[-1].append(u) - elif u.uop == UOps.LOOP: loop_stack.append([u]) + elif u.uop is UOps.LOOP: loop_stack.append([u]) elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST, UOps.LOAD]: loop_stack[-1].append(u) else: parents = get_recursive_parents(u, with_phi=True) # don't push any local buffer because there might have STORE and BARRIER (not considered as parent) between DEFINE_LOCAL and here - if any(u.uop == UOps.DEFINE_LOCAL for u in parents): loop_stack[-1].append(u) + if any(u.uop is UOps.DEFINE_LOCAL for u in parents): loop_stack[-1].append(u) else: for i in reversed(range(len(loop_stack))): # check backwards and put the uop in the first encounter with some dependency @@ -184,8 +182,7 @@ class UOpGraph: # get PHI node loop scope, link anything using a DEFINE_ACC to the loop as a "parent" acc_scope: DefaultDict[UOp, List[UOp]] = defaultdict(list) for u in self.uops: - if u.uop == UOps.PHI: - acc_scope[u.vin[0]] += u.vin[2:] + if u.uop is UOps.PHI: acc_scope[u.vin[0]] += u.vin[2:] # graph helper functions @functools.lru_cache(None) @@ -245,17 +242,17 @@ class UOpGraph: if u.uop is UOps.LOOP: mult_stack.append(mults) mults *= uop_alu_resolve(u.vin[1]) - if u.uop is UOps.ENDLOOP: + elif u.uop is UOps.ENDLOOP: mults = mult_stack.pop(-1) - if u.uop is UOps.ALU: + elif u.uop is UOps.ALU: flops += mults - if u.uop is UOps.LOAD: + elif u.uop is UOps.LOAD: assert u.dtype is not None mem += u.dtype.itemsize * mults - if u.uop is UOps.STORE: + elif u.uop is UOps.STORE: assert u.vin[2].dtype is not None mem += u.vin[2].dtype.itemsize * mults - if u.uop is UOps.WMMA: + elif u.uop is UOps.WMMA: if u.arg.startswith("__metal_wmma"): flops += 2*(8*8*8)//32 * mults elif u.arg == "__hip_wmma_f16_f16" or u.arg == "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32": flops += 2*(16*16*16)//32 * mults elif u.arg == "__cuda_mma_m16n8k16_f16_f32": flops += 2*(8*16*16)//32 * mults