wmma: add test and tensor core shape (#1925)

This commit is contained in:
Francis Lam
2023-09-28 18:04:28 -07:00
committed by GitHub
parent 094d3d71be
commit f445e056ed
5 changed files with 79 additions and 22 deletions

View File

@@ -204,6 +204,8 @@ jobs:
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py test/test_symbolic_jit.py
- name: Check Device.DEFAULT
run: WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
- name: Run linearizer and tensor core test
run: METAL=1 python -m pytest -n=auto test/test_linearizer.py
#- name: Run webgpu pytest
# run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto --ignore test/models/ --ignore test/unit/test_example.py --ignore test/extra/test_lr_scheduler.py --ignore test/test_linearizer.py test/
#- name: Build WEBGPU Efficientnet

View File

@@ -2,9 +2,12 @@ import numpy as np
from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
dtype = dtypes.half if getenv("HALF") else dtypes.float
N = 4096
a, b = Tensor.rand(N, N, dtype=dtype).realize(), Tensor.rand(N, N, dtype=dtype).realize()
for i in range(10):
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize()
print((c.numpy() - (a.numpy().astype(np.float32) @ b.numpy().astype(np.float32))).mean())
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
N = getenv("N", 4096)
CNT = getenv("CNT", 10)
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
for i in range(CNT):
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize() if getenv("ACCUM_FP32") else (a @ b).realize()
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
nc = c.numpy()
np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=1e-2)

View File

@@ -1,11 +1,13 @@
import numpy as np
import unittest
import unittest, os
from tinygrad.codegen.kernel import tensor_cores
from tinygrad.codegen.linearizer import Linearizer, UOps
from tinygrad.ops import Compiled, Device, MovementOps, LazyOp
from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
from tinygrad.lazy import _replace_bufferops
from extra.utils import print_tree
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
@@ -18,7 +20,7 @@ class TestLinearizer(unittest.TestCase):
rawbufs = CacheCollector.finish()[0][1]
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.realized, b.lazydata.realized}
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy())
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
def test_load_dedup(self):
# for different leaves in the AST, the same loads may occur.
@@ -86,11 +88,32 @@ class TestLinearizer(unittest.TestCase):
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
assert num_ops <= 0, "more load or alu uops than needed"
def helper_linearizer_opt(r:Tensor, opts=[]):
wanna_output = None
def test_tensor_cores(self):
if not isinstance(Device[Device.DEFAULT], Compiled):
self.skipTest("Only Compiled uses linearizer")
if Device.DEFAULT not in tensor_cores:
self.skipTest("No tensor cores for device")
for tc in tensor_cores[Device.DEFAULT]:
if tc.arch is not None and tc.arch != os.uname().machine: continue
a, b = Tensor.rand(tc.dims[0], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[1], dtype=tc.dtype_in)
np_a, np_b = a.numpy(), b.numpy()
if tc.dtype_out != tc.dtype_in:
r = (a.reshape(tc.dims[0], 1, tc.dims[2]) * b.permute(1,0).reshape(1, tc.dims[1], tc.dims[2])).cast(tc.dtype_out).sum(axis=2)
else:
r = a @ b
realized_ast, _ = helper_realized_ast(r)
k = Linearizer(realized_ast, Device[Device.DEFAULT].linearizer_opts)
k.process()
k.hand_coded_optimizations()
k.linearize()
assert len([uop for uop in k.uops if uop.uop == UOps.WMMA]) == 1, "tensor core not triggered"
np_c = np_a @ np_b
np.testing.assert_allclose(np_c, r.numpy(), atol=5e-3, rtol=1e-4)
def helper_realized_ast(r:Tensor):
realized_ast = None
real_bufs = None
# HACK to get real ast.
real_dev_exec_ast = Device[Device.DEFAULT].exec_ast
def fake_exec_ast(ast, output=None, inputs=None, **kwargs):
@@ -103,6 +126,11 @@ def helper_linearizer_opt(r:Tensor, opts=[]):
r = r.realize() # realize an output buffer
assert realized_ast is not None
Device[Device.DEFAULT].exec_ast = real_dev_exec_ast
return realized_ast, real_bufs
def helper_linearizer_opt(r:Tensor, opts=[]):
wanna_output = None
realized_ast, real_bufs = helper_realized_ast(r)
def check_opt(x, create_k, to_prg):
k = create_k()

View File

