diff --git a/tinygrad/runtime/ops_triton.py b/tinygrad/runtime/ops_triton.py index e70794a0d6..920c4eb191 100644 --- a/tinygrad/runtime/ops_triton.py +++ b/tinygrad/runtime/ops_triton.py @@ -15,10 +15,19 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.helpers import prod, DEBUG from tinygrad.runtime.ops_gpu import CLBuffer +class TritonProgram: + def __init__(self, name:str, prg:str): + hash = hashlib.md5(prg.encode('utf-8')).hexdigest() fn = f"/tmp/{hash}.py" + with open(fn, "w") as f: f.write(prg) + codeObject = compile(prg, fn, "exec") exec(codeObject, globals()) + self.program = globals()["fxn"] + + def __call__(self, global_size, local_size, *args, wait=False) -> Any: + self.program(*[x._buf for x in args]) class TritonDeviceAllocation(CLBuffer): def __init__(self, size):