mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
tag TC
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.codegen.opt import OptOps, Opt
|
||||
|
||||
@unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY")
|
||||
class TestRangeifyAssign(unittest.TestCase):
|
||||
@@ -83,6 +84,20 @@ class TestRangeify(unittest.TestCase):
|
||||
C = Tensor.empty(N, N)
|
||||
(A@B@C).realize()
|
||||
|
||||
def test_double_gemm_tc(self):
|
||||
with Context(DEBUG=0):
|
||||
A, B, C = [Tensor.randn(N, N) for _ in range(3)]
|
||||
Tensor.realize(A, B, C)
|
||||
#args = (Opt(OptOps.TC, 0, (0,0,1,1))), Opt(OptOps.TC, 0, (0,0,1,0))
|
||||
#args = (Opt(OptOps.TC, 0, (0,0,1,0)),)
|
||||
args = (Opt(OptOps.TC, 0, (0,0,1,1)),)
|
||||
#args = ()
|
||||
tst = (A@B@C).contiguous(arg=args).realize()
|
||||
assert tst.uop.base.op is Ops.BUFFER, "buffer"
|
||||
with Context(RANGEIFY=0, DEBUG=0):
|
||||
mse = ((A@B@C)-tst).square().mean().item()
|
||||
print(mse)
|
||||
|
||||
def test_double_gemm_exp(self):
|
||||
A = Tensor.empty(N, N)
|
||||
B = Tensor.empty(N, N)
|
||||
@@ -192,10 +207,10 @@ class TestRangeify(unittest.TestCase):
|
||||
out.realize()
|
||||
|
||||
def test_flash_attention(self):
|
||||
#BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
|
||||
# bigger
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
|
||||
#BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
|
||||
|
||||
# llama 8B
|
||||
#BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
|
||||
@@ -205,11 +220,9 @@ class TestRangeify(unittest.TestCase):
|
||||
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
|
||||
return q.scaled_dot_product_attention(k, v)
|
||||
|
||||
from tinygrad.codegen.opt import OptOps, Opt
|
||||
with Context(DEBUG=4):
|
||||
GlobalCounters.reset()
|
||||
opts = (Opt(OptOps.UPCAST,0,4),)
|
||||
ret = fa().contiguous(arg=opts).realize()
|
||||
ret = fa().realize()
|
||||
with Context(RANGEIFY=0):
|
||||
with Context(DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import cast, Final
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, KernelInfo, graph_rewrite, AxisType, ssimplify, can_pad, GroupOp
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import AddrSpace, dtypes, ImageDType
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element
|
||||
from tinygrad.codegen.opt import axis_colors, Opt, OptOps, KernelOptError, check, axis_letters
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -173,13 +173,13 @@ class Scheduler:
|
||||
check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
|
||||
ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
|
||||
elif opt.op is OptOps.TC:
|
||||
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
|
||||
#check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
|
||||
check(opt.axis is not None, "tensor core opts must have an axis")
|
||||
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
|
||||
check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) >= 3, "tensor core opts must have valid arg")
|
||||
check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.opts.tensor_cores), "tensor core opts must have valid tc_select")
|
||||
check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")
|
||||
check(0 < (use_tensor_cores:=cast(tuple, opt.arg)[2]) <= 2, "use_tensor_cores value is not valid")
|
||||
try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt)
|
||||
try: ret = self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), tc_select, tc_opt, opt.arg[3] if len(opt.arg) > 3 else 0)
|
||||
except ValueError as e: raise KernelOptError(str(e))
|
||||
check(ret is not None, "no tensor core available")
|
||||
elif opt.op is OptOps.PADTO:
|
||||
@@ -213,10 +213,10 @@ class Scheduler:
|
||||
if append_opt: self.applied_opts.append(opt)
|
||||
return ret
|
||||
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int) -> None|list[UOp]:
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, tc_select:int, opt_level:int, reduce_choice:int) -> None|list[UOp]:
|
||||
reduceops = [x for x in self.ast.toposort() if x.op is Ops.REDUCE]
|
||||
if not len(reduceops): raise KernelOptError("no reduce ops for TensorCore")
|
||||
reduceop = reduceops[0]
|
||||
reduceop = reduceops[reduce_choice]
|
||||
if use_tensor_cores and reduceop is not None and reduceop.arg is Ops.ADD:
|
||||
mul = reduceop.src[0] if reduceop.src[0].op is not Ops.CAST else reduceop.src[0].src[0]
|
||||
if mul.op is not Ops.MUL: return None
|
||||
@@ -242,6 +242,9 @@ class Scheduler:
|
||||
if not (axis < len(axis_choices)): continue
|
||||
axes = list(axis_choices[axis])
|
||||
|
||||
# tag the reduceop
|
||||
self.ast = self.ast.substitute({reduceop: reduceop.replace(tag="TC")})
|
||||
|
||||
# do optimizations and save the ranges
|
||||
try:
|
||||
for i,a in enumerate(axes):
|
||||
@@ -271,7 +274,7 @@ class Scheduler:
|
||||
|
||||
if use_tensor_cores != 2:
|
||||
# fix the srcs
|
||||
reduceop = [x for x in self.ast.toposort() if x.op is Ops.REDUCE][0]
|
||||
reduceop = get_single_element([x for x in self.ast.toposort() if x.op is Ops.REDUCE and x.tag == "TC"])
|
||||
tne = [x.replace(tag=1) for x in ne]
|
||||
ret = reduceop.substitute(dict(zip(ne, tne)))
|
||||
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
||||
|
||||
Reference in New Issue
Block a user