multitensor jit (#3149)

* initial multitensor jit support and tests

* Added graphs to multitensor jit and updated tests

* update unbind api

* fix set device, add TinyJit to resnet

* update_stats includes device

---------

Co-authored-by: ramenguy99 <ramenguy99@gmail.com>
This commit is contained in:
George Hotz
2024-01-16 09:09:15 -08:00
committed by GitHub
parent b9d470577c
commit 228f30b96a
9 changed files with 125 additions and 40 deletions

View File

@@ -22,7 +22,7 @@ def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer):
c[idx] = atan2(a[idx], b[idx]);
}"""
lib = Device[ret.device].compiler(src)
CompiledASTRunner(None, "atan2_gpu", src, lib, global_size=[ret.size]).build(Device[ret.device].runtime).exec([ret, a, b])
CompiledASTRunner(None, "atan2_gpu", src, lib, Device[ret.device], global_size=[ret.size]).build(Device[ret.device].runtime).exec([ret, a, b])
def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data)