mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
use AddrSpace instead of local (#11314)
* use AddrSpace instead of local * addrspace in test
This commit is contained in:
@@ -7,6 +7,7 @@ from tinygrad.uop.ops import graph_rewrite, PatternMatcher, UPat, Ops, UOp, Grou
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
|
||||
from tinygrad.schedule.kernelize import merge_views
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.dtype import AddrSpace
|
||||
|
||||
N = 4096
|
||||
run_count = 5
|
||||
@@ -62,8 +63,8 @@ def hand_spec():
|
||||
bB = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2) # input B
|
||||
|
||||
# TODO: this should not be a string, just a number
|
||||
lAs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_A_SZ, local=True), arg="As")
|
||||
lBs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_B_SZ, local=True), arg="Bs")
|
||||
lAs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_A_SZ, addrspace=AddrSpace.LOCAL), arg="As")
|
||||
lBs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(LDS_B_SZ, addrspace=AddrSpace.LOCAL), arg="Bs")
|
||||
|
||||
s0 = ShapeTracker.from_shape((N, N, N), (N, 0, 1))
|
||||
s1 = ShapeTracker.from_shape((N, N, N), (0, 1, N))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List
|
||||
import unittest, pytest
|
||||
from tinygrad import dtypes, Variable
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.helpers import DEBUG, Context
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite, GroupOp
|
||||
from tinygrad.uop.symbolic import sym
|
||||
@@ -453,7 +454,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
with Context(IGNORE_OOB=0):
|
||||
# Define buffers
|
||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0)
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, local=True), (), "temp0")
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
|
||||
|
||||
# Define indices, valids and barrier
|
||||
gidx = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 416))
|
||||
@@ -528,7 +529,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, local=True), (), "temp")
|
||||
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(size=18, addrspace=AddrSpace.LOCAL), (), "temp")
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem.index(lidx), UOp.load(glbl0.index(lidx), dtype=dtypes.int)))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
||||
@@ -726,7 +727,7 @@ class TestExpander(unittest.TestCase):
|
||||
class TestIFUOps(unittest.TestCase):
|
||||
def test_create_ifs(self):
|
||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=4, local=True), (), "smem")
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=4, addrspace=AddrSpace.LOCAL), (), "smem")
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10))<5
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gate = valid&(lidx.ne(2))
|
||||
@@ -745,7 +746,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
|
||||
def test_expand_ifs_one_gate(self):
|
||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=16, local=True), (), "smem")
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(size=16, addrspace=AddrSpace.LOCAL), (), "smem")
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4))<1
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
gate = valid&(lidx.ne(2))
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View # noqa F401
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Timing
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.dtype import dtypes, DType, AddrSpace
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401
|
||||
from tinygrad.uop.spec import spec
|
||||
@@ -303,7 +303,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
def test_local_basic(self):
|
||||
uops = []
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, local=True), (), 'smem')
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
|
||||
@@ -313,7 +313,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Test local access with packed data type")
|
||||
def test_local_packed(self):
|
||||
uops = []
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, local=True), (), 'smem')
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), uop(uops, Ops.CONST, dtypes.uint8, (), 42)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.uint8, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 0)), barr))
|
||||
@@ -325,7 +325,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
_dtypes = [dtypes.char, dtypes.uchar, dtypes.short, dtypes.ushort, dtypes.half]
|
||||
size = 16
|
||||
for dtype in _dtypes:
|
||||
temp = UOp(Ops.DEFINE_LOCAL, dtype.ptr(size=size, local=True), (), 'smem')
|
||||
temp = UOp(Ops.DEFINE_LOCAL, dtype.ptr(size=size, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
uops = to_uops_list([temp], opts=Device[Device.DEFAULT].renderer)
|
||||
out = Device[Device.DEFAULT].renderer.render(uops)
|
||||
# half is supported in wgsl, so it doesn't have to be packed
|
||||
@@ -336,7 +336,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skip("tinygrad doesn't support this behavior")
|
||||
def test_local_indirect(self):
|
||||
uops = []
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(size=16, local=True), (), 'smem')
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(size=16, addrspace=AddrSpace.LOCAL), (), 'smem')
|
||||
st1 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 1)), uop(uops, Ops.CONST, dtypes.int32, (), 2)))
|
||||
st2 = uop(uops, Ops.STORE, dtypes.void, (smem.index(uop(uops, Ops.CONST, dtypes.int32, (), 2)), uop(uops, Ops.CONST, dtypes.int32, (), 42)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st1,st2))
|
||||
|
||||
@@ -94,7 +94,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
for grp in grouped_offsets:
|
||||
# get the index offset for this element. using [0] is okay, because they are the same
|
||||
lidx = midx.src[offsets[grp[0]][0]]
|
||||
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
||||
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
|
||||
# set the idxs of the output
|
||||
for i,g in enumerate(grp):
|
||||
for oo in offsets[g]: idxs[oo] = global_offset+i
|
||||
@@ -103,7 +103,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
global_offset += len(grp)
|
||||
assert None not in idxs, f"some idxs are missing {idxs}"
|
||||
# this base thing is for image, we want the CAT to be a normal pointer
|
||||
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret))
|
||||
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
|
||||
return post_cat.gep(tuple(cast(list[int], idxs)))
|
||||
|
||||
def cat_after_store(cat:UOp, data:UOp):
|
||||
@@ -224,7 +224,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
for fold_length in lengths:
|
||||
if global_offset+fold_length > sz: continue
|
||||
lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None)
|
||||
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
||||
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
|
||||
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
||||
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
|
||||
global_offset += fold_length
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# the job of the lowerer is to do indexing
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.dtype import dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType
|
||||
from tinygrad.helpers import prod, partition, flatten
|
||||
|
||||
@@ -53,7 +53,7 @@ def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
|
||||
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.idxs)
|
||||
if not cast(PtrDType, buf.dtype).local:
|
||||
if cast(PtrDType, buf.dtype).addrspace == AddrSpace.GLOBAL:
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx is not ridx: valid = valid * oidx.eq(0)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Final, ClassVar, Callable, Literal
|
||||
import math, struct, ctypes, functools
|
||||
from dataclasses import dataclass, fields
|
||||
from tinygrad.helpers import getenv, prod
|
||||
from enum import Enum, auto
|
||||
|
||||
ConstType = float|int|bool
|
||||
|
||||
@@ -16,6 +17,8 @@ class DTypeMetaClass(type):
|
||||
DTypeMetaClass.dcache[args] = ret = super().__call__(*args)
|
||||
return ret
|
||||
|
||||
class AddrSpace(Enum): GLOBAL = auto(); LOCAL = auto(); REG = auto() # noqa: E702
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class DType(metaclass=DTypeMetaClass):
|
||||
priority: int # this determines when things get upcasted
|
||||
@@ -38,8 +41,8 @@ class DType(metaclass=DTypeMetaClass):
|
||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
||||
def ptr(self, size=-1, local=False) -> PtrDType:
|
||||
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1, size)
|
||||
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
|
||||
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
|
||||
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
|
||||
def nbytes(self): raise RuntimeError("only ptr types have nbytes")
|
||||
@property
|
||||
@@ -50,7 +53,7 @@ class DType(metaclass=DTypeMetaClass):
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class PtrDType(DType):
|
||||
_base: DType
|
||||
local: bool
|
||||
addrspace: AddrSpace
|
||||
v: int
|
||||
size: int = -1 # -1 is unlimited size
|
||||
@property
|
||||
@@ -60,22 +63,23 @@ class PtrDType(DType):
|
||||
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
||||
if sz == 1: return self # sz=1 is a scalar
|
||||
if isinstance(self, ImageDType):
|
||||
return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size, self.shape)
|
||||
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz, self.size)
|
||||
def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
|
||||
return ImageDType(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size, self.shape)
|
||||
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.addrspace, sz, self.size)
|
||||
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL): raise RuntimeError("can't make a pointer from a pointer")
|
||||
def nbytes(self) -> int:
|
||||
if self.size == -1: return 0 # TODO: this should be an exception
|
||||
return self.size*self.itemsize
|
||||
@property
|
||||
def vcount(self): return self.v
|
||||
def __repr__(self):
|
||||
return f"{self.base.__repr__()}.ptr({self.size}{', local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
|
||||
return f"{self.base.__repr__()}.ptr({self.size}{', addrspace='+str(self.addrspace) if self.addrspace != AddrSpace.GLOBAL else ''})" + \
|
||||
(f'.vec({self.v})' if self.v != 1 else '')
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class ImageDType(PtrDType):
|
||||
shape: tuple[int, ...] = () # shape of the Image
|
||||
def ptr(self, size=-1, local=False) -> PtrDType:
|
||||
assert not local, "images can't be local"
|
||||
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
|
||||
assert addrspace == AddrSpace.GLOBAL, "images can't be local"
|
||||
return self
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
|
||||
|
||||
@@ -149,9 +153,9 @@ class dtypes:
|
||||
|
||||
# NOTE: these are image dtypes
|
||||
@staticmethod
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, prod(shp), shp)
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
|
||||
@staticmethod
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, prod(shp), shp)
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
|
||||
|
||||
default_float: ClassVar[DType] = float32
|
||||
default_int: ClassVar[DType] = int32
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.uop.spec import type_verify, ast_spec
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.opt.tc import TensorCore
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.dtype import ImageDType, AddrSpace
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape, get_contraction
|
||||
@@ -495,7 +495,7 @@ class Kernel:
|
||||
local_src_shape = tuple([self.full_shape[i] if i in self.axes_of(AxisType.GLOBAL) else s for i,s in enumerate(local_shape)])
|
||||
st = ShapeTracker.from_shape(base_shape).permute(permute_axes).reshape(local_shape).expand(local_src_shape)
|
||||
local_size = st.real_size()
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, addrspace=AddrSpace.LOCAL), (), f"temp{self.reduceops.index(op)}")
|
||||
local_load = local_buffer.view(st).load(local_buffer.view(st).store(ret))
|
||||
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict, Counter
|
||||
from tinygrad.opt import tc
|
||||
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.devectorizer import no_vectorized_alu
|
||||
|
||||
@@ -119,7 +119,8 @@ class CStyleLanguage(Renderer):
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str:
|
||||
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
||||
if isinstance(dt, PtrDType):
|
||||
return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*"
|
||||
prefix = self.smem_prefix if dt.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast else self.buffer_prefix
|
||||
return prefix + self.render_dtype(dt.base) + "*"
|
||||
if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
|
||||
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import cast, Callable
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, resolve
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.helpers import all_same, prod, DEBUG, ContextVar, Context
|
||||
try:
|
||||
import z3
|
||||
@@ -128,8 +128,8 @@ index_pat = UPat(Ops.INDEX, name="idx").or_casted()
|
||||
# this is the matcher for the final rendered UOps
|
||||
# matcher functions returns True or False (or None to not match)
|
||||
spec = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and x.dtype.addrspace == AddrSpace.GLOBAL),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.addrspace == AddrSpace.LOCAL),
|
||||
(UPat(Ops.DEFINE_REG, src=(UPat.var("c"),), name="x", allow_any_len=True),
|
||||
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
|
||||
Reference in New Issue
Block a user