support pickling tensors and others (#3787)

* test pickle tensors

* pickle unrealized tensor

* pickle jit, don't save Device in every CompiledASTRunner

* real test of pickle, move delete
This commit is contained in:
George Hotz
2024-03-17 18:29:14 -07:00
committed by GitHub
parent 5ac1fa933f
commit bf3e1c4df2
9 changed files with 70 additions and 24 deletions

View File

@@ -216,8 +216,8 @@ b2 = Buffer(dev, 9408, dtypes.float)
print(hex(b0._buf.value), hex(b0._buf.value+1605632*4))
print(hex(b1._buf.value))
print(hex(b2._buf.value))
#prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", Device[dev], [7, 1, 1], [8, 4, 1], precompiled=lib)
prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", Device[dev], [49, 8, 2], [8, 4, 1], precompiled=lib)
#prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [7, 1, 1], [8, 4, 1], precompiled=lib)
prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [49, 8, 2], [8, 4, 1], precompiled=lib)
print("compiled")
prg([b0, b1, b2], {})
print("ran")

View File

@@ -38,7 +38,7 @@ if __name__ == "__main__":
lin.linearize()
ptx_src = ptx.render(to_function_name(lin.name), lin.uops)
try:
ptx_prg = CompiledASTRunner(lin.name, ptx_src, dev, lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src))
ptx_prg = CompiledASTRunner(lin.name, ptx_src, "CUDA", lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src))
except RuntimeError:
print("PTX FAIL")
continue