mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
good changes from tensor_cores branch (#1005)
* good changes from tensor_cores branch * touchups * real_strides fixup * refactor merge_views
This commit is contained in:
@@ -297,6 +297,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
def test_gemm(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
|
||||
def test_big_gemm(self):
|
||||
helper_test_op([(256,256), (256,256)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
|
||||
def test_broadcastdot(self):
|
||||
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
||||
@@ -33,7 +33,7 @@ def colorize_float(x):
|
||||
ret = f"{x:7.2f}x"
|
||||
if x < 0.75:
|
||||
return colored(ret, 'green')
|
||||
elif x > 1.5:
|
||||
elif x > 1.33:
|
||||
return colored(ret, 'red')
|
||||
else:
|
||||
return colored(ret, 'yellow')
|
||||
@@ -81,7 +81,7 @@ def helper_test_generic_square(name, N, f1, f2, onearg=False):
|
||||
tiny_a = Tensor(torch_a.cpu().numpy())
|
||||
tiny_b = Tensor(torch_b.cpu().numpy()) if not onearg else None
|
||||
|
||||
helper_test_generic(f"{name:30s} {N:4d}x{N:4d}", f1, (torch_a, torch_b), TinyJit(lambda a,b:f2(a,b).realize()), (tiny_a, tiny_b))
|
||||
helper_test_generic(f"{name:30s} {N:5d}x{N:5d}", f1, (torch_a, torch_b), TinyJit(lambda a,b:f2(a,b).realize()), (tiny_a, tiny_b))
|
||||
|
||||
prefix = None
|
||||
def helper_test_generic(name, f1, f1_args, f2, f2_args):
|
||||
@@ -93,10 +93,44 @@ def helper_test_generic(name, f1, f1_args, f2, f2_args):
|
||||
desc = "faster" if et_torch > et_tinygrad else "slower"
|
||||
flops = save_ops*1e-6
|
||||
mem = save_mem*1e-6
|
||||
print(f"{prefix}{name:40s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:7.2f} MOPS {mem:7.2f} MB")
|
||||
print(f"{prefix}{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB")
|
||||
prefix = " "
|
||||
np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-4, rtol=1e-3)
|
||||
|
||||
def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x):
|
||||
torch.manual_seed(0)
|
||||
torch_dat = torch.rand(bs, in_chans, img_size_y, img_size_x).to(torch_device)
|
||||
torch_conv = torch.nn.Conv2d(in_chans, out_chans, kernel_size, bias=None).to(torch_device)
|
||||
|
||||
tiny_dat = Tensor(torch_dat.cpu().numpy())
|
||||
tiny_conv = Conv2d(in_chans, out_chans, kernel_size, bias=None)
|
||||
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
|
||||
|
||||
def f1(torch_dat): return torch_conv(torch_dat)
|
||||
def f2(tiny_dat): return tiny_conv(tiny_dat).realize()
|
||||
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
|
||||
|
||||
@unittest.skipIf(getenv("BIG") != 1, "no big tests")
|
||||
class TestBigSpeed(unittest.TestCase):
|
||||
def setUp(self):
|
||||
global prefix
|
||||
prefix = " " if prefix is None else ""
|
||||
return super().setUp()
|
||||
def test_exp(self):
|
||||
def f(a, b): return a.exp()
|
||||
helper_test_generic_square('exp', 16384, f, f, onearg=True)
|
||||
def test_gemm_1024(self):
|
||||
def f(a, b): return a @ b
|
||||
helper_test_generic_square('gemm', 1024, f, f)
|
||||
def test_gemm_2048(self):
|
||||
def f(a, b): return a @ b
|
||||
helper_test_generic_square('gemm', 2048, f, f)
|
||||
def test_gemm_4096(self):
|
||||
def f(a, b): return a @ b
|
||||
helper_test_generic_square('gemm', 4096, f, f)
|
||||
def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128)
|
||||
def test_large_conv_3x3(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130)
|
||||
|
||||
class TestSpeed(unittest.TestCase):
|
||||
def setUp(self):
|
||||
global prefix
|
||||
@@ -221,21 +255,10 @@ class TestSpeed(unittest.TestCase):
|
||||
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
|
||||
|
||||
def test_conv2d(self):
|
||||
torch.manual_seed(0)
|
||||
for bs in [32]:
|
||||
for in_chans in IN_CHANS:
|
||||
for out_chans in [32]:
|
||||
img_size = 34
|
||||
torch_dat = torch.rand(bs, in_chans, img_size, img_size).to(torch_device)
|
||||
torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None).to(torch_device)
|
||||
|
||||
tiny_dat = Tensor(torch_dat.cpu().numpy())
|
||||
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None)
|
||||
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
|
||||
|
||||
def f1(torch_dat): return torch_conv(torch_dat)
|
||||
def f2(tiny_dat): return tiny_conv(tiny_dat).realize()
|
||||
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
|
||||
helper_test_conv(bs, in_chans, out_chans, 3, 34, 34)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -67,6 +67,45 @@ class CheckingShapeTracker:
|
||||
assert self.st.shape == self.shape
|
||||
assert x == y, f"mismatch shapetracker:{x} real:{y}"
|
||||
|
||||
class TestRealDoesntSimplify(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
st = self.st.real_strides()
|
||||
print(st)
|
||||
self.st.simplify()
|
||||
assert len(self.st.views) != 1
|
||||
assert None in st
|
||||
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((8, 6, 11), views=[
|
||||
View((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
|
||||
View((8, 6, 11), (66, 11, 1), 0, None)])
|
||||
assert self.st.real_strides() == (33, None, 1)
|
||||
|
||||
def test_2(self):
|
||||
self.st = ShapeTracker((4, 4, 3, 3), views=[
|
||||
View((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
|
||||
View((4, 4, 3, 3), (36, 9, 3, 1), 0, None)])
|
||||
assert self.st.real_strides() == (None, 18, -3, -1)
|
||||
|
||||
class TestRealSimplifies(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
st = self.st.real_strides()
|
||||
self.st.simplify()
|
||||
assert len(self.st.views) == 1
|
||||
print(self.st.views[-1].strides, st)
|
||||
assert self.st.views[-1].strides == st
|
||||
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((1, 3, 2, 11, 26, 1, 1, 3), views=[
|
||||
View((1, 3, 2, 11, 4, 28), (0, 308, 0, 28, 0, 1), 0, None),
|
||||
View((1, 3, 2, 11, 26, 1, 1, 3), (0, 2464, 0, 112, 1, 0, 0, 29), 0, None)])
|
||||
|
||||
def test_2(self):
|
||||
self.st = ShapeTracker((8, 1, 6, 10, 28, 3, 2, 1), views=[
|
||||
View((8, 3, 3, 11, 2, 28), (924, 308, 0, 28, 0, 1), 0, None),
|
||||
View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
|
||||
|
||||
|
||||
class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = CheckingShapeTracker((1, 10))
|
||||
|
||||
@@ -143,16 +143,17 @@ class Linearizer:
|
||||
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
|
||||
def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
|
||||
|
||||
# TODO: this stride is only on the last view, and may not be real
|
||||
def upcasted_axis(self, i):
|
||||
return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
|
||||
self.sts[i].views[-1].strides[self.shape_len-self.upcasted:], # WRONG
|
||||
self.sts[i].real_strides()[self.shape_len-self.upcasted:],
|
||||
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
|
||||
|
||||
# TODO: is there a better way to write this?
|
||||
def acc_offsets(self, i):
|
||||
if self.upcasted == 0: return [0]
|
||||
acc_strides = [x*(1-self.upcasted_axis(i)[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in self.upcasted_axis(i)[::-1])))]
|
||||
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(self.upcasted_axis(i)[::-1])])]
|
||||
upcasted_i = self.upcasted_axis(i)
|
||||
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
|
||||
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
|
||||
|
||||
def _group_float4(self, i, store_offset):
|
||||
store_offset_float4 = {}
|
||||
@@ -506,12 +507,12 @@ class Linearizer:
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# potentially do more upcasts of non reduce axes based on a heuristic
|
||||
upcasted_axis = set()
|
||||
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
|
||||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# if it mods, and some buffer has stride 0 on axis while having no stride 0 in the buftoken
|
||||
# NOTE: this is using views[-1]
|
||||
if self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))):
|
||||
# if we haven't upcasted it, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
|
||||
if len(xb_choices):
|
||||
xb_choices = sorted(xb_choices)
|
||||
@@ -519,6 +520,7 @@ class Linearizer:
|
||||
self.shift_to(xb_choices[0][2], amount=xb_choices[0][3])
|
||||
self.upcast()
|
||||
self.simplify_ones()
|
||||
upcasted_axis.add(xb_choices[0][2])
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ class _CL:
|
||||
for q in self.cl_queue: q.finish()
|
||||
CL = _CL()
|
||||
|
||||
# TODO: merge CLImage in here
|
||||
class CLBuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype, device='0'):
|
||||
assert not OSX or dtype != dtypes.float64, "OpenCL on Mac doesn't support float64"
|
||||
|
||||
@@ -4,17 +4,11 @@ from enum import Enum, auto
|
||||
import functools
|
||||
from typing import Dict, Tuple, Union, List, Optional, Callable, cast
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, ModNode
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node
|
||||
|
||||
# these ops live here
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
|
||||
|
||||
def check_no_mul(test, var):
|
||||
if test == var: return True
|
||||
if test.__class__ is SumNode: return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay
|
||||
if test.__class__ is ModNode and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
|
||||
return False
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tuple[int, int]]:
|
||||
assert len(shape) == len(strides)
|
||||
@@ -97,22 +91,10 @@ def view_from_shape(shape:Tuple[int, ...]) -> View:
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
if vm2.mask: return None # this isn't supported yet
|
||||
new_strides, new_offset = [], vm2.expr_node(Variable.num(vm1.offset))
|
||||
assert isinstance(new_offset, NumNode), "new_offset wasn't a number?!?"
|
||||
for s,st in zip(vm1.shape, vm1.strides):
|
||||
this_dim = View(vm2.shape, vm2.strides).expr_node(Variable('idx', 0, s-1)*st)
|
||||
if s == 1:
|
||||
new_strides.append(0) # all shape 1 can have stride 0
|
||||
elif this_dim.__class__ is NumNode and this_dim.b == 0:
|
||||
new_strides.append(0)
|
||||
elif this_dim.__class__ is Variable:
|
||||
new_strides.append(1)
|
||||
elif this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable:
|
||||
new_strides.append(this_dim.b)
|
||||
else:
|
||||
if DEBUG >= 4: print("can't simplify", s, this_dim.render())
|
||||
break
|
||||
return View(vm1.shape, tuple(new_strides), new_offset.b, vm1.mask) if len(new_strides) == len(vm1.strides) else None
|
||||
mst = ShapeTracker(vm1.shape, [vm2, vm1])
|
||||
strides = mst.real_strides()
|
||||
if None in strides: return None
|
||||
return View(vm1.shape, cast(Tuple[int, ...], strides), mst.real_offset(), vm1.mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]:
|
||||
@@ -168,15 +150,33 @@ class ShapeTracker:
|
||||
# this is the real size (ish)
|
||||
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
|
||||
|
||||
def unit_stride_axes(self) -> List[int]:
|
||||
ret, acc = [], 1
|
||||
for j,s in reversed(list(enumerate(self.shape))):
|
||||
if s == 1: continue
|
||||
# these are multiview strides, value is None if it's not a simple strided dimension
|
||||
# TODO: this can be shared code between simplify and merge_views
|
||||
def real_offset(self) -> int:
|
||||
real_offset, mask = self.expr_node(Variable('zero', 0, 0))
|
||||
assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}"
|
||||
return real_offset.b
|
||||
|
||||
def real_strides(self) -> Tuple[Optional[int], ...]:
|
||||
if len(self.views) == 1: return self.views[-1].strides
|
||||
ret: List[Optional[int]] = []
|
||||
acc, real_offset = 1, self.real_offset()
|
||||
for s in reversed(self.shape):
|
||||
if s == 1: # fast path, all shape 1 have stride 0
|
||||
ret.append(0)
|
||||
continue
|
||||
var = Variable('idx', 0, s-1)
|
||||
this_dim = self.expr_node(var*acc)
|
||||
this_dim, _ = self.expr_node(var*acc)
|
||||
this_dim -= real_offset
|
||||
acc *= s
|
||||
if check_no_mul(this_dim[0], var): ret.append(j)
|
||||
return ret
|
||||
# TODO: sometimes a mod here is okay if you are say, reading a float4, since you only care %4
|
||||
# if test.__class__ is ModNode and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
|
||||
if this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable: ret.append(this_dim.b)
|
||||
elif this_dim.__class__ is NumNode and this_dim.b == 0: ret.append(0)
|
||||
elif this_dim.__class__ is Variable: ret.append(1)
|
||||
else: ret.append(None)
|
||||
return tuple(ret[::-1])
|
||||
def unit_stride_axes(self) -> List[int]: return [i for i,st in enumerate(self.real_strides()) if st == 1]
|
||||
|
||||
def _expr_idx(self, idx, valid):
|
||||
for v in reversed(self.views[0:-1]):
|
||||
|
||||
Reference in New Issue
Block a user