mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
@@ -6,11 +6,11 @@ from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callab
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \
|
||||
graph_rewrite, track_rewrites, UPat
|
||||
graph_rewrite, track_rewrites
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, ProgramSpec
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put, unwrap
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put
|
||||
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
@@ -607,83 +607,97 @@ class Kernel:
|
||||
return name + colored(num, 'BLACK')
|
||||
|
||||
def get_optimized_ast(self) -> UOp:
|
||||
# set the shapetrackers to the optimized ones, fixup reduceop
|
||||
# transformed to the final UOp
|
||||
@functools.lru_cache(None)
|
||||
def fixup_ast(op:UOp) -> UOp:
|
||||
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
|
||||
if op.op in GroupOp.Buffer and op in self.bufs:
|
||||
st_uop = self.sts[self.bufs.index(op)].to_uop()
|
||||
return ret.replace(src=(st_uop,)) if op.op is Ops.VALID else ret.replace(src=(ret.src[0], st_uop, *ret.src[2:]))
|
||||
if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals))
|
||||
def fixup_ast(op:UOp, apply_to_st=None) -> UOp:
|
||||
arg = op.arg
|
||||
if op.op in GroupOp.Buffer:
|
||||
# for locals, we use the ShapeTracker that's in the srcs
|
||||
st = op.st_arg if op.src[0].op is Ops.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
|
||||
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
|
||||
if op.op is Ops.VALID: return op.replace(src=(st_uop,))
|
||||
if op.op is Ops.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
|
||||
return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
|
||||
if op.op is Ops.REDUCE_AXIS:
|
||||
reduce_idx = len(self.bufs) + self.reduceops.index(op) * 2
|
||||
reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
|
||||
alu_op: Ops = op.arg[0]
|
||||
axis = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len)
|
||||
if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i]))
|
||||
if op in self.bufs_for_tensor_core and (tc := self.tensor_core):
|
||||
rsrc = op.src[0]
|
||||
if rsrc.op is Ops.CAST: rsrc = rsrc.src[0]
|
||||
assert rsrc.op is Ops.MUL
|
||||
|
||||
def reduced_axes(start, stop):
|
||||
return tuple(i for i in range(start, stop) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx + 1].shape[i]))
|
||||
axes = reduced_axes(self.first_reduce + self.group_for_reduces, self.shape_len)
|
||||
grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
|
||||
def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
|
||||
wd, tcd = self.global_dims, self.first_upcast
|
||||
assert st1.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st1.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
|
||||
assert st1.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st1.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
|
||||
new_shape = st1.shape[:tcd] + tcd_expand + st1.shape[tcd+len(tcd_dims):] # expand the tcd
|
||||
permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in pattern_1] + list(range(wd+len(warp_dims), tcd)) + \
|
||||
[y + (wd if x == 0 else tcd) for x,y in pattern_2] + list(range(tcd+len(tcd_expand), len(new_shape)))
|
||||
return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape).simplify()
|
||||
|
||||
if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
|
||||
def fix_st(st: ShapeTracker, wd_pattern, tcd_pattern):
|
||||
wd, warp_dims = self.global_dims, tuple(sz for _, sz in tc.threads)
|
||||
tcd, tcd_dims = self.first_upcast, tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes)
|
||||
warp_dims = tuple(sz for _, sz in tc.threads)
|
||||
tcd_dims = tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes)
|
||||
fix_st1 = functools.partial(fix_st, warp_dims, tcd_dims, tc.expanded_shape, *tc.st1_pattern) if tc.st1_pattern else None
|
||||
fix_st2 = functools.partial(fix_st, warp_dims, tcd_dims, tc.expanded_shape, *tc.st2_pattern) if tc.st2_pattern else None
|
||||
|
||||
assert st.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
|
||||
assert st.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
|
||||
assert tc.expanded_shape is not None
|
||||
|
||||
new_shape = st.shape[:tcd] + tc.expanded_shape + st.shape[tcd+len(tcd_dims):] # expand the tcd
|
||||
permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in wd_pattern] + list(range(wd+len(warp_dims),tcd)) + \
|
||||
[y + (wd if x == 0 else tcd) for x,y in tcd_pattern] + list(range(tcd+len(tc.expanded_shape),len(new_shape)))
|
||||
return st.reshape(new_shape).permute(tuple(permaxis)).reshape(st.shape).simplify()
|
||||
|
||||
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
||||
for i, tc_pattern in enumerate([tc.st1_pattern, tc.st2_pattern]):
|
||||
if tc_pattern: srcs[i] = srcs[i].view(fix_st(unwrap(srcs[i].st), *tc_pattern))
|
||||
|
||||
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
|
||||
local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
|
||||
st = store_st = ShapeTracker.from_shape(local_shape)
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{i + 1}", st.real_size()))
|
||||
if tc_pattern: store_st = fix_st(store_st, *tc_pattern)
|
||||
local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
|
||||
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
|
||||
|
||||
tc_reduce_axes = tuple(self.first_upcast + ax for ax, _ in tc.reduce_axes)
|
||||
if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/EXPAND to get the vectorization right
|
||||
upcast_axes = tuple(tuple((self.first_upcast + ax, sz) for ax, sz in up) for up in tc.upcast_axes)
|
||||
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, prod(sz for _, sz in tc.threads), upcast_axes, tc_reduce_axes)
|
||||
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
|
||||
assert apply_to_st is None, "double tensor core? not supported"
|
||||
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, prod(t[1] for t in tc.threads),
|
||||
tuple(tuple((self.first_upcast+ax,sz) for ax,sz in up) for up in tc.upcast_axes), tuple(self.first_upcast+ax for ax,_ in tc.reduce_axes))
|
||||
if self.use_tensor_cores >= 2:
|
||||
if self.use_tensor_cores == 3:
|
||||
# TC=3, emulate the warp addressing with locals
|
||||
ex_shape = tuple(1 if i < self.global_dims or (i >= self.first_reduce and i < self.first_upcast) else s \
|
||||
for i,s in enumerate(self.full_shape))
|
||||
srcs = []
|
||||
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
|
||||
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is Ops.LOAD]
|
||||
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
|
||||
st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop()
|
||||
membuf = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
|
||||
local_store = fixup_ast(UOp(Ops.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
|
||||
srcs.append(UOp(Ops.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
|
||||
else:
|
||||
# for TC=2, we can't do the shapetracker fixup
|
||||
srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])]
|
||||
# MUL/SUM instead of WMMA
|
||||
ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(Ops.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
|
||||
else:
|
||||
# real WMMA, use CONTRACT/EXPAND to get the vectorization right
|
||||
wmma_upcast_axes = wmma_arg[-2]
|
||||
wmma_sz = [prod(x[1] for x in l) for l in wmma_upcast_axes]
|
||||
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
|
||||
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(wmma_sz[0]), src=(srcs[0],), arg=upcast_axes[0]),
|
||||
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(wmma_sz[1]), src=(srcs[1],), arg=upcast_axes[1]),
|
||||
UOp(Ops.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]),
|
||||
UOp(Ops.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]),
|
||||
UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
|
||||
tc_uop = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=upcast_axes[2])
|
||||
|
||||
else: # for TC=3 MUL/SUM instead of WMMA
|
||||
tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes))
|
||||
|
||||
new_reduce_axes = tuple(i for i in axes if i not in tc_reduce_axes)
|
||||
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_reduce_axes)) if new_reduce_axes else tc_uop
|
||||
|
||||
ret = ret.replace(arg = (op.arg[0], axes))
|
||||
if self.group_for_reduces and grouped_axes:
|
||||
ret = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2])
|
||||
new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in [ax for ax, _ in tc.reduce_axes])
|
||||
return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret
|
||||
if self.group_for_reduces:
|
||||
start = UOp(Ops.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(alu_op, axis))
|
||||
second_axis = tuple(i for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces) \
|
||||
if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i])
|
||||
# NOTE: if there's a grouped reduce, but no reduce axes for this reduce, we can skip it
|
||||
if len(second_axis) == 0: return start
|
||||
local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims] + \
|
||||
tuple([self.full_shape[i] if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i] else 1 \
|
||||
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
||||
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
||||
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
|
||||
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret)))
|
||||
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
||||
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
|
||||
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
st_uop = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)])).to_uop()
|
||||
st_uop = ShapeTracker.from_shape(tuple([1 if i in second_axis else a for i,a in enumerate(local_shape)])).to_uop()
|
||||
return UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce)))
|
||||
|
||||
return ret
|
||||
|
||||
return graph_rewrite(fixup_ast(self.ast), PatternMatcher([
|
||||
(UPat({*GroupOp.ALU,Ops.CAST,Ops.BITCAST,Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
|
||||
(UPat(Ops.LOAD, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src)))]))
|
||||
arg = (alu_op, axis)
|
||||
elif op.op is Ops.SINK:
|
||||
arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)
|
||||
return op.replace(src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg)
|
||||
# NOTE: rewrite with an empty PatternMatcher to dedup UOps
|
||||
return graph_rewrite(fixup_ast(self.ast), PatternMatcher([]))
|
||||
|
||||
# **** this is the lowerer ****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user