From 99b0287e4ef3f6155cdc2e6eedb97aa97636a9b3 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 13 Mar 2025 11:28:38 -0400 Subject: [PATCH 01/26] add GROUP and GROUPTOP to test_arange (#9432) it does not grow quadratically, but it's not 0 ops now --- test/test_arange.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/test/test_arange.py b/test/test_arange.py index ae5b6208a1..867145af43 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -25,23 +25,29 @@ class TestArange(unittest.TestCase): return p.estimates.ops def test_complexity(self, opts=None, limit=None): - # add 1 to avoid divide by 0. arange is 0 flops now! - f1 = self._get_flops(256, opts) + 1 - f2 = self._get_flops(2560, opts) + 1 + f1 = self._get_flops(256, opts) + f2 = self._get_flops(2560, opts) print(f"{f1=}, {f2=}") - assert (f1 < 6000 and f2 < 6000) or (f2 / f1 < 16), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X" + # add 1 to avoid divide by 0. arange is 0 flops now! + assert (f1 < 6000 and f2 < 6000) or ((f2+1) / (f1+1) < 16), f"bad complexity, flops {(f2+1) / (f1+1):.1f}X while inputs 10X" if limit is not None and not getenv("PTX"): # PTX counts index ALU in flops assert f1 <= limit, f"{f1=}, {limit=}" - def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=1) - def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=1) - def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=1) - def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=1) - def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1) + def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=0) + def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=0) + def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=0) + def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=0) + def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=0) - @unittest.skip("doesn't work yet") - def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, arg=32)]) + if Device.default.renderer.has_local: + # TODO: fix limit + def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=81920) + def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496) + + def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0) + @unittest.skip("doesn't work yet") + def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)]) def test_all_opts(self, opts=None, exclude=None): k = Kernel(Tensor.arange(256).schedule()[-1].ast) From 357e364ab898b159518bdd6dc9aec0a06ceee94b Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 13 Mar 2025 23:59:28 +0800 Subject: [PATCH 02/26] am: turn off unord dispatch (#9433) --- tinygrad/runtime/support/am/ip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 9af5920bde..6450175aa6 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -206,7 +206,7 @@ class AM_GFX(AM_IP): cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr), cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr), cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.build(doorbell_offset=doorbell*2, doorbell_en=1), - cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=1, queue_size=(ring_size//4).bit_length()-2), + cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2), cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.build(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000, cp_mqd_control=self.adev.regCP_MQD_CONTROL.build(priv_state=1), cp_hqd_vmid=0, cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8), From 459d0cd14fa05ce4e7b811b29365caba9973e575 Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Thu, 13 Mar 2025 13:06:27 -0300 Subject: [PATCH 03/26] add arch to AMDRenderer and HIPRenderer (#9431) --- tinygrad/renderer/cstyle.py | 4 ++++ tinygrad/runtime/ops_amd.py | 2 +- tinygrad/runtime/ops_hip.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 1a805fbbca..a8c7a95205 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -402,6 +402,10 @@ class AMDRenderer(CStyleLanguage): opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8)))) for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]] + def __init__(self, arch:str): # gfx942 => MI300, gfx1100 => RX 7900 + self.tensor_cores, self.arch = AMDRenderer.tensor_cores, arch + def __reduce__(self): return self.__class__, (self.arch,) + # language options ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]] ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index c01b5147d4..cb1856554a 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -704,7 +704,7 @@ class AMDDevice(HCQCompiled): self.sdma_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_SDMA, 0x800000) - super().__init__(device, AMDAllocator(self), AMDRenderer(), AMDCompiler(self.arch), functools.partial(AMDProgram, self), + super().__init__(device, AMDAllocator(self), AMDRenderer(self.arch), AMDCompiler(self.arch), functools.partial(AMDProgram, self), AMDSignal, AMDComputeQueue, AMDCopyQueue) # Scratch setup diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index da957c3d95..412ddc0925 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -14,7 +14,7 @@ class HIPDevice(Compiled): self.device_id = int(device.split(":")[1]) if ":" in device else 0 self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device_id))).gcnArchName.decode() self.time_event_st, self.time_event_en = [init_c_var(hip.hipEvent_t(), lambda x: hip.hipEventCreate(ctypes.byref(x), 0)) for _ in range(2)] - super().__init__(device, HIPAllocator(self), HIPRenderer(), AMDCompiler(self.arch), functools.partial(HIPProgram, self)) + super().__init__(device, HIPAllocator(self), HIPRenderer(self.arch), AMDCompiler(self.arch), functools.partial(HIPProgram, self)) def synchronize(self): check(hip.hipSetDevice(self.device_id)) check(hip.hipDeviceSynchronize()) From 5ff90cb2617dac0144e149cd0f7ec97daa85aca5 Mon Sep 17 00:00:00 2001 From: uuuvn <83587632+uuuvn@users.noreply.github.com> Date: Fri, 14 Mar 2025 23:10:35 +0500 Subject: [PATCH 04/26] am: less magic values (#9440) --- tinygrad/runtime/support/am/ip.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 6450175aa6..065faee9bd 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -61,7 +61,14 @@ class AM_GMC(AM_IP): self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12) self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12) self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1) - self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1fffe00, enable_context=1, page_table_depth=(3 - page_table.lv)) + self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1800000, pde0_protection_fault_enable_interrupt=1, pde0_protection_fault_enable_default=1, + dummy_page_protection_fault_enable_interrupt=1, dummy_page_protection_fault_enable_default=1, + range_protection_fault_enable_interrupt=1, range_protection_fault_enable_default=1, + valid_protection_fault_enable_interrupt=1, valid_protection_fault_enable_default=1, + read_protection_fault_enable_interrupt=1, read_protection_fault_enable_default=1, + write_protection_fault_enable_interrupt=1, write_protection_fault_enable_default=1, + execute_protection_fault_enable_interrupt=1, execute_protection_fault_enable_default=1, + enable_context=1, page_table_depth=(3 - page_table.lv)) def init_hub(self, ip:Literal["MM", "GC"]): # Init system apertures @@ -290,7 +297,7 @@ class AM_IH(AM_IP): self.adev.reg(f"regIH_RB_WPTR{suf}").write(0) self.adev.reg(f"regIH_RB_RPTR{suf}").write(0) - self.adev.reg(f"regIH_DOORBELL_RPTR{suf}").write(((am.AMDGPU_NAVI10_DOORBELL_IH + ring_id) * 2), enable=1) + self.adev.reg(f"regIH_DOORBELL_RPTR{suf}").write(offset=(am.AMDGPU_NAVI10_DOORBELL_IH + ring_id) * 2, enable=1) self.adev.regIH_STORM_CLIENT_LIST_CNTL.update(client18_is_storm_client=1) self.adev.regIH_INT_FLOOD_CNTL.update(flood_cntl_enable=1) From 77a8430616836e87b99b68da1f12f353c79b80c7 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 15 Mar 2025 02:10:45 +0800 Subject: [PATCH 05/26] am: use smu based on discovery (#9441) --- tinygrad/runtime/support/am/ip.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 065faee9bd..19fc33f1bc 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -1,6 +1,6 @@ import ctypes, time, contextlib from typing import Literal -from tinygrad.runtime.autogen.am import am, smu_v13_0_0 +from tinygrad.runtime.autogen.am import am from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG class AM_IP: @@ -113,37 +113,38 @@ class AM_GMC(AM_IP): class AM_SMU(AM_IP): def __init__(self, adev): super().__init__(adev) + self.smu_mod = self.adev._ip_module("smu", am.MP1_HWIP, prever_prefix='v') self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=not self.adev.partial_boot, boot=True) def init(self): - self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) - self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) - self._send_msg(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True) + self._send_msg(self.smu_mod.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) + self._send_msg(self.smu_mod.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) + self._send_msg(self.smu_mod.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True) def is_smu_alive(self): - with contextlib.suppress(RuntimeError): self._send_msg(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100) + with contextlib.suppress(RuntimeError): self._send_msg(self.smu_mod.PPSMC_MSG_GetSmuVersion, 0, timeout=100) return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0 def mode1_reset(self): if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset") - self._send_msg(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True) + self._send_msg(self.smu_mod.PPSMC_MSG_Mode1Reset, 0, poll=True) time.sleep(0.5) # 500ms def read_table(self, table_t, cmd): - self._send_msg(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True) + self._send_msg(self.smu_mod.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True) return table_t.from_buffer(to_mv(self.adev.paddr2cpu(self.driver_table_paddr), ctypes.sizeof(table_t))) - def read_metrics(self): return self.read_table(smu_v13_0_0.SmuMetricsExternal_t, smu_v13_0_0.TABLE_SMU_METRICS) + def read_metrics(self): return self.read_table(self.smu_mod.SmuMetricsExternal_t, self.smu_mod.TABLE_SMU_METRICS) def set_clocks(self, level): if not hasattr(self, 'clcks'): self.clcks = {} - for clck in [smu_v13_0_0.PPCLK_GFXCLK, smu_v13_0_0.PPCLK_UCLK, smu_v13_0_0.PPCLK_FCLK, smu_v13_0_0.PPCLK_SOCCLK]: - cnt = self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff - self.clcks[clck] = [self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)] + for clck in [self.smu_mod.PPCLK_GFXCLK, self.smu_mod.PPCLK_UCLK, self.smu_mod.PPCLK_FCLK, self.smu_mod.PPCLK_SOCCLK]: + cnt = self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff + self.clcks[clck] = [self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)] for clck, vals in self.clcks.items(): - self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), poll=True) - self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]), poll=True) + self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), poll=True) + self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]), poll=True) def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout) def _smu_cmn_send_msg(self, msg, param=0): From bd4ae5ac539e4b0f23a8d30f0e452df918d4405c Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 15 Mar 2025 03:10:18 +0800 Subject: [PATCH 06/26] am: hotfix: import modules (#9443) * am: hotfix: import modules * hmm --- tinygrad/runtime/support/am/amdev.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index d89bac3750..d79e0248cf 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -1,5 +1,5 @@ from __future__ import annotations -import ctypes, collections, time, dataclasses, pathlib, fcntl, os +import ctypes, collections, time, dataclasses, pathlib, fcntl, os, importlib from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp from tinygrad.runtime.autogen.am import am, mp_11_0 from tinygrad.runtime.support.allocator import TLSFAllocator @@ -391,10 +391,10 @@ class AMDev: gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(bhdr) + bhdr.table_list[am.GC].offset) self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr) - def _ip_module(self, prefix:str, hwip): + def _ip_module(self, prefix:str, hwip, prever_prefix:str=""): version = [self.ip_versions[hwip]//10000, (self.ip_versions[hwip]//100)%100, self.ip_versions[hwip]%100] for ver in [version, version[:2]+[0], version[:1]+[0, 0]]: - try: return __import__(f"tinygrad.runtime.autogen.am.{prefix}_{ver[0]}_{ver[1]}_{ver[2]}", fromlist=[f"{prefix}_{ver[0]}_{ver[1]}_{ver[2]}"]) + try: return importlib.import_module(f"tinygrad.runtime.autogen.am.{prefix}_{prever_prefix}{ver[0]}_{ver[1]}_{ver[2]}") except ImportError: pass raise ImportError(f"am {self.devfmt}: failed to load {prefix} module with version {version}") From 3af7a08a06afe76e5fb727ec71981c94da660c39 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 14 Mar 2025 20:14:31 +0100 Subject: [PATCH 07/26] ast_fixup in one graph_rewrite pass [pr] (#9444) --- tinygrad/engine/schedule.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f80ebf7358..05844b2705 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -274,7 +274,7 @@ DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK} add_buffer_ops = PatternMatcher([ # LOAD - (UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))), + (UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)), # STORE (except for COPY/BUFFER_VIEW) (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), # partial assign can store to a non-contiguous ShapeTracker @@ -342,12 +342,12 @@ view_right = merge_views+PatternMatcher([ # ** unbind variables -def unbind_shapetracker(ctx:dict[Variable, int], x:UOp) -> UOp|None: +def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp) -> UOp|None: st = unwrap(x.st).simplify() if any(x.op is Ops.BIND for x in st.vars()): st, var_vals = st.unbind() - ctx.update(var_vals) - return st.to_uop() if st != x.st else None + ctx[0].update(var_vals) + return x.replace(arg=st) if st != x.st else None def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): ctx[var.replace(src=())] = val.arg @@ -387,13 +387,11 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp: ast = k.arg.ast.substitute(parents_rep) # unbind_vars + push views to edges ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right) - # add buffer ops - ast = graph_rewrite(ast, view_left+add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True) + # add buffer ops + fix_kernel_ops + ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True) if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}") # create subbuffer (TODO: this does not belong here) if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize) - # fix_kernel_ops - ast = graph_rewrite(ast, fix_kernel_ops, var_vals) return k.replace(arg=Kernel(ast, k.arg.metadata)) PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {} From 2a50e6440d8ba3a6638b54ff9c3931023bda3ae0 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 14 Mar 2025 21:27:46 +0100 Subject: [PATCH 08/26] filter sink by DONT_PUSH_VIEWS + remove extra base [pr] (#9446) --- tinygrad/engine/schedule.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 05844b2705..f29bdea5bb 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -113,6 +113,8 @@ sym = symbolic_simple+PatternMatcher([ # **** UOp realization +DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK} + @dataclass(frozen=True) class GrouperContext: assigns: dict[UOp, UOp] # maps realized buffers to assigns @@ -133,7 +135,7 @@ def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None: do_realize = PatternMatcher([ # always realize SINK parents - (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})), + (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x, None) for x in s.src if x.op not in DONT_PUSH_VIEWS)), # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize), # realize before expand or unsafe pad ops @@ -266,8 +268,6 @@ create_kernels = merge_views+PatternMatcher([ (UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None), ]) -DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK} - # **** fix kernel AST # ** create buffer ops + enumerate buffers From 14018050c1f8ef0e6713b2d2c42c965ea2ecc9ec Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Fri, 14 Mar 2025 17:36:50 -0300 Subject: [PATCH 09/26] `simple_matmul.py` uses np to generate random (#9438) * np generates randoms * hotfix: use generator for int dtype * float32 as default dtype for float generator * use np.float32 instead of stirng * add dtype= to integers generator * change import _to_np_dtype source --- extra/gemm/simple_matmul.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index fc06dad1f2..73e736ef9d 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -1,5 +1,6 @@ import numpy as np from tinygrad.helpers import getenv +from tinygrad.dtype import _to_np_dtype from tinygrad import dtypes, Tensor dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float @@ -13,12 +14,15 @@ K = getenv("K", N) CNT = getenv("CNT", 10) ATOL = getenv("ATOL", 1e-4) RTOL = getenv("RTOL", 3e-2) +INT_LOW = getenv("INT_LOW", 0) +INT_HIGH = getenv("INT_HIGH", 10) if __name__ == "__main__": def init_matrix(rows, cols): + rng = np.random.default_rng() if dtype_in in dtypes.ints: - return Tensor.randint((rows, cols), dtype=dtype_in).realize() - return Tensor.rand(rows, cols, dtype=dtype_in).realize() + return Tensor(rng.integers(INT_LOW, INT_HIGH, (rows, cols), dtype=_to_np_dtype(dtype_in))).realize() + return Tensor(rng.random((rows, cols), dtype=np.float32).astype(_to_np_dtype(dtype_in))).realize() a, b = init_matrix(M, K), init_matrix(K, N) for i in range(CNT): From b0f63d3c040f4c94ef8898627d66e8139a2e17dd Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 14 Mar 2025 17:14:22 -0400 Subject: [PATCH 10/26] Revert "`simple_matmul.py` uses np to generate random (#9438)" (#9449) This reverts commit 14018050c1f8ef0e6713b2d2c42c965ea2ecc9ec. --- extra/gemm/simple_matmul.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index 73e736ef9d..fc06dad1f2 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -1,6 +1,5 @@ import numpy as np from tinygrad.helpers import getenv -from tinygrad.dtype import _to_np_dtype from tinygrad import dtypes, Tensor dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float @@ -14,15 +13,12 @@ K = getenv("K", N) CNT = getenv("CNT", 10) ATOL = getenv("ATOL", 1e-4) RTOL = getenv("RTOL", 3e-2) -INT_LOW = getenv("INT_LOW", 0) -INT_HIGH = getenv("INT_HIGH", 10) if __name__ == "__main__": def init_matrix(rows, cols): - rng = np.random.default_rng() if dtype_in in dtypes.ints: - return Tensor(rng.integers(INT_LOW, INT_HIGH, (rows, cols), dtype=_to_np_dtype(dtype_in))).realize() - return Tensor(rng.random((rows, cols), dtype=np.float32).astype(_to_np_dtype(dtype_in))).realize() + return Tensor.randint((rows, cols), dtype=dtype_in).realize() + return Tensor.rand(rows, cols, dtype=dtype_in).realize() a, b = init_matrix(M, K), init_matrix(K, N) for i in range(CNT): From 0e591baf434d2611dd7ee517cf2d466ba990a1a0 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 14 Mar 2025 17:53:52 -0400 Subject: [PATCH 11/26] redo simple_matmul change (#9450) numpy does not support bfloat16 --- extra/gemm/simple_matmul.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index fc06dad1f2..1edad82c23 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -1,5 +1,6 @@ import numpy as np from tinygrad.helpers import getenv +from tinygrad.dtype import _to_np_dtype from tinygrad import dtypes, Tensor dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float @@ -13,12 +14,17 @@ K = getenv("K", N) CNT = getenv("CNT", 10) ATOL = getenv("ATOL", 1e-4) RTOL = getenv("RTOL", 3e-2) +INT_LOW = getenv("INT_LOW", 0) +INT_HIGH = getenv("INT_HIGH", 10) if __name__ == "__main__": def init_matrix(rows, cols): + rng = np.random.default_rng() + # NOTE: numpy does not support bfloat16 + if (np_dtype := _to_np_dtype(dtype_in)) is None: np_dtype = np.float32 if dtype_in in dtypes.ints: - return Tensor.randint((rows, cols), dtype=dtype_in).realize() - return Tensor.rand(rows, cols, dtype=dtype_in).realize() + return Tensor(rng.integers(INT_LOW, INT_HIGH, (rows, cols), dtype=np_dtype)).realize() + return Tensor(rng.random((rows, cols), dtype=np.float32).astype(np_dtype)).cast(dtype_in).realize() a, b = init_matrix(M, K), init_matrix(K, N) for i in range(CNT): From ca5064a5b65fbbc4654d76c916bebef21e0e586c Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 14 Mar 2025 17:54:32 -0400 Subject: [PATCH 12/26] remove Kernel.float4_axis [pr] (#9448) --- tinygrad/codegen/kernel.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 20cf001dbb..f23abeb381 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -102,9 +102,6 @@ class Kernel: @property def membufs(self) -> list[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}]) - # TODO: these need more tests or it might silently be no-op - def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501 - def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]: upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:] assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}" @@ -461,7 +458,8 @@ class Kernel: if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]): # are we grouping? (requires local shape support) - if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501 + if not [x for x in self.sts[0].unit_stride_axes() if x >= self.first_upcast and self.sts[0].shape[x]%4 == 0] and \ + self.first_reduce <= 2 and self.first_reduce < self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # TODO: use 1024 if it's allowed in a smarter way for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts): From cb7a7f69c759206d8e4bf682f9619558d4e09306 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 15 Mar 2025 07:49:37 +0800 Subject: [PATCH 13/26] quantization preprocessor from DSP, should be universal (#9437) * quantization preprocessor from DSP, should be universal * touchups * fix tests --- .github/workflows/test.yml | 4 ++- extra/replay_pkl.py | 2 ++ test/test_quantize_onnx.py | 49 +++++++++++++++++++++-------- tinygrad/codegen/lowerer.py | 63 +++++++++++++++++++++++++++++++++++-- tinygrad/helpers.py | 1 + tinygrad/ops.py | 1 + tinygrad/runtime/ops_dsp.py | 6 ++-- 7 files changed, 106 insertions(+), 20 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a0b4038507..a6cb1ebb34 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -423,6 +423,8 @@ jobs: run: LLVM=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20 - name: Test Additional ONNX Ops (CPU) run: CPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_ops.py + - name: Test Quantize ONNX + run: CPU=1 PYTHONPATH=. python3 test/test_quantize_onnx.py - name: Run CLOUD=1 Test run: | CLOUDDEV=CPU CLOUD=1 python3 test/test_tiny.py @@ -467,7 +469,7 @@ jobs: testdsp: name: Linux (DSP) runs-on: ubuntu-24.04 - timeout-minutes: 10 + timeout-minutes: 15 steps: - name: Checkout Code uses: actions/checkout@v4 diff --git a/extra/replay_pkl.py b/extra/replay_pkl.py index 43272adc8c..fc40280d49 100644 --- a/extra/replay_pkl.py +++ b/extra/replay_pkl.py @@ -26,6 +26,8 @@ if __name__ == "__main__": k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0)) k.apply_opt(Opt(OptOps.PADTO, 2, 128)) k.apply_opt(Opt(OptOps.UPCAST, 2, 128)) + elif knum == 3: + k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=128)) else: k.hand_coded_optimizations() p2 = k.to_program() diff --git a/test/test_quantize_onnx.py b/test/test_quantize_onnx.py index 1f6681b0cc..e1bce65950 100644 --- a/test/test_quantize_onnx.py +++ b/test/test_quantize_onnx.py @@ -2,8 +2,9 @@ import numpy as np import unittest from dataclasses import replace from tinygrad import Tensor, Context, Device, dtypes +from tinygrad.ops import Ops from tinygrad.codegen.kernel import Kernel, Opt, OptOps -from tinygrad.engine.realize import CompiledRunner, ExecItem +from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item N = 512 @@ -44,24 +45,46 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3): ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata) for _ in range(run_count): ei.run(wait=True) +def get_quantized_model(sz): + from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader + class FakeDataReader(CalibrationDataReader): + def __init__(self): self.cnt = 0 + def get_next(self) -> dict: + self.cnt += 1 + if self.cnt == 100: return None + return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)} + out_file = "/tmp/test_out.onnx" + quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file, + FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False, + activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8, + extra_options={"ActivationSymmetric": False}) + return out_file + +@unittest.skipIf(Device.DEFAULT != "CPU", "only tests for CPU") +class TestQuantizeOnnxCPU(unittest.TestCase): + def test_quant_128(self, sz=128): + try: + import onnx + except ImportError: + raise unittest.SkipTest() + from extra.onnx import OnnxRunner + out_file = get_quantized_model(sz) + onnx_model = onnx.load(out_file) + run_onnx = OnnxRunner(onnx_model) + inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)) + with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1): + sched = run_onnx({"input":inp})["output"].schedule() + ei = lower_schedule_item(sched[-2]) + daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_ACC] + assert all(u.dtype.scalar() is dtypes.int for u in daccs) + @unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP") class TestQuantizeOnnx(unittest.TestCase): def test_quant_128(self): self.test_quant(128) def test_quant(self, sz=512): - from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader from examples.benchmark_onnx import load_onnx_model - class FakeDataReader(CalibrationDataReader): - def __init__(self): self.cnt = 0 - def get_next(self) -> dict: - self.cnt += 1 - if self.cnt == 100: return None - return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)} - out_file = "/tmp/test_out.onnx" # divide is ~1500-2000 without reduce_range, 750-900 with it - quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file, - FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False, - activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8, - extra_options={"ActivationSymmetric": False}) + out_file = get_quantized_model(sz) run_onnx_jit, _ = load_onnx_model(out_file) with Context(DONT_REALIZE_EXPAND=1): run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 90861e3a9b..40b0bc96f0 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -2,11 +2,12 @@ import functools, itertools, operator, math from dataclasses import dataclass from typing import cast -from tinygrad.dtype import dtypes, PtrDType -from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop +from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype +from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop, GroupOp from tinygrad.renderer import Renderer -from tinygrad.helpers import all_int, prod, partition, flatten, unwrap +from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE from tinygrad.codegen.expander import expand_rewrite +from tinygrad.codegen.symbolic import symbolic # returns the axes to create new_shape if new_shape can be created by combining axis from old_shape def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None: @@ -156,9 +157,65 @@ pm_lowerer = PatternMatcher([ # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store), (UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)), + (UPat(Ops.IGNORE, name="x"), lambda x: x.src[0]), +]) + +# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints **** + +def view_to_mask(x:UOp): + from tinygrad.shape.shapetracker import ShapeTracker, View + st = cast(ShapeTracker, x.st) + if len(st.views) > 1: return None + if st.views[-1].mask is None: return None + return ShapeTracker((View(st.shape, (0,)*len(st.shape), 0, st.views[-1].mask, False),)) + +FP = (1 << 16) +pm_quant = symbolic+PatternMatcher([ + # cast after add/mul + (UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32), + lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)), + (UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32), + lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)), + # MUL after reduce + (UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c), + # CAST after reduce (doesn't work if it's a size change) + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"), + lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None), + # x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats) + (UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats), + lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None), + # mul 0 * c1 is 0 + (UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) * + UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1), + # mul (with plus) 0 * c1 is 0 + (UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) * + (UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int) + \ + UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"), + lambda ld,v,c1: ld*c1), + # fixed point mult, replace (x.float()*c1+c2).int() with an int expression + ((UPat.var("x").cast(dtypes.float)*UPat.cvar("c1")+UPat.cvar("c2")).cast(dtypes.int), + lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP), + # where move + (UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul: + (yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None), + ((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c), + (UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid: + (x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)), + ((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) * + UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2: + x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))), + # don't care + (UPat(Ops.STORE, name="x"), lambda x: + x.replace(src=(x.src[0], UOp(Ops.IGNORE, src=(x.src[1],), arg=mm), UOp(Ops.IGNORE, x.src[2].dtype, src=(x.src[2],), arg=mm),)) \ + if x.src[1].op is not Ops.IGNORE and (mm:=view_to_mask(x.src[1])) is not None else None), + (UPat(Ops.IGNORE, src=(UPat((*GroupOp.ALU, Ops.CAST), name="alu"),), name="ig"), + lambda ig,alu: alu.replace(src=tuple(UOp(Ops.IGNORE, x.dtype, (x,), ig.arg) for x in alu.src))), + (UPat(Ops.IGNORE, src=(UPat.cvar("c"),), name="ig"), lambda ig, c: c), + (UPat(Ops.IGNORE, src=(UPat(Ops.VALID, name="v"),), name="ig"), lambda ig, v: UOp.const(dtypes.bool, True) if v.src[0].arg == ig.arg else None), ]) def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: + if QUANTIZE and opts.device in {"CPU", "DSP"}: ast = graph_rewrite(ast, pm_quant, name="quantize") sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts)) # expand_rewrite turns this into a vectorized program return expand_rewrite(sink) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 2b51316083..3c87eba7d6 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -113,6 +113,7 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), Conte PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0) +QUANTIZE = ContextVar("QUANTIZE", 0) @dataclass(frozen=True) class Metadata: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 57b5959578..42e4ec929d 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -154,6 +154,7 @@ class Ops(FastEnum): # CUSTOMI is inline CUSTOM = auto(); CUSTOMI = auto() # noqa: E702 + IGNORE = auto() class GroupOp: Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index 2d23df950d..dbf604bbec 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -20,9 +20,9 @@ dsp_pm = PatternMatcher([ ]) dsp_pm_late = PatternMatcher([ - (UPat.var("x")+UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")), - (UPat.var("x")*UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")), - (UPat.var("x")//UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")), + (UPat.var("x")+UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None), + (UPat.var("x")*UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None), + (UPat.var("x")//UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None), (UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True), lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])), ]) From be2161652b29641b0d97d34728692f41530ff2fc Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 15 Mar 2025 09:00:14 +0100 Subject: [PATCH 14/26] reorder into swizzler + ast_fixup [pr] (#9456) --- tinygrad/engine/schedule.py | 57 ++++++++++++++++++------------------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f29bdea5bb..8aa49f6112 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -223,7 +223,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]: if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce] return ctx.realizes -# break the SINK into kernels +# **** create kernels @dataclass(frozen=True) class Kernel: @@ -243,6 +243,7 @@ def create_kernel(ctx:KernelContext, x:UOp, b:UOp): return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape) DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER} + def append_to_kernel(ctx:KernelContext, x:UOp): new_srcs: list[UOp] = [] metadata = dict.fromkeys(x.arg.metadata) @@ -268,30 +269,7 @@ create_kernels = merge_views+PatternMatcher([ (UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None), ]) -# **** fix kernel AST - -# ** create buffer ops + enumerate buffers - -add_buffer_ops = PatternMatcher([ - # LOAD - (UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)), - # STORE (except for COPY/BUFFER_VIEW) - (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), - # partial assign can store to a non-contiguous ShapeTracker - (UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)), - lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()), - # otherwise the store is contiguous - (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), - lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()), - # if the last child is a VIEW we merge the ShapeTrackers and store the base - (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))), - lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)), - # remove CONTIGUOUS/DEVICE from kernel AST - (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), - (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), -]) - -# ** push views to buffer ops +# **** swizzler def apply_swizzle(u:UOp) -> UOp: with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left) @@ -314,7 +292,7 @@ def reduceop_view_right(src:UOp, v:UOp, r:UOp): assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}" return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape)) -def elementwise_view_right(root:UOp) -> UOp|None: +def elementwise_view_right(root:UOp): if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in DONT_PUSH_VIEWS]): return None assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}" # place view after applying the elementwise op @@ -323,7 +301,7 @@ def elementwise_view_right(root:UOp) -> UOp|None: # reshape to match downstream shapes return root.replace(src=tuple(new_src)).reshape(root.shape) -def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: +def merge_double_reduce(root:UOp, first_reduce:UOp): assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time" return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)) @@ -340,9 +318,9 @@ view_right = merge_views+PatternMatcher([ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -# ** unbind variables +# **** unbind variables -def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp) -> UOp|None: +def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp): st = unwrap(x.st).simplify() if any(x.op is Ops.BIND for x in st.vars()): st, var_vals = st.unbind() @@ -354,7 +332,26 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): return var unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),]) -# ** fix_kernel_ops +# **** fix kernel AST + +add_buffer_ops = PatternMatcher([ + # LOAD + (UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)), + # STORE (except for COPY/BUFFER_VIEW) + (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), + # partial assign can store to a non-contiguous ShapeTracker + (UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)), + lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()), + # otherwise the store is contiguous + (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), + lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()), + # if the last child is a VIEW we merge the ShapeTrackers and store the base + (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))), + lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)), + # remove CONTIGUOUS/DEVICE from kernel AST + (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), + (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), +]) def check_load_st(glbl:UOp, view:UOp): if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return From 09e7708b4973c9e44047b9e3850ab31e47a3f120 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 16 Mar 2025 13:39:24 +0800 Subject: [PATCH 15/26] minimum change for rdna4 [pr] (#9455) --- extra/hip_large_kernel.py | 25 +++++++++++++++++++++++++ tinygrad/renderer/cstyle.py | 5 ++++- tinygrad/runtime/ops_amd.py | 3 ++- 3 files changed, 31 insertions(+), 2 deletions(-) create mode 100644 extra/hip_large_kernel.py diff --git a/extra/hip_large_kernel.py b/extra/hip_large_kernel.py new file mode 100644 index 0000000000..783e30f50f --- /dev/null +++ b/extra/hip_large_kernel.py @@ -0,0 +1,25 @@ +from tinygrad.device import Device, Buffer +from tinygrad.dtype import dtypes, _to_np_dtype + +dev = Device.default +mbin = dev.compiler.compile(""" +typedef long unsigned int size_t; +extern "C" __attribute__((device, const)) size_t __ockl_get_group_id(unsigned int); +extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, 1))) write_ones(signed char* data0) { + int gidx0 = __ockl_get_group_id(0); /* 16 */ + int gidx1 = __ockl_get_group_id(1); /* 1026048 */ + *(data0+(gidx0+gidx1*1)) = 1; +} +""") +dev.compiler.disassemble(mbin) +buf0 = Buffer(Device.DEFAULT, 1*65537, dtypes.uint8).ensure_allocated() + +prg = dev.runtime("write_ones", mbin) +prg(buf0._buf, global_size=(1,65537,1), local_size=(1,1,1), wait=True) + +import numpy as np +def to_np(buf): return np.frombuffer(buf.as_buffer().cast(buf.dtype.base.fmt), dtype=_to_np_dtype(buf.dtype.base)) + +big = to_np(buf0) +print(big) +print((big-1).nonzero()) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index a8c7a95205..37c38db1d7 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -397,13 +397,16 @@ def cast_float_to_bf16(x: UOp) -> UOp: class AMDRenderer(CStyleLanguage): device = "AMD" shared_max = 65536 + # NOTE: this is only really needed on gfx12, even though gfx11 reports the same limitation + global_max = (2147483647, 65535, 65535) # https://gpuopen.com/learn/wmma_on_rdna3/ tensor_cores = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do, opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8)))) for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]] def __init__(self, arch:str): # gfx942 => MI300, gfx1100 => RX 7900 - self.tensor_cores, self.arch = AMDRenderer.tensor_cores, arch + # TODO: fix tensor cores for gfx1201 + self.tensor_cores, self.arch = AMDRenderer.tensor_cores if arch != "gfx1201" else [], arch def __reduce__(self): return self.__class__, (self.arch,) # language options diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index cb1856554a..7eca9583ac 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -685,7 +685,8 @@ class AMDDevice(HCQCompiled): self.dev_iface = PCIIface(self, self.device_id) if AMDDevice.driverless else KFDIface(self, self.device_id) self.target = int(self.dev_iface.props['gfx_target_version']) self.arch = "gfx%d%x%x" % (self.target // 10000, (self.target // 100) % 100, self.target % 100) - if self.target < 100300 or self.target >= 120000: raise RuntimeError(f"Unsupported arch: {self.arch}") + if self.target < 100300 or self.target >= 130000: raise RuntimeError(f"Unsupported arch: {self.arch}") + if DEBUG >= 1: print(f"AMDDevice: opening {self.device_id} with target {self.target} arch {self.arch}") self.max_cu_id = self.dev_iface.props['simd_count'] // self.dev_iface.props['simd_per_cu'] - 1 self.max_wave_id = self.dev_iface.props['max_waves_per_simd'] * self.dev_iface.props['simd_per_cu'] - 1 From d2cfbd8a4d2ebb82d6caf3ee8f03d12d39107445 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 16 Mar 2025 17:21:20 -0400 Subject: [PATCH 16/26] bert lower learning rate and total steps (#9466) closer to the other submission with BS=240. converged with 10% less epochs --- examples/mlperf/model_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 83e1a2a107..d0f18bf487 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -658,9 +658,9 @@ def train_bert(): # ** hyperparameters ** BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS)) EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS)) - max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.0002 * math.sqrt(BS/96)) + max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.00018 * math.sqrt(BS/96)) - train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3630000 // BS) + train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3300000 // BS) warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1) max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000 eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down From 15ee742afa3ae23370113fb1861a47779d97920d Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 17 Mar 2025 14:36:13 +0800 Subject: [PATCH 17/26] add get_children_map to uop (#9470) * add get_children_map to uop * update_children * fix new children --- test/test_rewrite_tracked_childen.py | 51 ++++++++++++++++++++++++++++ tinygrad/ops.py | 33 ++++++++++++++---- 2 files changed, 78 insertions(+), 6 deletions(-) create mode 100644 test/test_rewrite_tracked_childen.py diff --git a/test/test_rewrite_tracked_childen.py b/test/test_rewrite_tracked_childen.py new file mode 100644 index 0000000000..e9b64d82fe --- /dev/null +++ b/test/test_rewrite_tracked_childen.py @@ -0,0 +1,51 @@ +import unittest +from tinygrad import Tensor +from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp + +class TestRewriteTrackedChildren(unittest.TestCase): + def test_children_in_context(self): + def print_children(ctx:RewriteContext, sink:UOp): + view_w_child = sink.src[0].src[0].src[0] + assert view_w_child.op is Ops.VIEW + assert set([x.arg for x in ctx.children[view_w_child]]) == set((2,3)) + ctx.update_children() + assert set([x.arg for x in ctx.children[view_w_child]]) == set((3,4)) + # this is the 3 + assert len(ctx.children[sink.src[0].src[1]]) == 1 + assert next(iter(ctx.children[sink.src[0].src[1]])).op is Ops.ADD + # this is the 4 + assert len(ctx.children[sink.src[0].src[0]]) == 1 + assert next(iter(ctx.children[sink.src[0].src[0]])).op is Ops.ADD + rewrite = PatternMatcher([ + (UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)), + (UPat(Ops.SINK, name="sink"), print_children) + ]) + a = Tensor(2) + b = Tensor(3) + c = a + b + sink = c.lazydata.sink() + sink = graph_rewrite(sink, rewrite, track_children=True) + + def test_simple_child(self): + rewrite = PatternMatcher([ + (UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)), + ]) + a = Tensor(2) + b = Tensor(3) + c = a + b + sink = c.lazydata + view_w_child = a.lazydata.src[0] + print([x().arg for x in view_w_child.children]) + print([x.arg for x in sink.get_children_map()[view_w_child]]) + self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((2,3))) + # children can either be added to or removed from the map with graph_rewrite + # added to is easy to detect, just hook the UOp constructor + # when are children removed? + # * if a rewrite rule returns a UOp, the matched node is removed from the graph + sink = graph_rewrite(sink, rewrite) + print([x().arg for x in view_w_child.children]) + print([x.arg for x in sink.get_children_map()[view_w_child]]) + self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((3,4))) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 42e4ec929d..023dd8cb71 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -281,6 +281,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return nodes return _toposort(self, cache=set()) + # returns map of UOps to their children in the graph rooted by self + def get_children_map(self) -> dict[UOp, dict[UOp, None]]: + ret: dict[UOp, dict[UOp, None]] = {} + for u in self.toposort: + for s in u.src: ret.setdefault(s, {})[u] = None + return ret + @functools.cached_property def tuplize(self:UOp) -> tuple[int, Any, Optional[DType], tuple]: return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src)) @@ -896,10 +903,23 @@ def launch_viz(env_str:str, data:str): # *** simple graph rewrite engine *** class RewriteContext: - def __init__(self, pm, ctx=None): + def __init__(self, pm, ctx=None, children=None): self.pm: PatternMatcher = pm - self.ctx = ctx + self.ctx = self if children is not None else ctx self.replace: dict[UOp, UOp] = {} + self.children = children + # TODO: is this function always right? + def update_children(self): + # add any new children from UOps that were replaced + for u in self.replace.values(): + for s in u.src: self.children.setdefault(s, {})[u] = None + # find any children that were replaced and replace them + for k,v in self.children.items(): + new_child: dict[UOp, None] = {} + for tv in v: + while (nv:=self.replace.get(tv, None)) is not None and nv is not tv: tv = nv + new_child[tv] = None + self.children[k] = new_child def top_down_rewrite(self, n:UOp) -> UOp: if (rn := self.replace.get(n)) is not None: return rn new_src = tuple([self.top_down_rewrite(x) for x in n.src]) @@ -914,15 +934,16 @@ class RewriteContext: self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg)) return ret -def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp: +def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False) -> UOp: if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0: tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name)) - return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink) + rewrite_ctx = RewriteContext(pm, ctx, children=sink.get_children_map() if track_children else None) + return rewrite_ctx.bottom_up_rewrite(sink) if bottom_up else rewrite_ctx.top_down_rewrite(sink) -def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]: +def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False) -> dict[UOp, UOp]: if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0: tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name)) - rewrite_ctx = RewriteContext(pm, ctx) + rewrite_ctx = RewriteContext(pm, ctx, children=sink.get_children_map() if track_children else None) return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]} def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x From 52ae9af4dd3df3749b55540ccd4018519afbb44e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 17 Mar 2025 15:10:36 +0800 Subject: [PATCH 18/26] Fast DSP for MobileNetV2 (try 2) (#9467) * Fast DSP for MobileNetV2 (try 2) * enable fast path on uchar * fix tests --- extra/onnx.py | 6 +++++- extra/replay_pkl.py | 36 ++++++++++++++++++++++++-------- tinygrad/codegen/devectorizer.py | 11 +++++++--- tinygrad/codegen/kernel.py | 4 +++- tinygrad/renderer/cstyle.py | 4 ++-- tinygrad/runtime/ops_dsp.py | 6 ++++++ 6 files changed, 51 insertions(+), 16 deletions(-) diff --git a/extra/onnx.py b/extra/onnx.py index 430b26be92..59ec591c0e 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -728,7 +728,11 @@ def get_onnx_ops(): def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) - return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous() + if out_dtype == dtypes.uchar: + # this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff + return _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype).contiguous() + else: + return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous() def DynamicQuantizeLinear(x: Tensor): # only support uint8 diff --git a/extra/replay_pkl.py b/extra/replay_pkl.py index fc40280d49..752dd70b80 100644 --- a/extra/replay_pkl.py +++ b/extra/replay_pkl.py @@ -1,6 +1,7 @@ import pickle, sys from dataclasses import replace -from tinygrad import Device +from tinygrad import Device, Context +from tinygrad.device import Buffer from tinygrad.helpers import getenv from tinygrad.engine.jit import TinyJit from tinygrad.engine.realize import CompiledRunner @@ -8,10 +9,11 @@ from tinygrad.renderer import ProgramSpec from tinygrad.codegen.kernel import Kernel, Opt, OptOps if __name__ == "__main__": - with open(sys.argv[1], "rb") as f: - fxn: TinyJit = pickle.load(f) - print(f"{f.tell()/1e6:.2f}M loaded") - print(type(fxn)) + with Context(DEBUG=0): + with open(sys.argv[1], "rb") as f: + fxn: TinyJit = pickle.load(f) + print(f"{f.tell()/1e6:.2f}M loaded") + print(type(fxn)) knum = 1 for ei in fxn.captured.jit_cache: @@ -21,17 +23,33 @@ if __name__ == "__main__": p: ProgramSpec = ei.prg.p k = Kernel(p.ast, Device["DSP"].renderer) if not getenv("NOOPT"): - if knum == 2: + if knum in [6,7,9,11]: + k.apply_opt(Opt(OptOps.PADTO, 1, 128)) + k.apply_opt(Opt(OptOps.UPCAST, 1, 128)) + elif knum in [5,8]: k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0)) k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0)) k.apply_opt(Opt(OptOps.PADTO, 2, 128)) k.apply_opt(Opt(OptOps.UPCAST, 2, 128)) + elif knum == 2: + k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0)) + k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0)) + k.apply_opt(Opt(OptOps.PADTO, 2, 128)) + k.apply_opt(Opt(OptOps.UPCAST, 2, 128)) + #k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=4)) + elif knum == 1: + k.apply_opt(Opt(op=OptOps.UNROLL, axis=2, arg=0)) + k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0)) + #k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0)) + k.apply_opt(Opt(OptOps.PADTO, 2, 128)) + k.apply_opt(Opt(OptOps.UPCAST, 2, 128)) elif knum == 3: - k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=128)) + k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=4)) + k.apply_opt(Opt(OptOps.UPCAST, 1, 128)) else: k.hand_coded_optimizations() + #if knum in [5]: k.apply_opt(Opt(OptOps.UPCAST, 1, 2)) p2 = k.to_program() - new_ei = replace(ei, prg=CompiledRunner(p2)) + new_ei = replace(ei, prg=CompiledRunner(p2), bufs=[Buffer("DSP", 128+b.size*2, b.dtype).view(b.size, b.dtype, 128) for b in ei.bufs]) new_ei.run() knum += 1 - diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index a0a9d4e87c..898923e669 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -45,7 +45,8 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): global_offset += len(grp) assert None not in idxs, f"some idxs are missing {idxs}" # this base thing is for image, we want the CAT to be a normal pointer - return UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)).gep(tuple(cast(list[int], idxs))) + post_cat = UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)) if len(ret) > 1 else ret[0] + return post_cat.gep(tuple(cast(list[int], idxs))) def cat_after_store(cat:UOp, data:UOp): # TODO: this is written in many places @@ -143,7 +144,11 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): if (sz:=ls.src[0].dtype.count) == 1: return None lengths = [] buf = idx.src[0] - if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): + must_divide = True + if ctx is not None and ctx.device == "DSP": + lengths = [128,64,32,16,8,4] + must_divide = False + elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): pass elif isinstance(buf.dtype, ImageDType): lengths = [4] @@ -158,7 +163,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): for fold_length in lengths: if global_offset+fold_length > sz: continue oidx = idx.src[1] + global_offset - if oidx.simplify().divides(fold_length) is None: continue + if must_divide and oidx.simplify().divides(fold_length) is None: continue lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None) if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local)) if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:])) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index f23abeb381..f9620f2e1f 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -501,10 +501,12 @@ class Kernel: for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0)) # potentially do more upcasts of non reduce axes based on a heuristic + is_dsp = self.opts is not None and self.opts.device == "DSP" upcasted_axis: set[int] = set() while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024): xb_choices = [] - for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce + # consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP) + for axis, upcast_amount in itertools.product(range(self.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]): # if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501 xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501 diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 37c38db1d7..01b8a85d56 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -197,8 +197,8 @@ class ClangRenderer(CStyleLanguage): if sys.platform == 'win32': kernel_prefix = "__attribute__((ms_abi)) " def render_vector_prefix(self, dt:DType) -> str: - # round (down) to power of two - alignment = 2**int(math.log2(dt.itemsize)) + # round (down) to power of two (this is actually the default clang behavior) + alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) else 1 return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));" def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index dbf604bbec..be76d29747 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -27,6 +27,11 @@ dsp_pm_late = PatternMatcher([ lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])), ]) +# NOTE: this just increases readability of the generated code +dsp_string = PatternMatcher([ + (UPat(Ops.CONST, (dtypes.int8, dtypes.uint8), name="x"), lambda ctx,x: str(x.arg)), +]) + class DSPRenderer(ClangRenderer): device = "DSP" supports_float4 = True @@ -34,6 +39,7 @@ class DSPRenderer(ClangRenderer): kernel_prefix = "__attribute__((noinline)) " pre_matcher = dsp_pm extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher + string_rewrite = dsp_string+ClangRenderer.string_rewrite type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" } code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})", Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})", From 242daa4f9a02171576f737ad80acc508e6cb847a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 17 Mar 2025 16:06:37 +0800 Subject: [PATCH 19/26] ptrcat (#9473) --- tinygrad/codegen/devectorizer.py | 10 +++++----- tinygrad/ops.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index 898923e669..430cb29aaa 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -45,7 +45,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): global_offset += len(grp) assert None not in idxs, f"some idxs are missing {idxs}" # this base thing is for image, we want the CAT to be a normal pointer - post_cat = UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)) if len(ret) > 1 else ret[0] + post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)) return post_cat.gep(tuple(cast(list[int], idxs))) def cat_after_store(cat:UOp, data:UOp): @@ -74,11 +74,11 @@ load_store_folding = PatternMatcher([ lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)), # GEP on data of STORE (UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store), - # put CAT after LOAD - (UPat(Ops.LOAD, src=(UPat(Ops.CAT, name="cat"),), name="ld", allow_any_len=True), + # put PTRCAT after LOAD + (UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True), lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))), - # put CAT after STORE - (UPat(Ops.STORE, src=(UPat(Ops.CAT, name="cat"), UPat(name="data"))), cat_after_store), + # put PTRCAT after STORE + (UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data"))), cat_after_store), ]) # ***** image load valid simplification ***** diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 023dd8cb71..7e11fba53c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -117,7 +117,7 @@ class Ops(FastEnum): REDUCE_AXIS = auto() # helper ops - GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702 + GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702 # UnaryOps CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 From 824c5f41ac11e3e96117ce3a56fa518e05d86cc3 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 17 Mar 2025 16:42:12 +0800 Subject: [PATCH 20/26] dsp work try 3 (#9475) * dsp work try 3 * padding --- extra/replay_pkl.py | 2 +- tinygrad/codegen/devectorizer.py | 1 + tinygrad/runtime/ops_dsp.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/extra/replay_pkl.py b/extra/replay_pkl.py index 752dd70b80..b22a851725 100644 --- a/extra/replay_pkl.py +++ b/extra/replay_pkl.py @@ -50,6 +50,6 @@ if __name__ == "__main__": k.hand_coded_optimizations() #if knum in [5]: k.apply_opt(Opt(OptOps.UPCAST, 1, 2)) p2 = k.to_program() - new_ei = replace(ei, prg=CompiledRunner(p2), bufs=[Buffer("DSP", 128+b.size*2, b.dtype).view(b.size, b.dtype, 128) for b in ei.bufs]) + new_ei = replace(ei, prg=CompiledRunner(p2), bufs=[Buffer("DSP", 1024+b.size*2, b.dtype).view(b.size, b.dtype, 512) for b in ei.bufs]) new_ei.run() knum += 1 diff --git a/tinygrad/codegen/devectorizer.py b/tinygrad/codegen/devectorizer.py index 430cb29aaa..87120089ab 100644 --- a/tinygrad/codegen/devectorizer.py +++ b/tinygrad/codegen/devectorizer.py @@ -12,6 +12,7 @@ from tinygrad.renderer import Renderer # ***** load/store grouping ***** def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None): + if getenv("UNSAFE_DISABLE_MASK", 0): mask = None # first, extract all the relevant offsets offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict) for i in range(vec.dtype.count): diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index be76d29747..d27c50f808 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -16,7 +16,7 @@ dsp_pm = PatternMatcher([ lambda x: UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=tuple(x.gep(tuple(range(i, i+32))) for i in range(0, 128, 32)), arg="__builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B({3}, {2}), __builtin_HEXAGON_V6_vpackwh_sat_128B({1}, {0}))")), (UPat(Ops.GEP, name="x"), lambda x: UOp(Ops.CUSTOM, x.dtype, x.src+x.src, - "__builtin_shufflevector({0}, {1}, "+','.join([str(y) for y in x.arg])+")") if len(x.arg) > 1 else None), + "__builtin_shufflevector({0}, {1}, "+','.join([str(y) for y in x.arg])+")") if len(x.arg) > 1 and x.src[0].dtype.count > 1 else None), ]) dsp_pm_late = PatternMatcher([ From e26caf4c3a9b30822b07f4b9ab9107540a5b7c43 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 17 Mar 2025 16:47:48 +0800 Subject: [PATCH 21/26] hotfix: skip test_mean_half_precision_underflow on amd ci (#9476) The global size is very large (781250 gidx) and the emulated version takes more than 1 minute to execute the kernel. --- test/test_dtype.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_dtype.py b/test/test_dtype.py index c8d4b00230..6d06831c5d 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -820,6 +820,7 @@ class TestAutoCastType(unittest.TestCase): np.testing.assert_allclose(t.grad.numpy(), [1, 0]) @unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow") + @unittest.skipIf(CI and Device.DEFAULT == "AMD", "very slow") @unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size") @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_mean_half_precision_underflow(self): From bd1f71c1e2c01d32e4c409da9891b54596c128dc Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 17 Mar 2025 17:02:40 +0800 Subject: [PATCH 22/26] simple failing test for extra ops in VALID [pr] (#9474) * simple failing test for extra valids [pr] * this has DEBUG=4 --- test/test_schedule.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_schedule.py b/test/test_schedule.py index 3ac485dab5..605ceced3e 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -98,6 +98,13 @@ class TestSchedule(unittest.TestCase): a.realize() assert not a.lazydata.is_realized + @unittest.expectedFailure + def test_simplify_padded_const(self): + a = Tensor.empty(1022).cummax(axis=0) + sched = check_schedule(a, 5) + ast = sched[0].ast + self.assertLessEqual(len([u for u in ast.toposort if u.op is Ops.WHERE]), 6) + def test_basic_binop_fusion(self): a = Tensor.empty(10) b = Tensor.empty(10) From 813f713edcfeacd5e74d6e58020fe9112fb4d966 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 17 Mar 2025 17:15:44 +0800 Subject: [PATCH 23/26] merge_views for buffer ops + create valids last (#9472) * merge_views for buffer ops + create valids last * view.arg * pass --- test/test_schedule.py | 1 - tinygrad/engine/schedule.py | 6 +++--- tinygrad/ops.py | 10 ++++------ 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 605ceced3e..da866903c5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -98,7 +98,6 @@ class TestSchedule(unittest.TestCase): a.realize() assert not a.lazydata.is_realized - @unittest.expectedFailure def test_simplify_padded_const(self): a = Tensor.empty(1022).cummax(axis=0) sched = check_schedule(a, 5) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 8aa49f6112..ef19139859 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -345,6 +345,8 @@ add_buffer_ops = PatternMatcher([ # otherwise the store is contiguous (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()), + # VALID + (UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"), lambda x,view: x.valid(view.arg)), # if the last child is a VIEW we merge the ShapeTrackers and store the base (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))), lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)), @@ -366,8 +368,6 @@ def check_load_st(glbl:UOp, view:UOp): fix_kernel_ops = PatternMatcher([ # BIND in shapetracker becomes DEFINE_VAR (UPat(Ops.VIEW, name="x"), unbind_shapetracker), - # remove unmasked valid - (UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None), # no ImageDType after load (UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), # if this kernel also assigns to the loaded buffer, ensure we can index it correctly @@ -385,7 +385,7 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp: # unbind_vars + push views to edges ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right) # add buffer ops + fix_kernel_ops - ast = graph_rewrite(ast, view_left+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True) + ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True) if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}") # create subbuffer (TODO: this does not belong here) if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7e11fba53c..43b745a619 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -980,6 +980,9 @@ merge_views = PatternMatcher([ # merge unmasked const views (UPat(Ops.VIEW, name="v", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const"),)), lambda v,const: const.replace(src=(const.src[0].replace(arg=const.st+v.st),)) if all(x.mask is None for x in (const.st+v.st).views) else None), + # merge view on load/store/valid + (UPat(Ops.VIEW, name="v", src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="b"),)), + lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), # remove view if it's a contiguous and the shapes match (UPat(Ops.VIEW, name="v", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda v,x: x if v.arg.contiguous and x.shape == v.shape else None), # remove mask if there's a zero in the masked dim @@ -989,13 +992,8 @@ merge_views = PatternMatcher([ (UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)), ]) -# push VIEW to parents +# view before elementwise ops view_left = merge_views+PatternMatcher([ - # VIEW(CONST) becomes VALID - (UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)), - # VIEW before elementwise/buffer ops (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))), - (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)), - lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), ]) From 3b00a778ba8b8b0f62375f6024411325d89c4f04 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 17 Mar 2025 19:02:02 +0800 Subject: [PATCH 24/26] fix view_left for unsafe pad ops [pr] (#9478) --- test/test_schedule.py | 10 ++++++++++ tinygrad/engine/schedule.py | 2 +- tinygrad/ops.py | 5 ++++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index da866903c5..e09d8e3bd0 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1971,6 +1971,16 @@ class TestSwizzle(unittest.TestCase): t = a_reduce+b_reduce with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1)) + def test_unsafe_pad(self): + x = Tensor.full((2,2), 1.0).contiguous() + y = x*x.sum((1,)).reciprocal() + t = y.pad(((0,1),None)).contiguous() + swizzled = swizzle_rewrite(t.lazydata) + sched = check_schedule(swizzled.sink(), 3) + output_buffer = sched[-1].bufs[0] + run_schedule(sched) + self.assertListEqual(output_buffer.as_buffer().cast("f").tolist(), [0.5, 0.5, 0.5, 0.5, 0., 0.]) + def store_val(si:ScheduleItem): return si.ast.src[0].src[2] zero_pm = UPat(Ops.CONST, arg=0) class TestView(unittest.TestCase): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ef19139859..4fa1222620 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -113,7 +113,7 @@ sym = symbolic_simple+PatternMatcher([ # **** UOp realization -DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK} +DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS} @dataclass(frozen=True) class GrouperContext: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 43b745a619..a4545ba763 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -992,8 +992,11 @@ merge_views = PatternMatcher([ (UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)), ]) -# view before elementwise ops view_left = merge_views+PatternMatcher([ + # do not push masked view before unsafe pad ops + (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.UnsafePad, name="e"),)), + lambda e,vm: e.contiguous().view(vm.st) if any(v.mask is not None for v in vm.st.views) else None), + # view before elementwise ops (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))), ]) From e03c0aacf2180469c98e198d0f4463931876bac4 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 17 Mar 2025 20:43:21 +0800 Subject: [PATCH 25/26] more explicit DONT_PUSH_VIEWS [pr] (#9479) * more explicit DONT_PUSH_VIEWS [pr] * update tests to not handcode ast * lint * test_recursive_swizzle and test_simple_store_reshape --- test/test_schedule.py | 43 +++++++++++++------------------------ tinygrad/engine/schedule.py | 4 ++-- 2 files changed, 17 insertions(+), 30 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index e09d8e3bd0..d4116362db 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -14,7 +14,7 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.spec import type_verify, shape_spec -from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, all_same, temp +from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, sym from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis @@ -1857,44 +1857,31 @@ class TestIndexing(unittest.TestCase): def test_recursive_swizzle(self): a = Tensor([1,2,3,4]).realize() for _ in range(24): a = a + a - ast = a.schedule()[0].ast - swizzle = ast.src[0].src[2].reshape((4, 1)) - new_uop = swizzle_rewrite(swizzle) + new_uop = swizzle_rewrite(a.lazydata.reshape((4, 1))) self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1))) self.assertEqual(swizzle_cnt(new_uop), 0) def test_no_rewrite_elementwise(self): - bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)] - ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop())) - sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),)) - rsink = graph_rewrite(sink, view_right) - self.assertEqual(rsink.key, sink.key) + a = Tensor.empty(32, 32) + b = Tensor.empty(32, 32) + sink = (a+b).schedule()[0].ast + self.assertEqual(swizzle_cnt(sink), 0) def test_simple_store_reshape(self): - bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] - ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1))) - r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) - r = r + r.const_like(2).replace(src=(unwrap(r.st).to_uop(),)) - sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) - rsink = graph_rewrite(sink, view_right) - # this AST first needs to swizzle, but it doesn't have implicit movementops - self.assertEqual(swizzle_cnt(sink), 1) - verify_ast(rsink) + a = Tensor.empty(32, 32).sum(axis=1)+Tensor.empty(1,32) + ast = a.schedule()[0].ast + self.assertEqual(ast.shape, (32, 1)) + self.assertEqual(a.lazydata.shape, (1, 32)) def test_no_reshape_reduceop(self): - bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] - ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1))) - sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),)) - rsink = graph_rewrite(sink, view_right) - verify_ast(sink) - self.assertEqual(sink.key, rsink.key) + a = Tensor.empty(32, 32).sum(axis=(1,)).contiguous() + ast = a.schedule()[0].ast + self.assertEqual(ast.shape, (32, 1)) + self.assertEqual(a.lazydata.shape, (32,)) @track_rewrites(named=True) def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right) -def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0]) +def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op is not Ops.BUFFER]) class TestSwizzle(unittest.TestCase): def test_swizzle_simple(self): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 4fa1222620..362f50a377 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -113,7 +113,7 @@ sym = symbolic_simple+PatternMatcher([ # **** UOp realization -DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS} +DONT_PUSH_VIEWS = {Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR, Ops.DEVICE, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS, Ops.COPY} @dataclass(frozen=True) class GrouperContext: @@ -139,7 +139,7 @@ do_realize = PatternMatcher([ # always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW (UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize), # realize before expand or unsafe pad ops - (UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view), + (UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),)), realize_before_view), # realize before COPY (UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW}, name="tr"))), realize), ]) From f53be010d7445c422f894fc25a6e2f2171aa8eda Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 17 Mar 2025 10:49:56 -0400 Subject: [PATCH 26/26] lower bert learning rate (#9481) slightly better. first sub 3hr run https://wandb.ai/chenyuxyz/MLPerf-BERT/runs/0or96ink/overview --- examples/mlperf/model_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index d0f18bf487..aaf4e8f219 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -658,7 +658,7 @@ def train_bert(): # ** hyperparameters ** BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS)) EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS)) - max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.00018 * math.sqrt(BS/96)) + max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000175 * math.sqrt(BS/96)) train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3300000 // BS) warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)