diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e5c0fc9632..ee1b94524d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -49,6 +49,8 @@ jobs: run: DEBUG=2 PYTHON=1 python3 test/test_dtype.py - name: Test ops with Python emulator run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal)" + - name: Test symbolic with Python emulator + run: PYTHONPATH=. PYTHON=1 python3 test/test_symbolic_ops.py linter: name: Linters diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 78edf9982d..d53b137ec1 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -190,20 +190,15 @@ class Linearizer(Kernel): self.loop_uops: Dict[str, UOp] = {} # add global buffers - buf_count = 0 - buf_index = {} for i,buf in enumerate(self.bufs): if isinstance(buf, MemBuffer): - if buf.idx not in buf_index: - buf_index[buf.idx] = buf_count - buf_count += 1 self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), - (buf_index[buf.idx], f"data{buf.idx}")) + (buf.idx, f"data{buf.idx}")) # add var vals for i,var in enumerate(self.ast.vars()): assert var.expr is not None - self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (len(buf_index)+i, var.expr)) + self.loop_uops[var.expr] = self.uop(UOps.DEFINE_VAR, dtypes.int32, (), var) # define local buffers for lb in self.local_alias.values(): self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size)) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 5513bd774f..fff74403d9 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -1,16 +1,16 @@ from __future__ import annotations -from typing import List, Set, Optional, Tuple, Any, Dict +from typing import List, Set, Optional, Tuple, Any from tinygrad.helpers import DEBUG, flatten from tinygrad.dtype import dtypes, DType from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps -from tinygrad.shape.symbolic import Variable, sint +from tinygrad.shape.symbolic import sint from enum import Enum, auto from dataclasses import dataclass # bottom ones are asm only class UOps(Enum): LOOP = auto(); IF = auto(); ENDLOOP = auto(); ENDIF = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702 - DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702 + DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702 LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto(); PHI = auto() # noqa: E702 ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702 @@ -33,7 +33,7 @@ def get_recursive_children(uops:List[UOp], x:UOp) -> Set[UOp]: deps.add(u) return deps -UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL} +UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_VAR} def remove_childless_uops(uops:List[UOp]) -> List[UOp]: # NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that while 1: @@ -83,17 +83,17 @@ def uops_type_verify(uops:List[UOp]): assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}" assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}" -def uops_alu_resolve(u:UOp, vars:Dict[str, Variable]) -> sint: +def uops_alu_resolve(u:UOp) -> sint: if u.uop == UOps.CONST: return u.arg - elif u.uop == UOps.DEFINE_GLOBAL: return vars[u.arg[1]] + elif u.uop == UOps.DEFINE_VAR: return u.arg elif u.uop == UOps.ALU and u.arg == BinaryOps.MUL: - return uops_alu_resolve(u.vin[0], vars) * uops_alu_resolve(u.vin[1], vars) + return uops_alu_resolve(u.vin[0]) * uops_alu_resolve(u.vin[1]) elif u.uop == UOps.ALU and u.arg == BinaryOps.ADD: - return uops_alu_resolve(u.vin[0], vars) + uops_alu_resolve(u.vin[1], vars) + return uops_alu_resolve(u.vin[0]) + uops_alu_resolve(u.vin[1]) else: raise RuntimeError(f"ALU resolve fail @ {u.uop}") -def uops_flops_mem(uops:List[UOp], vars:Dict[str, Variable]) -> Tuple[sint, sint]: +def uops_flops_mem(uops:List[UOp]) -> Tuple[sint, sint]: flops: sint = 0 mem: sint = 0 mults: sint = 1 @@ -101,7 +101,7 @@ def uops_flops_mem(uops:List[UOp], vars:Dict[str, Variable]) -> Tuple[sint, sint for u in uops: if u.uop is UOps.LOOP: mult_stack.append(mults) - mults *= uops_alu_resolve(u.vin[1], vars) + mults *= uops_alu_resolve(u.vin[1]) if u.uop is UOps.ENDLOOP: mults = mult_stack.pop(-1) if u.uop is UOps.ALU: diff --git a/tinygrad/device.py b/tinygrad/device.py index 9ab3885986..a8cdd3b7fa 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -238,7 +238,7 @@ class Compiled: ret = CompiledASTRunner(k.ast, k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size) from tinygrad.codegen.uops import uops_flops_mem run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else [])) - ops, mem = uops_flops_mem(k.uops, {x.expr:x for x in ret.vars}) + ops, mem = uops_flops_mem(k.uops) # NOTE: we use min here to ignore the indexing FLOPS ret.op_estimate = min(ret.op_estimate, ops * run_count) ret.mem_estimate = min(ret.mem_estimate, mem * run_count) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index e3c32e7838..0fdd948bc4 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -141,7 +141,8 @@ class LazyBuffer: # *** movement ops *** def _view(self, new_st:ShapeTracker) -> LazyBuffer: - if self.st.size == 0: return self.const(0, new_st.shape) + if self.st.size == 0 or (new_st.views[-1].mask is not None and all((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): + return self.const(0, new_st.shape) if new_st.contiguous and self.base.shape == new_st.shape: return self.base return create_lazybuffer(self.device, new_st, self.dtype, base=self.base) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index cc4bd3331f..c4eb74b053 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -161,6 +161,9 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st elif uop is UOps.DEFINE_LOCAL: kk(lang.render_local(args[0], dtype, args[1])) r[u] = args[0] + elif uop is UOps.DEFINE_VAR: + bufs.append((args.expr, dtype)) + r[u] = args.expr elif uop is UOps.DEFINE_GLOBAL: assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}" bufs.append((args[1], dtype)) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index db5f09f424..c4eaaf0793 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -70,8 +70,8 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str: # all llvm stuff goes into a module module = ir.Module(name=__file__) - # extract global buffers - buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop == UOps.DEFINE_GLOBAL} + # extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order) + buf_to_dtype = {u.arg:u.dtype for u in uops if u.uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}} buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} # create llvm function @@ -144,7 +144,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str: elif uop is UOps.ALU: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype) elif uop is UOps.CAST: lvars[u] = cast(bb, lvars[vin[0]], vin[0].dtype, dtype, bitcast=isinstance(args, tuple) and args[1]) - elif uop is UOps.DEFINE_GLOBAL: lvars[u] = func.args[buf_index[args]] + elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]] elif uop is UOps.SPECIAL: lvars[u] = lvars[args.expr] elif uop is UOps.CONST: lvars[u] = const(args, dtype) else: raise RuntimeError(f"failed to render {uop}") diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index ef8f3d03de..320f17e319 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -56,6 +56,7 @@ class PythonProgram: ul: Dict[int, Any] = {} dl: Dict[int, DType] = {} pbufs: List[memoryview] = list(bufs) + pvals: List[int] = list(vals) i = 0 loop_ends: Dict[int, int] = {} while i < len(self.uops): @@ -97,6 +98,8 @@ class PythonProgram: assert dtype.fmt is not None lbuf = memoryview(bytearray(arg[1]*dtype.itemsize)) ul[i] = [lbuf.cast(dtype.fmt)] * warp_size + elif uop is UOps.DEFINE_VAR: + ul[i] = [pvals.pop(0)] * warp_size elif uop is UOps.SPECIAL: if arg[1][0] == 'g': ul[i] = [idxs[2-arg[0]]] * warp_size