From ba118abfec400cd4b78b40a2ea57569cb4656205 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Thu, 4 Apr 2024 16:33:48 +0200 Subject: [PATCH] improved caching for pointer arithmetics in ptx (#3922) * improved caching for pointer arithmetics * Add test for pointer arithmetics caching * Refactor test --- test/test_uops.py | 22 ++++++++++++++++++++++ tinygrad/renderer/assembly.py | 32 ++++++++++++++++++++------------ 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/test/test_uops.py b/test/test_uops.py index ef5f0f27d7..ea9d96aa59 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple, Any, List import unittest, math import numpy as np from tinygrad.tensor import Tensor +from tinygrad.helpers import getenv from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.device import Buffer, Device, CompiledASTRunner from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps @@ -232,5 +233,26 @@ class TestLocalAccess(unittest.TestCase): sres = uop(uops, UOps.LOAD, dtypes.int32, (smem, ofs)) self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42) +@unittest.skipUnless(Device.DEFAULT in {"CUDA"} and getenv("PTX"), "This only tests assembly backends") +class TestAssembly(unittest.TestCase): + def test_pointer_arithmetics_caching(self): + uops = UOpGraph() + u1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple(), (0, 'data0', True)) + u2 = uops.add(UOps.SPECIAL, dtypes.int, tuple(), (0, 'gidx0', 9)) + u3 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=42) + u4 = uops.add(UOps.ALU, dtypes.int, (u2, u3), BinaryOps.MUL) + u5 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0) + u6 = uops.add(UOps.CONST, dtypes.int, tuple(), arg=1) + u7 = uops.add(UOps.ALU, dtypes.int, (u4, u5), BinaryOps.ADD) + u8 = uops.add(UOps.ALU, dtypes.int, (u4, u6), BinaryOps.ADD) + u9 = uops.add(UOps.LOAD, dtypes.int, (u1, u7)) + u10 = uops.add(UOps.LOAD, dtypes.int, (u1, u8)) + _uops_to_prg(uops) + self.assertEqual(u9.vin[0], u10.vin[0]) + self.assertEqual(u9.vin[1].uop, UOps.CONST) + self.assertEqual(u9.vin[1].arg, u5.arg*dtypes.float.itemsize) + self.assertEqual(u10.vin[1].uop, UOps.CONST) + self.assertEqual(u10.vin[1].arg, u6.arg*dtypes.float.itemsize) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index d70007096a..883dbe7492 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -13,6 +13,26 @@ def render_val(x, dtype): return "0f%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "") +def ptr_ar(root, uops): + assert root.arg in {'.shared', '.global', None} + if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL + val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root)) + if root.vin[1].uop is UOps.ALU and root.vin[1].arg in [BinaryOps.ADD, BinaryOps.SUB] and root.vin[1].vin[1].uop is UOps.CONST: + offset = uops.add(UOps.ALU, dtypes.int, (root.vin[1].vin[0], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root)) + offset = uops.add(UOps.CAST, dtypes.uint64, (offset,), insert_before=uops.uops.index(root)) + cache = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], offset), arg=BinaryOps.ADD, insert_before=uops.uops.index(root)) + ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1].vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root)) + if root.vin[1].arg == BinaryOps.SUB: ptr = uops.add(UOps.ALU, dtypes.int, (ptr,), arg=UnaryOps.NEG, insert_before=uops.uops.index(root)) + root.vin = (cache, ptr) + root.vin[2:] + else: + ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root)) + if ptr.uop is UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:] + else: + zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root)) + bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root)) + fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root)) + root.vin = (fptr, zero) + root.vin[2:] + class AssemblyLanguage(NamedTuple): kernel_prefix: str = "" barrier: str = "" @@ -43,18 +63,6 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: kernel:List[str] = [] bufs = [] - def ptr_ar(root, uops): - assert root.arg in {'.shared', '.global', None} - if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL - val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root)) - ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root)) - if ptr.uop is UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:] - else: - zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root)) - bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root)) - fptr = uops.add(UOps.ALU, dtypes.uint64, (root.vin[0], bptr), arg=BinaryOps.ADD, insert_before=uops.uops.index(root)) - root.vin = (fptr, zero) + root.vin[2:] - matcher = PatternMatcher([ ({"__name__": "root", "uop": UOps.ALU, "arg": BinaryOps.CMPEQ, "vin": ({"dtype": dtypes.bool},{})}, lambda root: UOp(UOps.ALU, dtypes.bool, (UOp(root.uop, root.dtype, root.vin, BinaryOps.XOR),), UnaryOps.NEG)),