use AddrSpace instead of local (#11314)

* use AddrSpace instead of local

* addrspace in test
This commit is contained in:
George Hotz
2025-07-21 14:00:06 -07:00
committed by GitHub
parent d3a93185a6
commit 108aac8af4
9 changed files with 41 additions and 34 deletions

View File

@@ -7,6 +7,7 @@ from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops, UOp, Grou
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
from tinygrad.schedule.kernelize import merge_views
from tinygrad.shape.view import View
from tinygrad.dtype import AddrSpace
N = 4096
run_count = 5
@@ -62,8 +63,8 @@ def hand_spec():
bB = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2) # input B
# TODO: this should not be a string, just a number
lAs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_A_SZ, local=True), arg="As")
lBs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_B_SZ, local=True), arg="Bs")
lAs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_A_SZ, addrspace=AddrSpace.LOCAL), arg="As")
lBs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_B_SZ, addrspace=AddrSpace.LOCAL), arg="Bs")
s0 = ShapeTracker.from_shape((N, N, N), (N, 0, 1))
s1 = ShapeTracker.from_shape((N, N, N), (0, 1, N))

View File

@@ -1,6 +1,7 @@
from typing import List
import unittest, pytest
from tinygrad import dtypes, Variable
from tinygrad.dtype import AddrSpace
from tinygrad.helpers import DEBUG, Context
from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite, GroupOp
from tinygrad.uop.symbolic import sym
@@ -453,7 +454,7 @@ class TestUOpGraph(unittest.TestCase):
with Context(IGNORE_OOB=0):
# Define buffers
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, local=True), (), "temp0")
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
# Define indices, valids and barrier
gidx = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 416))
@@ -528,7 +529,7 @@ class TestUOpGraph(unittest.TestCase):
def test_fold_gated_load_local(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, local=True), (), "temp")
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
@@ -726,7 +727,7 @@ class TestExpander(unittest.TestCase):
class TestIFUOps(unittest.TestCase):
def test_create_ifs(self):
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=4, local=True), (), "smem")
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=4, addrspace=AddrSpace.LOCAL), (), "smem")
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2))
@@ -745,7 +746,7 @@ class TestIFUOps(unittest.TestCase):
def test_expand_ifs_one_gate(self):
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=16, local=True), (), "smem")
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=16, addrspace=AddrSpace.LOCAL), (), "smem")
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4))<1
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
gate = valid&(lidx.ne(2))

View File

@@ -5,7 +5,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View # noqa F401
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Timing
from tinygrad.dtype import dtypes, DType
from tinygrad.dtype import dtypes, DType, AddrSpace
from tinygrad.device import Buffer, Device
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401
from tinygrad.uop.spec import spec
@@ -303,7 +303,7 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_basic(self):
uops = []
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, local=True), (), 'smem')
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
@@ -313,7 +313,7 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Test local access with packed data type")
def test_local_packed(self):
uops = []
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, local=True), (), 'smem')
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
@@ -325,7 +325,7 @@ class TestLocalAccess(unittest.TestCase):
_dtypes = [dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.half]
size = 16
for dtype in _dtypes:
temp = UOp(Ops.DEFINE_LOCAL, dtype.ptr(size=size, local=True), (), 'smem')
temp = UOp(Ops.DEFINE_LOCAL, dtype.ptr(size=size, addrspace=AddrSpace.LOCAL), (), 'smem')
uops = to_uops_list([temp], opts=Device[Device.DEFAULT].renderer)
out = Device[Device.DEFAULT].renderer.render(uops)
# half is supported in wgsl, so it doesn't have to be packed
@@ -336,7 +336,7 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skip("tinygrad doesn't support this behavior")
def test_local_indirect(self):
uops = []
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(size=16, local=True), (), 'smem')
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
st1 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 1)), uop(uops, Ops.CONST, dtypes.int32, (), 2)))
st2 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 2)), uop(uops, Ops.CONST, dtypes.int32, (), 42)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st1,st2))

View File

@@ -94,7 +94,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
for grp in grouped_offsets:
# get the index offset for this element. using [0] is okay, because they are the same
lidx = midx.src[offsets[grp[0]][0]]
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
# set the idxs of the output
for i,g in enumerate(grp):
for oo in offsets[g]: idxs[oo] = global_offset+i
@@ -103,7 +103,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
global_offset += len(grp)
assert None not in idxs, f"some idxs are missing {idxs}"
# this base thing is for image, we want the CAT to be a normal pointer
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret))
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
return post_cat.gep(tuple(cast(list[int], idxs)))
def cat_after_store(cat:UOp, data:UOp):
@@ -224,7 +224,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
for fold_length in lengths:
if global_offset+fold_length > sz: continue
lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None)
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
global_offset += fold_length

View File

