mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Add TritonProgram
This commit is contained in:
@@ -15,10 +15,19 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.runtime.ops_gpu import CLBuffer
|
||||
|
||||
class TritonProgram:
|
||||
|
||||
def __init__(self, name:str, prg:str):
|
||||
hash = hashlib.md5(prg.encode('utf-8')).hexdigest()
|
||||
fn = f"/tmp/{hash}.py"
|
||||
with open(fn, "w") as f: f.write(prg)
|
||||
codeObject = compile(prg, fn, "exec")
|
||||
exec(codeObject, globals())
|
||||
self.program = globals()["fxn"]
|
||||
|
||||
|
||||
def __call__(self, global_size, local_size, *args, wait=False) -> Any:
|
||||
self.program(*[x._buf for x in args])
|
||||
|
||||
class TritonDeviceAllocation(CLBuffer):
|
||||
def __init__(self, size):
|
||||
|
||||
Reference in New Issue
Block a user