mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
wmma: add test and tensor core shape (#1925)
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -204,6 +204,8 @@ jobs:
|
||||
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py test/test_symbolic_jit.py
|
||||
- name: Check Device.DEFAULT
|
||||
run: WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
|
||||
- name: Run linearizer and tensor core test
|
||||
run: METAL=1 python -m pytest -n=auto test/test_linearizer.py
|
||||
#- name: Run webgpu pytest
|
||||
# run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto --ignore test/models/ --ignore test/unit/test_example.py --ignore test/extra/test_lr_scheduler.py --ignore test/test_linearizer.py test/
|
||||
#- name: Build WEBGPU Efficientnet
|
||||
|
||||
@@ -2,9 +2,12 @@ import numpy as np
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes
|
||||
dtype = dtypes.half if getenv("HALF") else dtypes.float
|
||||
N = 4096
|
||||
a, b = Tensor.rand(N, N, dtype=dtype).realize(), Tensor.rand(N, N, dtype=dtype).realize()
|
||||
for i in range(10):
|
||||
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize()
|
||||
print((c.numpy() - (a.numpy().astype(np.float32) @ b.numpy().astype(np.float32))).mean())
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
|
||||
N = getenv("N", 4096)
|
||||
CNT = getenv("CNT", 10)
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
for i in range(CNT):
|
||||
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize() if getenv("ACCUM_FP32") else (a @ b).realize()
|
||||
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
nc = c.numpy()
|
||||
np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=1e-2)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
import unittest, os
|
||||
|
||||
from tinygrad.codegen.kernel import tensor_cores
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps
|
||||
from tinygrad.ops import Compiled, Device, MovementOps, LazyOp
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.lazy import _replace_bufferops
|
||||
from extra.utils import print_tree
|
||||
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
def test_arg_dedup(self):
|
||||
@@ -18,7 +20,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
rawbufs = CacheCollector.finish()[0][1]
|
||||
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.realized, b.lazydata.realized}
|
||||
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
|
||||
np.testing.assert_allclose(np_c, c.numpy())
|
||||
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_load_dedup(self):
|
||||
# for different leaves in the AST, the same loads may occur.
|
||||
@@ -86,11 +88,32 @@ class TestLinearizer(unittest.TestCase):
|
||||
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
|
||||
assert num_ops <= 0, "more load or alu uops than needed"
|
||||
|
||||
def helper_linearizer_opt(r:Tensor, opts=[]):
|
||||
wanna_output = None
|
||||
def test_tensor_cores(self):
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled):
|
||||
self.skipTest("Only Compiled uses linearizer")
|
||||
if Device.DEFAULT not in tensor_cores:
|
||||
self.skipTest("No tensor cores for device")
|
||||
|
||||
for tc in tensor_cores[Device.DEFAULT]:
|
||||
if tc.arch is not None and tc.arch != os.uname().machine: continue
|
||||
a, b = Tensor.rand(tc.dims[0], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[1], dtype=tc.dtype_in)
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
if tc.dtype_out != tc.dtype_in:
|
||||
r = (a.reshape(tc.dims[0], 1, tc.dims[2]) * b.permute(1,0).reshape(1, tc.dims[1], tc.dims[2])).cast(tc.dtype_out).sum(axis=2)
|
||||
else:
|
||||
r = a @ b
|
||||
realized_ast, _ = helper_realized_ast(r)
|
||||
k = Linearizer(realized_ast, Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
assert len([uop for uop in k.uops if uop.uop == UOps.WMMA]) == 1, "tensor core not triggered"
|
||||
np_c = np_a @ np_b
|
||||
np.testing.assert_allclose(np_c, r.numpy(), atol=5e-3, rtol=1e-4)
|
||||
|
||||
def helper_realized_ast(r:Tensor):
|
||||
realized_ast = None
|
||||
real_bufs = None
|
||||
|
||||
# HACK to get real ast.
|
||||
real_dev_exec_ast = Device[Device.DEFAULT].exec_ast
|
||||
def fake_exec_ast(ast, output=None, inputs=None, **kwargs):
|
||||
@@ -103,6 +126,11 @@ def helper_linearizer_opt(r:Tensor, opts=[]):
|
||||
r = r.realize() # realize an output buffer
|
||||
assert realized_ast is not None
|
||||
Device[Device.DEFAULT].exec_ast = real_dev_exec_ast
|
||||
return realized_ast, real_bufs
|
||||
|
||||
def helper_linearizer_opt(r:Tensor, opts=[]):
|
||||
wanna_output = None
|
||||
realized_ast, real_bufs = helper_realized_ast(r)
|
||||
|
||||
def check_opt(x, create_k, to_prg):
|
||||
k = create_k()
|
||||
|
||||
@@ -14,7 +14,7 @@ from tinygrad.ops import Device
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.helpers import colored, getenv, CI
|
||||
from tinygrad.helpers import colored, getenv, CI, dtypes
|
||||
from tinygrad.jit import TinyJit
|
||||
import pytest
|
||||
|
||||
@@ -22,6 +22,7 @@ pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu, pytest.mark.exc
|
||||
|
||||
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
|
||||
|
||||
torch_dt = torch.float16 if getenv("HALF", 0) else torch.float32
|
||||
torch_device = torch.device('mps' if getenv("MPS", 0) else ('cuda' if getenv("TORCHCUDA", 0) else 'cpu'))
|
||||
if str(torch_device) == "mps":
|
||||
import torch.mps
|
||||
@@ -78,8 +79,8 @@ def helper_test_speed(f1, *args):
|
||||
|
||||
def helper_test_generic_square(name, N, f1, f2, onearg=False):
|
||||
torch.manual_seed(0)
|
||||
torch_a = (torch.rand(N, N) - 0.5).to(torch_device)
|
||||
torch_b = (torch.rand(N, N) - 0.5).to(torch_device) if not onearg else None
|
||||
torch_a = (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device)
|
||||
torch_b = (torch.rand(N, N, dtype=torch_dt) - 0.5).to(torch_device) if not onearg else None
|
||||
|
||||
tiny_a = Tensor(torch_a.cpu().numpy())
|
||||
tiny_b = Tensor(torch_b.cpu().numpy()) if not onearg else None
|
||||
@@ -88,9 +89,8 @@ def helper_test_generic_square(name, N, f1, f2, onearg=False):
|
||||
|
||||
def helper_test_matvec(name, N, M):
|
||||
torch.manual_seed(0)
|
||||
dt = torch.float32
|
||||
torch_a = (torch.rand(N, dtype=dt) - 0.5).to(torch_device)
|
||||
torch_b = (torch.rand(N, M, dtype=dt) - 0.5).to(torch_device)
|
||||
torch_a = (torch.rand(N, dtype=torch_dt) - 0.5).to(torch_device)
|
||||
torch_b = (torch.rand(N, M, dtype=torch_dt) - 0.5).to(torch_device)
|
||||
|
||||
tiny_a = Tensor(torch_a.cpu().numpy())
|
||||
tiny_b = Tensor(torch_b.cpu().numpy())
|
||||
@@ -112,8 +112,8 @@ def helper_test_generic(name, f1, f1_args, f2, f2_args):
|
||||
|
||||
def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x):
|
||||
torch.manual_seed(0)
|
||||
torch_dat = torch.rand(bs, in_chans, img_size_y, img_size_x).to(torch_device)
|
||||
torch_conv = torch.nn.Conv2d(in_chans, out_chans, kernel_size, bias=None).to(torch_device)
|
||||
torch_dat = torch.rand(bs, in_chans, img_size_y, img_size_x, dtype=torch_dt).to(torch_device)
|
||||
torch_conv = torch.nn.Conv2d(in_chans, out_chans, kernel_size, bias=None, dtype=torch_dt).to(torch_device)
|
||||
|
||||
tiny_dat = Tensor(torch_dat.cpu().numpy())
|
||||
tiny_conv = Conv2d(in_chans, out_chans, kernel_size, bias=None)
|
||||
@@ -190,7 +190,7 @@ class TestSpeed(unittest.TestCase):
|
||||
def test_double_permute(self):
|
||||
N = 64
|
||||
torch.manual_seed(0)
|
||||
torch_a = (torch.rand(N, N, N, N) - 0.5).to(torch_device)
|
||||
torch_a = (torch.rand(N, N, N, N, dtype=torch_dt) - 0.5).to(torch_device)
|
||||
tiny_a = Tensor(torch_a.cpu().numpy())
|
||||
def f(a): return a.permute(1,0,3,2).contiguous()
|
||||
helper_test_generic(f"double_permute {tiny_a.shape}", f, (torch_a,), TinyJit(lambda a: f(a).realize()), (tiny_a,))
|
||||
@@ -268,8 +268,8 @@ class TestSpeed(unittest.TestCase):
|
||||
def test_openpilot_conv2d(self):
|
||||
bs, in_chans, out_chans = 1,12,32
|
||||
torch.manual_seed(0)
|
||||
torch_dat = torch.rand(bs, 64, 128, 12).to(torch_device)
|
||||
torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None, padding=1).to(torch_device)
|
||||
torch_dat = torch.rand(bs, 64, 128, 12, dtype=torch_dt).to(torch_device)
|
||||
torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None, padding=1, dtype=torch_dt).to(torch_device)
|
||||
|
||||
tiny_dat = Tensor(torch_dat.cpu().numpy())
|
||||
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1)
|
||||
|
||||
@@ -5,6 +5,30 @@ from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorCore:
|
||||
device: str
|
||||
dims: List[int]
|
||||
dtype_in: DType
|
||||
dtype_out: DType
|
||||
threads: List[int]
|
||||
thread_local_aliases: List[List[List[int]]]
|
||||
thread_local_sizes: List[int]
|
||||
arch: Optional[str] = None
|
||||
def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
|
||||
|
||||
# TODO(TC): doesn't belong here!!!
|
||||
tensor_cores: Dict[str, List[TensorCore]] = {
|
||||
"METAL": [
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, threads=[2,4,2,2], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[-1, 1, 3], [0], [2, 4]], [[2, 4], [-1, 1, 3], [0]], [[0], [-1, 1, 3], [2, 4]] ], arch="arm64"),
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, threads=[2,4,2,2], thread_local_sizes=[2,2,2], thread_local_aliases=[ [[-1, 1, 3], [0], [2, 4]], [[2, 4], [-1, 1, 3], [0]], [[0], [-1, 1, 3], [2, 4]] ], arch="arm64")
|
||||
],
|
||||
"HIP": [
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, threads=[16,2], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[-1], [0], [1]], [[-1], [1], [0]], [[0], [1], [2, -1]] ]),
|
||||
]
|
||||
}
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
name: str
|
||||
|
||||
Reference in New Issue
Block a user