mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fix intel wmma flop counting, add flop counting tests for different tensor cores (#6192)
* fix wmma flop counting on intel, add count tests * half * add half gemm * Update test.yml * one test * Update test_uops_stats.py * Update test_uops_stats.py * Update test_uops_stats.py * smaller matrix, use unittest skipUnless decorator
This commit is contained in:
10
.github/workflows/test.yml
vendored
10
.github/workflows/test.yml
vendored
@@ -57,9 +57,15 @@ jobs:
|
||||
- name: Test tensor cores (TC=3)
|
||||
run: |
|
||||
TC=3 DEBUG=3 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
|
||||
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_AMD=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
TC=3 DEBUG=3 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
|
||||
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_INTEL=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_INTEL=1 PYTHON=1 N=16 HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
- name: Test device flop counts
|
||||
run: |
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_INTEL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
|
||||
- name: Test dtype with Python emulator
|
||||
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py
|
||||
- name: Test ops with Python emulator
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv, GlobalCounters
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
from tinygrad.codegen.uopgraph import linearize_uop
|
||||
@@ -95,6 +95,16 @@ class TestUOpsStats(unittest.TestCase):
|
||||
# NOTE: it's hard to assert on the memory here, all depends on caching
|
||||
assert required_mem <= mem
|
||||
|
||||
@unittest.skipUnless(getenv("PYTHON"), "only run test on emulated tensor cores")
|
||||
def test_simple_matmul_half(self):
|
||||
GlobalCounters.reset()
|
||||
N = 16
|
||||
a, b = Tensor.empty(N, N, dtype=dtypes.half), Tensor.empty(N, N, dtype=dtypes.half)
|
||||
c = a.matmul(b)
|
||||
c.realize()
|
||||
expected_ops = N ** 3 * 2
|
||||
assert expected_ops == GlobalCounters.global_ops
|
||||
|
||||
#MULACC should have the same stats as MUL + ADD
|
||||
def test_mulacc(self):
|
||||
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple())
|
||||
|
||||
@@ -472,5 +472,5 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is UOps.WMMA and u not in dont_count:
|
||||
assert u.arg[1] is not None
|
||||
flops += 2 * prod(u.arg[1]) // 32 * mults
|
||||
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return flops, mem
|
||||
|
||||
Reference in New Issue
Block a user