From 13e45691b4b196be63ae810ab072985ce84cbb40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20O=C5=BC=C3=B3g?= <58388001+SzymonOzog@users.noreply.github.com> Date: Tue, 15 Aug 2023 19:47:12 +0200 Subject: [PATCH] Add TritonProgram --- tinygrad/runtime/ops_triton.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tinygrad/runtime/ops_triton.py b/tinygrad/runtime/ops_triton.py index e70794a0d6..920c4eb191 100644 --- a/tinygrad/runtime/ops_triton.py +++ b/tinygrad/runtime/ops_triton.py @@ -15,10 +15,19 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.helpers import prod, DEBUG from tinygrad.runtime.ops_gpu import CLBuffer +class TritonProgram: + def __init__(self, name:str, prg:str): + hash = hashlib.md5(prg.encode('utf-8')).hexdigest() fn = f"/tmp/{hash}.py" + with open(fn, "w") as f: f.write(prg) + codeObject = compile(prg, fn, "exec") exec(codeObject, globals()) + self.program = globals()["fxn"] + + def __call__(self, global_size, local_size, *args, wait=False) -> Any: + self.program(*[x._buf for x in args]) class TritonDeviceAllocation(CLBuffer): def __init__(self, size):