Tensor core ptx (#3894)

* tensor cores

* Merge from master

* faster program start in llvm (#3897)

* Fix the result permutation in einsum (#3895)

* Fix permutation of result indices in einsum.

* Delete stray line used for breaking tests

* Fix linter error by renaming twice-used variable

---------

Co-authored-by: chenyu <chenyu@fastmail.com>

* touchup einsum (#3900)

don't need rhs_letters

* hotfix check ckpts before writing achieved model (#3901)

this killed tinybox green run

* replace dtype.name str with render_dtype (#3903)

fixed some bf16 cast issue since it does not have `.name`.
also more robust if there are lang specific type override

* add --minimal flag to nvrtc (#3899)

* wmma: fix the AMD TC threads to split the first 16 threads (#3904)

previously it was incorrectly aliasing 16 into the size 8 upcast
on the store alias.  now it splits it properly into 8 and the
remaining 2 into the correct local stride

* training cifar with BF16 on CUDA (#3905)

* training cifar with BF16 on CUDA

memory usage is between float and half due to numpy calls on dataset preprocessing, which converts into float.

* simpler bf16 functions

* bf16 cifar works for HSA too just very slow

* simpler bf16 functions, we love cuda

* include negative float in test_dtype (#3884)

* include negative float in test_dtype

* that is ub

* too annoying

* pack can overflow

* add to benchmark

* change var name to satisfy mypy

* spacing

* Update to new TensorCore format

* Spacing

---------

Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
Co-authored-by: Alejandro F Queiruga <33233447+afqueiruga@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
Co-authored-by: sekstini <127142660+sekstini@users.noreply.github.com>
Co-authored-by: Francis Lam <flam@alum.mit.edu>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Szymon Ożóg
2024-04-04 16:32:31 +02:00
committed by GitHub
parent 92378fb5b6
commit 68fe3527f1
3 changed files with 13 additions and 2 deletions

View File

@@ -107,10 +107,12 @@ jobs:
run: CUDA=1 python3 test/external/external_model_benchmark.py
- name: Test speed vs torch
run: CUDA=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
- name: Run Tensor Core GEMM
- name: Run Tensor Core GEMM(CUDA)
run: |
CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
CUDA=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
- name: Run Tensor Core GEMM(PTX)
run: CUDA=1 PTX=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
- name: Run LLaMA
run: |
CUDA=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt

View File

@@ -51,7 +51,7 @@ class TensorCoreOptions(NamedTuple):
tensor_cores: Dict[str, List[TensorCore]] = {
"METAL": [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
"HSA": [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
"CUDA": [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)]], # noqa: E501
"CUDA": [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)] if getenv("PTX") else [(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])], # noqa: E501
}
class LocalBuffer(NamedTuple):

View File

@@ -192,6 +192,15 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
if lang.load_global:
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
kk(*lang.render_load(args[1], ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param"))
elif uop is UOps.WMMA:
wmma = []
for vv in vin[:2]:
for i in range(0, len(r[vv]), 2):
wmma.append(ssa(None, "wmma", "b32"))
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
r[u] = r[vin[2]]
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
{{{", ".join(r[u])}}}, {{{", ".join(wmma[:4])}}}, {{{", ".join(wmma[4:])}}}, {{{", ".join(r[u])}}};')
else: raise NotImplementedError(f"no code for {uop}")
return lang.render_kernel(kernel, function_name, bufs, c.items())