mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -1,5 +1,5 @@
|
||||
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.schedule.kernelize import merge_views
|
||||
@@ -40,7 +40,11 @@ def hl_spec_kernel3():
|
||||
A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape)
|
||||
B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape)
|
||||
out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
|
||||
sink = c.store(out).sink(arg=KernelInfo(name="tinygemm"))
|
||||
axis_types = [
|
||||
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
|
||||
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
|
||||
AxisType.REDUCE, AxisType.UNROLL]
|
||||
sink = c.store(out).sink(arg=KernelInfo(name="tinygemm", axis_types=tuple(axis_types)))
|
||||
sink = graph_rewrite(sink, merge_views)
|
||||
return sink
|
||||
|
||||
|
||||
@@ -122,8 +122,8 @@ def gep_on_store(gep:UOp, st:UOp, sto:UOp):
|
||||
return gep.src[0].store(st.gep(new_arg), *sto.src[2:])
|
||||
|
||||
load_store_folding = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"))), expand_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL), name="buf")), UPat.var("vec"),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"))), expand_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.VECTORIZE, src=UPat(GroupOp.Defines, name="buf")), UPat.var("vec"),
|
||||
UPat.var("mask"))), expand_index),
|
||||
# GEP after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.GEP, name="gep"),), name="ld", allow_any_len=True),
|
||||
@@ -154,6 +154,8 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
must_divide = False
|
||||
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
elif cast(PtrDType, buf.dtype).addrspace == AddrSpace.REG:
|
||||
pass
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
lengths = [4]
|
||||
elif ctx is not None and ctx.supports_float4:
|
||||
|
||||
@@ -86,9 +86,6 @@ expander = PatternMatcher([
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX,
|
||||
Ops.VECTORIZE, Ops.IF, Ops.REDUCE), name="root", custom_early_reject=set([Ops.UNROLL])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# vectorize DEFINE_ACC
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_REG, name="acc"), name="v"),
|
||||
lambda acc,v: acc.replace(dtype=v.dtype, src=(acc.src[0].broadcast(v.dtype.count),)+acc.src[1:])),
|
||||
# BARRIERs aren't actually expanded
|
||||
(UPat(Ops.BARRIER, src=(UPat(Ops.UNROLL, name="ex"),)),
|
||||
lambda ex: UOp(Ops.UNROLL, src=(UOp(Ops.BARRIER, src=ex.src),)*len(ex.src), arg=ex.arg)),
|
||||
|
||||
@@ -48,7 +48,7 @@ def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
|
||||
def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if buf.op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[1],)),) if buf.op is Ops.DEFINE_LOCAL else ()
|
||||
barrier = tuple([y.barrier() if buf.op is Ops.DEFINE_LOCAL else y for y in x.src[1:]])
|
||||
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
|
||||
|
||||
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
|
||||
@@ -118,7 +118,9 @@ class CStyleLanguage(Renderer):
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str:
|
||||
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
||||
if isinstance(dt, PtrDType):
|
||||
prefix = self.smem_prefix if dt.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast else self.buffer_prefix
|
||||
prefix = ""
|
||||
if dt.addrspace == AddrSpace.LOCAL and self.smem_prefix_for_cast: prefix = self.smem_prefix
|
||||
if dt.addrspace == AddrSpace.GLOBAL: prefix = self.buffer_prefix
|
||||
return prefix + self.render_dtype(dt.base) + "*"
|
||||
if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
|
||||
return self.type_map.get(scalar:=dt.scalar(), scalar.name)
|
||||
|
||||
@@ -83,6 +83,8 @@ class GroupOp:
|
||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||
ALU = set.union(Unary, Binary, Ternary)
|
||||
|
||||
Defines = {Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG}
|
||||
|
||||
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
||||
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
||||
|
||||
|
||||
@@ -238,6 +238,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, dtype=kwargs.pop("dtype", self.dtype.base), src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self, x))
|
||||
def barrier(self): return UOp(Ops.BARRIER, src=(self,))
|
||||
def alu(self, op, *src:UOp, **kwargs):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
if op in {Ops.CMPLT, Ops.CMPNE}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
|
||||
@@ -144,18 +144,18 @@ spec = PatternMatcher([
|
||||
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
|
||||
# early LOAD has a <bufview, store?>
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),)),)), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),)), UPat(Ops.STORE))), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)),)), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat(Ops.STORE))), lambda: True),
|
||||
|
||||
# early STORE has a <bufview, val>
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),)), UPat())), lambda: True),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat())), lambda: True),
|
||||
|
||||
# **** new style load/store ****
|
||||
|
||||
# INDEX is used in new style load/store
|
||||
# INDEX takes a <buf, alu, gate?>
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), UPat())), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat())), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
|
||||
# LOAD on STORE
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),), allow_any_len=True), lambda: True),
|
||||
@@ -165,8 +165,8 @@ spec = PatternMatcher([
|
||||
(UPat(Ops.LOAD, src=(index_pat,), allow_any_len=True), validate_index),
|
||||
|
||||
# STORE takes a <bufidx, val, gate?>
|
||||
(UPat(Ops.STORE, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store),
|
||||
(UPat(Ops.STORE, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store),
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(index_pat, UPat(name="val"), UPat(Ops.IF, name="gate")), allow_any_len=True), validate_store),
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(index_pat, UPat(name="val")), allow_any_len=True), validate_store),
|
||||
|
||||
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
||||
(UPat(Ops.WHERE, name="w", src=(UPat(dtype=dtypes.bool), UPat.var("x"), UPat.var("y"))), lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
||||
|
||||
Reference in New Issue
Block a user