mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
small changes from double gemm work
This commit is contained in:
@@ -12,7 +12,7 @@ from tinygrad.codegen.quantize import pm_quant
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing
|
||||
from tinygrad.uop.decompositions import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander
|
||||
from tinygrad.codegen.late.expander import migrate_indexing, expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
@@ -75,7 +75,7 @@ def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _Q
|
||||
ret.append(RewriteStep(pm_postrange_opt, ctx=lambda _: opts, name="post optimize ast"))
|
||||
|
||||
# ** expander (expand_rewrite) **
|
||||
ret.append(RewriteStep(sym+migrate_indexing, name="postopt symbolic"))
|
||||
ret.append(RewriteStep(sym+migrate_indexing+pm_group_for_reduce, name="postopt symbolic"))
|
||||
|
||||
# expand
|
||||
ret.append(RewriteStep(sym+pm_pre_expander+expander, name="expander"))
|
||||
|
||||
@@ -157,6 +157,9 @@ pm_pre_expander = PatternMatcher([
|
||||
# fix REDUCEs with UNROLLs
|
||||
(UPat(Ops.REDUCE, name="x"), fix_reduce_unroll),
|
||||
(UPat(Ops.STORE, name="x"), fix_store_unroll),
|
||||
])
|
||||
|
||||
pm_group_for_reduce = PatternMatcher([
|
||||
# fix group for reduce
|
||||
(UPat(Ops.REDUCE, name="x"), fix_group_for_reduce),
|
||||
])
|
||||
|
||||
@@ -320,7 +320,8 @@ class MetalRenderer(CStyleLanguage):
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
prefix = ["#include <metal_stdlib>","using namespace metal;"]
|
||||
for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops): prefix.append(
|
||||
deduped_wmma_args = dedup([(name, dtype_in, dtype_out) for name, _, dtype_in, dtype_out, _, _, _, _ in wmma_args(uops)])
|
||||
for name, dtype_in, dtype_out in deduped_wmma_args: prefix.append(
|
||||
f"""{(dstr_out:=self.render_dtype(dtype_out.vec(2)))} __{name}({(dstr_in:=self.render_dtype(dtype_in.vec(2)))} a, {dstr_in} b, {dstr_out} c){{
|
||||
simdgroup_{self.render_dtype(dtype_in)}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(dtype_out)}8x8 mat_c;
|
||||
mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0];
|
||||
|
||||
@@ -519,7 +519,7 @@ pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary)
|
||||
# BUFFERIZE returns the BUFFER ready for INDEXing (doing this will make splitting a lot easier)
|
||||
# NOTE: this has been fixed up a bit
|
||||
|
||||
def bufferize_to_store(x:UOp):
|
||||
def bufferize_to_store(x:UOp, allow_locals=True):
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([int(r.vmax+1) for r in rngs])
|
||||
size = prod(shape)
|
||||
@@ -552,6 +552,7 @@ def bufferize_to_store(x:UOp):
|
||||
return ret.replace(tag=x.tag)
|
||||
|
||||
# handle locals
|
||||
if not allow_locals: return None
|
||||
tag = x.arg.device
|
||||
if tag is None: tag = UOp.unique().arg # TODO: hack
|
||||
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag)
|
||||
@@ -559,14 +560,21 @@ def bufferize_to_store(x:UOp):
|
||||
# TODO: how is this unified?
|
||||
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
|
||||
|
||||
pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
||||
|
||||
_pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([
|
||||
# move RESHAPEs through MSELECT/MSTACK
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
||||
lambda m: m.replace(src=tuple([x.src[0] for x in m.src]), tag=None).reshape(m.src[0].arg).rtag(m.tag)),
|
||||
])
|
||||
|
||||
pm_add_buffers_nolocals = PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), lambda x: bufferize_to_store(x, allow_locals=False)),
|
||||
])+_pm_add_buffers
|
||||
|
||||
# all local bufferization should happen later
|
||||
pm_add_buffers = PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
||||
])+_pm_add_buffers
|
||||
|
||||
# *****************
|
||||
# 5. split into kernels
|
||||
|
||||
@@ -741,7 +749,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
|
||||
|
||||
# bufferize -> store
|
||||
tsink = graph_rewrite(tsink, pm_add_buffers, bottom_up=True, name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, pm_add_buffers_nolocals, bottom_up=True, name="bufferize to store")
|
||||
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
|
||||
|
||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
||||
|
||||
@@ -481,7 +481,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.MSTACK: return UOp(Ops.MSTACK, self.dtype, src=tuple(x.as_buf() for x in self.src))
|
||||
# TODO: this should be the only one of these. this is the one RANGEIFY uses
|
||||
s = self
|
||||
while len(s.src) and s.op not in {Ops.BUFFER, Ops.MSTACK}: s = s.src[0]
|
||||
while len(s.src) and s.op not in {Ops.BUFFER, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0]
|
||||
return s
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user