Files
tinygrad/tinygrad/realize.py
George Hotz 6d6eb9302d ruff checks the max line length is 150 (#2734)
* ruff checks the max line length is 150

* fix tensor.py

* a lot more

* done
2023-12-12 17:34:47 -08:00

38 lines
1.8 KiB
Python

from typing import List, Dict, Optional
from tinygrad.ops import LoadOps, ScheduleItem
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner
from tinygrad.graph import log_schedule_item, print_tree
from tinygrad.helpers import prod
from tinygrad.shape.symbolic import Variable
class CustomOp(JITRunner):
def __init__(self, fxn):
self.fxn = fxn
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.COPY, f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" # noqa: E501
if si.ast.op is LoadOps.EMPTY: return None
if si.ast.op is LoadOps.COPY: return BufferCopy
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
return Device[si.out.device].get_runner(si.ast)
def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
while len(schedule):
si = schedule.pop(0)
if not disable_logging: log_schedule_item(si)
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
# get the program
prg = lower_schedule_item(si)
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype)
del si.out.op
for v in si.out.views: del v.op
# run the function (put it in JIT)
if prg: prg.exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals)