PTX - implement float 4, ptr arithmetics and other speed improvements (#3775)

* ptx float4 implementation

* remove from cache when trimming uops

* Gate for float4

* Linting fix

* disable test reasonable time for ptx

* import getenv

* Update uops.py

* linter

* Add div test for half

* upcast if op does not support operation

* fix offset

* Run only if dtype supported

* zero out registers when accessing by pred + cleanup

* Remove trailing whitespace

* revert

* spacing fix

* move cache clearing outside loop

* did this suddenly start working?

* unused import removed

* Remove cast

* Use pattern matching

* linting

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Szymon Ożóg
2024-03-22 16:54:02 +01:00
committed by GitHub
parent f4055439dc
commit 624bc89910
3 changed files with 118 additions and 67 deletions

View File

@@ -52,8 +52,7 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes:
return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
class PTXCompiler(Compiler):
linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024],
supports_float4=False, shared_max=49152)
linearizer_opts = LinearizerOptions("CUDA", suffix="PTX", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], shared_max=49152)
def __init__(self, arch:str):
self.arch = arch
PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)