fix intel wmma flop counting, add flop counting tests for different tensor cores (#6192)

* fix wmma flop counting on intel, add count tests

* half

* add half gemm

* Update test.yml

* one test

* Update test_uops_stats.py

* Update test_uops_stats.py

* Update test_uops_stats.py

* smaller matrix, use unittest skipUnless decorator
This commit is contained in:
CaltropHungerton
2024-08-25 20:37:05 -05:00
committed by GitHub
parent 331b0f5477
commit 002f60b4c3
3 changed files with 20 additions and 4 deletions

View File

@@ -57,9 +57,15 @@ jobs:
- name: Test tensor cores (TC=3)
run: |
TC=3 DEBUG=3 EMULATE_METAL=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_AMD=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
TC=3 DEBUG=3 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_INTEL=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 python3 ./extra/gemm/simple_matmul.py
TC=3 PYTHONPATH=. DEBUG=3 EMULATE_INTEL=1 PYTHON=1 N=16 HALF=1 python3 ./extra/gemm/simple_matmul.py
- name: Test device flop counts
run: |
PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
PYTHONPATH=. DEBUG=2 EMULATE_INTEL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
- name: Test dtype with Python emulator
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py
- name: Test ops with Python emulator

View File

@@ -1,6 +1,6 @@
import unittest
from tinygrad import Tensor
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, GlobalCounters
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item
from tinygrad.codegen.uopgraph import linearize_uop
@@ -95,6 +95,16 @@ class TestUOpsStats(unittest.TestCase):
# NOTE: it's hard to assert on the memory here, all depends on caching
assert required_mem <= mem
@unittest.skipUnless(getenv("PYTHON"), "only run test on emulated tensor cores")
def test_simple_matmul_half(self):
GlobalCounters.reset()
N = 16
a, b = Tensor.empty(N, N, dtype=dtypes.half), Tensor.empty(N, N, dtype=dtypes.half)
c = a.matmul(b)
c.realize()
expected_ops = N ** 3 * 2
assert expected_ops == GlobalCounters.global_ops
#MULACC should have the same stats as MUL + ADD
def test_mulacc(self):
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple())

View File

@@ -472,5 +472,5 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
elif u.op is UOps.WMMA and u not in dont_count:
assert u.arg[1] is not None
flops += 2 * prod(u.arg[1]) // 32 * mults
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return flops, mem