@@ -14,7 +14,7 @@ from tinygrad.ops import Device
from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv, CI
from tinygrad.helpers import colored, getenv, CI, dtypes
from tinygrad.jit import TinyJit
import pytest
@@ -22,6 +22,7 @@ pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu, pytest.mark.exc
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
torch_dt = torch.float16 if getenv("HALF", 0) else torch.float32
torch_device = torch.device('mps' if getenv("MPS", 0) else ('cuda' if getenv("TORCHCUDA", 0) else 'cpu'))
if str(torch_device) == "mps":
import torch.mps
@@ -78,8 +79,8 @@ def helper_test_speed(f1, *args):
def helper_test_generic_square(name, N, f1, f2, onearg=False):
torch.manual_seed(0)
torch_a = (torch.rand(N, N) - 0.5).to(torch_device)
torch_b = (torch.rand(N, N) - 0.5).to(torch_device) if not onearg else None
torch_a = (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device)
torch_b = (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device) if not onearg else None
tiny_a = Tensor(torch_a.cpu().numpy())
tiny_b = Tensor(torch_b.cpu().numpy()) if not onearg else None
@@ -88,9 +89,8 @@ def helper_test_generic_square(name, N, f1, f2, onearg=False):
def helper_test_matvec(name, N, M):
torch.manual_seed(0)
dt = torch.float32
torch_a = (torch.rand(N, dtype=dt) - 0.5).to(torch_device)
torch_b = (torch.rand(N, M, dtype=dt) - 0.5).to(torch_device)
torch_a = (torch.rand(N, dtype=torch_dt) - 0.5).to(torch_device)
torch_b = (torch.rand(N, M, dtype=torch_dt) - 0.5).to(torch_device)
tiny_a = Tensor(torch_a.cpu().numpy())
tiny_b = Tensor(torch_b.cpu().numpy())
@@ -112,8 +112,8 @@ def helper_test_generic(name, f1, f1_args, f2, f2_args):
def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x):
torch.manual_seed(0)
torch_dat = torch.rand(bs, in_chans, img_size_y, img_size_x).to(torch_device)
torch_conv = torch.nn.Conv2d(in_chans, out_chans, kernel_size, bias=None).to(torch_device)
torch_dat = torch.rand(bs, in_chans, img_size_y, img_size_x, dtype=torch_dt).to(torch_device)
torch_conv = torch.nn.Conv2d(in_chans, out_chans, kernel_size, bias=None, dtype=torch_dt).to(torch_device)
tiny_dat = Tensor(torch_dat.cpu().numpy())
tiny_conv = Conv2d(in_chans, out_chans, kernel_size, bias=None)
@@ -190,7 +190,7 @@ class TestSpeed(unittest.TestCase):
def test_double_permute(self):
N = 64
torch.manual_seed(0)
torch_a = (torch.rand(N, N, N, N) - 0.5).to(torch_device)
torch_a = (torch.rand(N, N, N, N, dtype=torch_dt) - 0.5).to(torch_device)
tiny_a = Tensor(torch_a.cpu().numpy())
def f(a): return a.permute(1,0,3,2).contiguous()
helper_test_generic(f"double_permute {tiny_a.shape}", f, (torch_a,), TinyJit(lambda a: f(a).realize()), (tiny_a,))
@@ -268,8 +268,8 @@ class TestSpeed(unittest.TestCase):
def test_openpilot_conv2d(self):
bs, in_chans, out_chans = 1,12,32
torch.manual_seed(0)
torch_dat = torch.rand(bs, 64, 128, 12).to(torch_device)
torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None, padding=1).to(torch_device)
torch_dat = torch.rand(bs, 64, 128, 12, dtype=torch_dt).to(torch_device)
torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None, padding=1, dtype=torch_dt).to(torch_device)
tiny_dat = Tensor(torch_dat.cpu().numpy())
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1)

View File

@@ -5,6 +5,30 @@ from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import strides_for_shape
from dataclasses import dataclass
@dataclass(frozen=True)
class TensorCore:
device: str
dims: List[int]
dtype_in: DType
dtype_out: DType
threads: List[int]
thread_local_aliases: List[List[List[int]]]
thread_local_sizes: List[int]
arch: Optional[str] = None
def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
# TODO(TC): doesn't belong here!!!
tensor_cores: Dict[str, List[TensorCore]] = {
"METAL": [
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, threads=[2,4,2,2], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[-1, 1, 3], [0], [2, 4]], [[2, 4], [-1, 1, 3], [0]], [[0], [-1, 1, 3], [2, 4]] ], arch="arm64"),
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, threads=[2,4,2,2], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[-1, 1, 3], [0], [2, 4]], [[2, 4], [-1, 1, 3], [0]], [[0], [-1, 1, 3], [2, 4]] ], arch="arm64")
],
"HIP": [
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, threads=[16,2], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[-1], [0], [1]], [[-1], [1], [0]], [[0], [1], [2, -1]] ]),
]
}
class LocalBuffer(NamedTuple):
name: str