mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-11 07:58:08 -05:00
small changes from postopt (#11854)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import RANGEIFY
|
||||
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
|
||||
|
||||
N = 256
|
||||
|
||||
@@ -96,14 +96,30 @@ class TestRangeify(unittest.TestCase):
|
||||
out.realize()
|
||||
|
||||
def test_flash_attention(self):
|
||||
BS = 4
|
||||
HEADS = 2
|
||||
MATDIM = 16
|
||||
EMB = 8
|
||||
q = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
k = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
v = Tensor.empty(BS, HEADS, MATDIM, EMB)
|
||||
q.scaled_dot_product_attention(k, v).realize()
|
||||
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
|
||||
# bigger
|
||||
#BS, HEADS, SEQLEN, EMB = 4, 16, 128, 64
|
||||
|
||||
# llama 8B
|
||||
#BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
|
||||
|
||||
def fa():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
|
||||
return q.scaled_dot_product_attention(k, v).realize()
|
||||
|
||||
with Context(DEBUG=4):
|
||||
GlobalCounters.reset()
|
||||
ret = fa()
|
||||
with Context(RANGEIFY=0):
|
||||
with Context(DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
cmp = fa()
|
||||
with Context(DEBUG=0):
|
||||
mse = ((cmp-ret)**2).sum().item()
|
||||
print(f"mse: {mse}")
|
||||
self.assertLessEqual(mse, 1e-6)
|
||||
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
@@ -30,7 +30,10 @@ class TestTiny(unittest.TestCase):
|
||||
def test_gemm(self, N=64, out_dtype=dtypes.float):
|
||||
a = Tensor.ones(N,N).contiguous()
|
||||
b = Tensor.eye(N).contiguous()
|
||||
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
|
||||
lst = (out:=a@b).tolist()
|
||||
for y in range(N):
|
||||
for x in range(N):
|
||||
self.assertEqual(lst[y][x], 1.0, msg=f"mismatch at ({y},{x})")
|
||||
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
||||
|
||||
# *** randomness ***
|
||||
|
||||
@@ -18,7 +18,7 @@ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_in
|
||||
from tinygrad.codegen.late.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.codegen.opt import pm_get_optimization, pm_do_optimize
|
||||
from tinygrad.codegen.opt.swizzler import view_left, view_right, fix_kernel_ops
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers, rangeify_codegen
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
@@ -72,7 +72,7 @@ def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVEC
|
||||
ret.append(RewriteStep(sym+expander, name="expander"))
|
||||
|
||||
# add locals
|
||||
ret.append(RewriteStep(pm_add_buffers+rangeify_codegen, name="add local buffers"))
|
||||
ret.append(RewriteStep(pm_add_buffers_local+rangeify_codegen, name="add local buffers"))
|
||||
|
||||
# ** devectorizer (full_graph_rewrite) **
|
||||
# remove reduce
|
||||
|
||||
@@ -50,7 +50,7 @@ def do_expand(root:UOp):
|
||||
if root.op is Ops.IF or src.op is Ops.IF:
|
||||
# for the first arg of IF, just pass them through ignoring UNROLLS
|
||||
new_srcs.append(src)
|
||||
elif (root.op is Ops.STORE and i >= 2) or (root.op in {Ops.REDUCE, Ops.BUFFERIZE} and i >= 1):
|
||||
elif (root.op is Ops.STORE and i >= 2) or (root.op in {Ops.REDUCE, Ops.BUFFERIZE} and i >= 1) or (root.op is Ops.WMMA and i >= 3):
|
||||
# for any range args of STORE/REDUCE, pass them through
|
||||
new_srcs.append(src)
|
||||
elif root.op is Ops.INDEX and i >= 1 and not isinstance(root.dtype, PtrDType):
|
||||
|
||||
@@ -245,7 +245,7 @@ class Kernel:
|
||||
if axis is None: return -1
|
||||
if op is OptOps.UNROLL: return self.unrollable_dims[axis]
|
||||
if op in {OptOps.GROUP, OptOps.GROUPTOP}: return self.axes_of(AxisType.REDUCE)[axis]
|
||||
check(axis < self.shape_len, "invalid axis")
|
||||
check(axis < self.shape_len, f"invalid axis on {axis=} {op=} {self.shape_len=}")
|
||||
return axis
|
||||
except IndexError as e: raise KernelOptError from e
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ class NullRenderer(CStyleLanguage):
|
||||
device = "NULL"
|
||||
has_local = False
|
||||
float4 = "float4"
|
||||
barrier = "// BARRIER"
|
||||
code_for_op = {**CStyleLanguage.code_for_op, Ops.THREEFRY: lambda a,b,dtype: f"threefry({a},{b})", Ops.MAX: lambda a,b,dtype: f"max({a},{b})"}
|
||||
|
||||
class NullProgram:
|
||||
|
||||
@@ -329,7 +329,7 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
|
||||
# BUFFERIZE returns the BUFFER ready for INDEXing (doing this will make splitting a lot easier)
|
||||
# NOTE: this has been fixed up a bit
|
||||
|
||||
def bufferize_to_store(x:UOp):
|
||||
def bufferize_to_store(x:UOp, locals_allowed=False):
|
||||
rngs = x.src[1:]
|
||||
shape = tuple([int(r.vmax+1) for r in rngs])
|
||||
size = prod(shape)
|
||||
@@ -339,10 +339,18 @@ def bufferize_to_store(x:UOp):
|
||||
assign_target, assign_src = x.src[0].src
|
||||
assert assign_target.op is Ops.INDEX
|
||||
return assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=sdtype)
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL: buf = UOp.new_buffer(x.arg, size, x.dtype)
|
||||
else: buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=x.arg[1])
|
||||
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
||||
if sdtype.addrspace == AddrSpace.GLOBAL:
|
||||
buf = UOp.new_buffer(x.arg, size, x.dtype)
|
||||
else:
|
||||
if not locals_allowed: return None
|
||||
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=x.arg[1])
|
||||
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
|
||||
|
||||
pm_add_buffers_local = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), lambda x: bufferize_to_store(x, True)),
|
||||
])
|
||||
|
||||
pm_add_buffers = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
||||
|
||||
|
||||
@@ -202,17 +202,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@functools.cached_property
|
||||
def ranges(self) -> dict[UOp, None]:
|
||||
if self.op is Ops.RANGE: return {self:None}
|
||||
if self.op in {Ops.BUFFERIZE, Ops.REDUCE}:
|
||||
ret = self.src[0].ranges.copy()
|
||||
for s in self.src[1:]:
|
||||
if s in ret: del ret[s]
|
||||
elif self.op in {Ops.STORE}:
|
||||
ret = self.src[0].ranges.copy()
|
||||
ret.update(self.src[1].ranges)
|
||||
for s in self.src[2:]:
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3}
|
||||
ret: dict[UOp, None] = {}
|
||||
if self.op in range_start.keys():
|
||||
for s in self.src[:range_start[self.op]]: ret.update(s.ranges)
|
||||
for s in self.src[range_start[self.op]:]:
|
||||
if s in ret: del ret[s]
|
||||
else:
|
||||
ret = {}
|
||||
for s in self.src: ret.update(s.ranges)
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user