Add TritonProgram

This commit is contained in:
Szymon Ożóg
2023-08-15 19:47:12 +02:00
parent 83516c6ec8
commit 13e45691b4

View File

@@ -15,10 +15,19 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.helpers import prod, DEBUG
from tinygrad.runtime.ops_gpu import CLBuffer
class TritonProgram:
def __init__(self, name:str, prg:str):
hash = hashlib.md5(prg.encode('utf-8')).hexdigest()
fn = f"/tmp/{hash}.py"
with open(fn, "w") as f: f.write(prg)
codeObject = compile(prg, fn, "exec")
exec(codeObject, globals())
self.program = globals()["fxn"]
def __call__(self, global_size, local_size, *args, wait=False) -> Any:
self.program(*[x._buf for x in args])
class TritonDeviceAllocation(CLBuffer):
def __init__(self, size):