@@ -1,7 +1,7 @@
# the job of the lowerer is to do indexing
from dataclasses import dataclass
from typing import cast
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType
from tinygrad.helpers import prod, partition, flatten
@@ -53,7 +53,7 @@ def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.idxs)
if not cast(PtrDType, buf.dtype).local:
if cast(PtrDType, buf.dtype).addrspace == AddrSpace.GLOBAL:
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx is not ridx: valid = valid * oidx.eq(0)

View File

@@ -3,6 +3,7 @@ from typing import Final, ClassVar, Callable, Literal
import math, struct, ctypes, functools
from dataclasses import dataclass, fields
from tinygrad.helpers import getenv, prod
from enum import Enum, auto
ConstType = float|int|bool
@@ -16,6 +17,8 @@ class DTypeMetaClass(type):
DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
return ret
class AddrSpace(Enum): GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702
@dataclass(frozen=True, eq=False)
class DType(metaclass=DTypeMetaClass):
priority: int # this determines when things get upcasted
@@ -38,8 +41,8 @@ class DType(metaclass=DTypeMetaClass):
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
def ptr(self, size=-1, local=False) -> PtrDType:
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1, size)
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
def nbytes(self): raise RuntimeError("only ptr types have nbytes")
@property
@@ -50,7 +53,7 @@ class DType(metaclass=DTypeMetaClass):
@dataclass(frozen=True, eq=False)
class PtrDType(DType):
_base: DType
local: bool
addrspace: AddrSpace
v: int
size: int = -1 # -1 is unlimited size
@property
@@ -60,22 +63,23 @@ class PtrDType(DType):
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
if sz == 1: return self # sz=1 is a scalar
if isinstance(self, ImageDType):
return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size, self.shape)
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size)
def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape)
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size)
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL): raise RuntimeError("can't make a pointer from a pointer")
def nbytes(self) -> int:
if self.size == -1: return 0 # TODO: this should be an exception
return self.size*self.itemsize
@property
def vcount(self): return self.v
def __repr__(self):
return f"{self.base.__repr__()}.ptr({self.size}{', local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
return f"{self.base.__repr__()}.ptr({self.size}{', addrspace='+str(self.addrspace) if self.addrspace != AddrSpace.GLOBAL else ''})" + \
(f'.vec({self.v})' if self.v != 1 else '')
@dataclass(frozen=True, eq=False)
class ImageDType(PtrDType):
shape: tuple[int, ...] = () # shape of the Image
def ptr(self, size=-1, local=False) -> PtrDType:
assert not local, "images can't be local"
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
assert addrspace == AddrSpace.GLOBAL, "images can't be local"
return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
@@ -149,9 +153,9 @@ class dtypes:
# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, prod(shp), shp)
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, prod(shp), shp)
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32

View File

@@ -10,7 +10,7 @@ from tinygrad.uop.spec import type_verify, ast_spec
from tinygrad.device import Device
from tinygrad.opt.tc import TensorCore
from tinygrad.renderer import Renderer
from tinygrad.dtype import ImageDType
from tinygrad.dtype import ImageDType, AddrSpace
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape, get_contraction
@@ -495,7 +495,7 @@ class Kernel:
local_src_shape = tuple([self.full_shape[i] if i in self.axes_of(AxisType.GLOBAL) else s for i,s in enumerate(local_shape)])
st = ShapeTracker.from_shape(base_shape).permute(permute_axes).reshape(local_shape).expand(local_src_shape)
local_size = st.real_size()
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, addrspace=AddrSpace.LOCAL), (), f"temp{self.reduceops.index(op)}")
local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
if op is self.reduceops[-1]: return grouped_reduce

View File

@@ -4,7 +4,7 @@ from collections import defaultdict, Counter
from tinygrad.opt import tc
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace
from tinygrad.renderer import Renderer
from tinygrad.codegen.devectorizer import no_vectorized_alu
@@ -119,7 +119,8 @@ class CStyleLanguage(Renderer):
def render_dtype(self, dt:DType, mutable=True) -> str:
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
if isinstance(dt, PtrDType):
return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*"
prefix = self.smem_prefix if dt.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast else self.buffer_prefix
return prefix + self.render_dtype(dt.base) + "*"
if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
return self.type_map.get(scalar:=dt.scalar(), scalar.name)

View File

@@ -1,6 +1,6 @@
from typing import cast, Callable
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, resolve
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace
from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context
try:
import z3
@@ -128,8 +128,8 @@ index_pat = UPat(Ops.INDEX, name="idx").or_casted()
# this is the matcher for the final rendered UOps
# matcher functions returns True or False (or None to not match)
spec = PatternMatcher([
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
(UPat(Ops.DEFINE_REG, src=(UPat.var("c"),), name="x", allow_any_len=True),
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),