mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Schedule item (#2012)
* ScheduleItem * put var_vals in the schedule * fix tests, wow that proliferated quickly * not ready to be in the schedule
This commit is contained in:
@@ -267,7 +267,7 @@ result = Tensor(2).realize() + Tensor(3).realize()
|
||||
# use the real Linearizer to linearize 2+3
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
sched = result.lazydata.schedule()
|
||||
linearizer = Linearizer(sched[-1][0])
|
||||
linearizer = Linearizer(sched[-1].ast)
|
||||
linearizer.linearize()
|
||||
|
||||
# print the uops
|
||||
|
||||
@@ -19,12 +19,12 @@ if __name__ == "__main__":
|
||||
x = Tensor.empty(64, 3, 224, 224)
|
||||
out = mdl(x)
|
||||
sched = out.lazydata.schedule(seen)
|
||||
sched = [x for x in sched if x[0].op not in LoadOps]
|
||||
sched = [x for x in sched if x.ast.op not in LoadOps]
|
||||
|
||||
# work with the schedule
|
||||
total_tm = 0
|
||||
for i,(op,out,inp) in enumerate(sched):
|
||||
if DEBUG >= 2: print_tree(op)
|
||||
for i,si in enumerate(sched):
|
||||
if DEBUG >= 2: print_tree(si.ast)
|
||||
|
||||
# enable only one kernel to focus on it
|
||||
#if i != 1: continue
|
||||
@@ -37,7 +37,7 @@ if __name__ == "__main__":
|
||||
for big_chomp in [1,2]: #[1,2,4,8,16]:
|
||||
for lil_chomp in [2,4,7,8,14]:
|
||||
for upcasted in [0,1,2]:
|
||||
lin = Linearizer(op, LinearizerOptions(device="METAL"))
|
||||
lin = Linearizer(si.ast, LinearizerOptions(device="METAL"))
|
||||
lin.reshape_and_permute(lambda x: (4096//big_chomp,big_chomp,56//lil_chomp,lil_chomp,56//lil_chomp,lil_chomp)+x[-2:], [0,2,4,1,3,5,6,7])
|
||||
lin.upcasted += upcasted
|
||||
lin.local_dims += 3
|
||||
@@ -45,13 +45,13 @@ if __name__ == "__main__":
|
||||
else:
|
||||
# try with and without tensor cores
|
||||
for tc in [0,1]:
|
||||
lin = Linearizer(op, LinearizerOptions(device="METAL"))
|
||||
lin = Linearizer(si.ast, LinearizerOptions(device="METAL"))
|
||||
lin.hand_coded_optimizations(use_tensor_cores=tc)
|
||||
lins.append(lin)
|
||||
|
||||
# create output/input buffers
|
||||
rout = RawMetalBuffer(out.st.size(), out.dtype)
|
||||
rin = [RawMetalBuffer(x.st.size(), x.dtype) for x in inp]
|
||||
rout = RawMetalBuffer(si.out.st.size(), si.out.dtype)
|
||||
rin = [RawMetalBuffer(x.st.size(), x.dtype) for x in si.inputs]
|
||||
|
||||
# benchmark the programs
|
||||
choices = []
|
||||
|
||||
@@ -32,7 +32,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# these are of size 3 to avoid float4 coalesce
|
||||
r = a[:-1] + a[1:]
|
||||
|
||||
k = Linearizer(r.lazydata.schedule()[-1][0])
|
||||
k = Linearizer(r.lazydata.schedule()[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD])
|
||||
@@ -48,7 +48,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = a.expand([2]) + b.expand([2])
|
||||
|
||||
k = Linearizer(r.lazydata.schedule()[-1][0])
|
||||
k = Linearizer(r.lazydata.schedule()[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
@@ -61,7 +61,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack([a, b])
|
||||
|
||||
k = Linearizer(r.lazydata.schedule()[-1][0])
|
||||
k = Linearizer(r.lazydata.schedule()[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
@@ -106,8 +106,8 @@ def helper_realized_ast(r:Tensor):
|
||||
s = r.lazydata.schedule()
|
||||
run_schedule(s[:-1]) # run all kernels except the last one
|
||||
# now all input LazyBuffers buffers in s[-1] should be realized
|
||||
output_buffer = Device[s[-1][1].device].buffer(prod((s if isinstance(s, int) else s.max for s in s[-1][1].shape)), s[-1][1].dtype, **s[-1][1]._device_extra_args()) # allocate an output buffer
|
||||
return s[-1][0], [output_buffer] + [l.realized for l in s[-1][2]]
|
||||
output_buffer = Device[s[-1].out.device].buffer(prod((s if isinstance(s, int) else s.max for s in s[-1].out.shape)), s[-1].out.dtype, **s[-1].out._device_extra_args()) # allocate an output buffer
|
||||
return s[-1].ast, [output_buffer] + [l.realized for l in s[-1].inputs]
|
||||
|
||||
def helper_linearizer_opt(r:Tensor, opts=[]):
|
||||
wanna_output = None
|
||||
@@ -232,7 +232,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
k = Linearizer(s[0])
|
||||
k = Linearizer(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
@@ -244,7 +244,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
k = Linearizer(s[0])
|
||||
k = Linearizer(s.ast)
|
||||
k.shift_to(0, 4) # float4 dimension
|
||||
k.shift_to(0, 2, insert_before=k.shape_len-1)
|
||||
k.upcast()
|
||||
@@ -260,7 +260,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
k = Linearizer(s[0])
|
||||
k = Linearizer(s.ast)
|
||||
k.hand_coded_optimizations() # implicit trigger float4 dim
|
||||
k.linearize()
|
||||
|
||||
@@ -272,7 +272,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
k = Linearizer(s[0])
|
||||
k = Linearizer(s.ast)
|
||||
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
|
||||
k.upcast()
|
||||
k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1)
|
||||
@@ -290,7 +290,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
k = Linearizer(s[0])
|
||||
k = Linearizer(s.ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
|
||||
@@ -305,7 +305,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# don't.
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
k = Linearizer(s[0])
|
||||
k = Linearizer(s.ast)
|
||||
k.upcast()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -321,7 +321,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
k = Linearizer(s[0])
|
||||
k = Linearizer(s.ast)
|
||||
k.shift_to(0, 4, top=True) # top axes are float4 axes
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -337,7 +337,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
k = Linearizer(s[0])
|
||||
k = Linearizer(s.ast)
|
||||
k.shift_to(0, 4) # float4 axis
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -352,7 +352,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# should float4 b but not a
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
k = Linearizer(s[0])
|
||||
k = Linearizer(s.ast)
|
||||
k.shift_to(0, 4) # float4 axis
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
|
||||
@@ -16,21 +16,21 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
for s in pre.lazydata.schedule(seen.copy()):
|
||||
log_schedule_item(*s)
|
||||
seen.add(s[1])
|
||||
log_schedule_item(s)
|
||||
seen.add(s.out)
|
||||
sched = t.lazydata.schedule(seen)
|
||||
for s in sched: log_schedule_item(*s)
|
||||
if filter_loadops: sched = [s for s in sched if s[0].op not in LoadOps]
|
||||
for s in sched: log_schedule_item(s)
|
||||
if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
for i, s in enumerate(sched):
|
||||
print("op", i)
|
||||
print_tree(s[0])
|
||||
print_tree(s.ast)
|
||||
assert len(sched) == allowed
|
||||
# test the (non loadops) ops linearize
|
||||
for s in sched:
|
||||
if s[0].op in LoadOps: continue
|
||||
l = Linearizer(s[0])
|
||||
if s.ast.op in LoadOps: continue
|
||||
l = Linearizer(s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
|
||||
|
||||
@@ -22,10 +22,10 @@ class TestWinograd(unittest.TestCase):
|
||||
sched = out.lazydata.schedule()
|
||||
|
||||
for i,s in enumerate(sched):
|
||||
if s[0].op in LoadOps: continue
|
||||
ops = s[0].get_lazyops()
|
||||
if s.ast.op in LoadOps: continue
|
||||
ops = s.ast.get_lazyops()
|
||||
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
|
||||
l = Linearizer(s[0])
|
||||
l = Linearizer(s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
|
||||
|
||||
@@ -4,12 +4,10 @@ try:
|
||||
except ImportError:
|
||||
nx = None # graph won't work
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, TYPE_CHECKING, Tuple
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
|
||||
from typing import Dict, List
|
||||
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters
|
||||
|
||||
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
|
||||
|
||||
# **** debugging and graphing ****
|
||||
|
||||
G = nx.DiGraph() if nx is not None else None
|
||||
@@ -47,31 +45,31 @@ def str_dtype(dtyp):
|
||||
ret = str(dtyp)[7:]
|
||||
return "" if ret == 'float' else f"\n{ret}"
|
||||
|
||||
def log_schedule_item(iop: LazyOp, ret: 'LazyBuffer', inp: Tuple['LazyBuffer', ...]):
|
||||
def log_schedule_item(si: ScheduleItem):
|
||||
show_graph = bool(GRAPH)
|
||||
if not DEBUG and not show_graph: return
|
||||
if iop.op == LoadOps.CONTIGUOUS: setattr(ret, 'node_id', nm(inp[0].base))
|
||||
if iop.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
|
||||
if si.ast.op == LoadOps.CONTIGUOUS: setattr(si.out, 'node_id', nm(si.inputs[0].base))
|
||||
if si.ast.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
|
||||
|
||||
op: List[Op] = [x.op for x in iop.get_lazyops()]
|
||||
op: List[Op] = [x.op for x in si.ast.get_lazyops()]
|
||||
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps]
|
||||
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
||||
cnts[optype] += 1
|
||||
if show_graph:
|
||||
assert ret.base == ret, "all outputs based"
|
||||
assert si.out.base == si.out, "all outputs based"
|
||||
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'}
|
||||
for x in inp:
|
||||
for x in si.inputs:
|
||||
assert x.base == x, "all inputs based"
|
||||
#assert nm(x) in G.nodes, "all inputs seen"
|
||||
G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060')
|
||||
G.add_edge(nm(x), nm(si.out), label=get_sop(op), color='#00000060')
|
||||
if 'label' not in G.nodes[nm(x)]:
|
||||
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(ret.dtype)
|
||||
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
||||
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(si.out.dtype)
|
||||
if nm(si.out) not in G.nodes: G.add_node(nm(si.out))
|
||||
|
||||
G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)+(f"\n{iop.op}" if iop.op in LoadOps else "")
|
||||
G.nodes[nm(ret)]['fillcolor'] = top_colors[optype]
|
||||
G.nodes[nm(ret)]['color'] = 'black'
|
||||
G.nodes[nm(ret)]['style'] = 'filled'
|
||||
G.nodes[nm(si.out)]['label'] = (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps else "")
|
||||
G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype]
|
||||
G.nodes[nm(si.out)]['color'] = 'black'
|
||||
G.nodes[nm(si.out)]['style'] = 'filled'
|
||||
|
||||
def _tree(lazydata, prefix=""):
|
||||
if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")
|
||||
|
||||
@@ -5,7 +5,7 @@ from weakref import ref, WeakSet, WeakValueDictionary
|
||||
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, partition, dedup, merge_dicts
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
|
||||
@@ -79,7 +79,7 @@ def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(
|
||||
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
|
||||
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
|
||||
|
||||
def var_vals_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.var_vals() for x in ast.get_lazyops() if x.op in BufferOps]))
|
||||
def var_vals_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.var_vals() for x in ast.get_lazyops() if x.op in BufferOps], []))
|
||||
|
||||
lazycache: WeakValueDictionary = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int], base:Optional[LazyBuffer]=None):
|
||||
@@ -159,7 +159,7 @@ class LazyBuffer:
|
||||
|
||||
# *** scheduling ***
|
||||
|
||||
def schedule(self, seen=None) -> List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]:
|
||||
def schedule(self, seen=None) -> List[ScheduleItem]:
|
||||
if seen is None: seen = set()
|
||||
if self in seen or self.realized or self.is_unrealized_const(): return []
|
||||
seen.add(self)
|
||||
@@ -180,7 +180,7 @@ class LazyBuffer:
|
||||
|
||||
# run the ast and log the op
|
||||
op, base_bufs = _replace_bufferops(op)
|
||||
return ret + [(op, self, tuple(base_bufs))]
|
||||
return ret + [ScheduleItem(op, self, tuple(base_bufs))]
|
||||
|
||||
# *** creation/special ops ***
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import numpy as np
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
from dataclasses import dataclass
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
@@ -38,6 +40,13 @@ class ConstBuffer:
|
||||
dtype: DType
|
||||
st: ShapeTracker
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: LazyOp
|
||||
out: LazyBuffer
|
||||
inputs: Tuple[LazyBuffer, ...]
|
||||
# TODO: move var_vals here
|
||||
|
||||
class LazyOp:
|
||||
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
|
||||
op: Op
|
||||
@@ -154,9 +163,6 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
|
||||
|
||||
# **************** for Compiled Buffers ****************
|
||||
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
|
||||
class BasicBatchExecutor:
|
||||
def __init__(self, jit_cache:List[Tuple[Any, Any, Any]]): pass
|
||||
def exec(self, jit_cache: List[Tuple[Any, Any, Any]], updatable_entries):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import List, Tuple, cast, Dict, Callable
|
||||
from typing import List, cast, Dict, Callable
|
||||
import numpy as np
|
||||
from tinygrad.ops import LazyOp, LoadOps, Device, UnaryOps, BufferOps, MemBuffer, get_lazyop_info
|
||||
import dataclasses
|
||||
from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, Device, UnaryOps, BufferOps, MemBuffer, get_lazyop_info
|
||||
from tinygrad.graph import log_schedule_item, print_tree
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE, ImageDType, dtypes
|
||||
@@ -8,59 +9,60 @@ from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE, ImageDType, dt
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
|
||||
def fix_schedule_for_images(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]):
|
||||
def fix_schedule_for_images(schedule:List[ScheduleItem]):
|
||||
# this is the fundamental fix, find unwritable or unreadable images and convert them to normal float32 (TODO: should it be float16?)
|
||||
for op,out,buffers in schedule:
|
||||
if isinstance(out.dtype, ImageDType) and (prod(out.shape) != prod(out.dtype.shape) or not any(out.shape[x]%4 == 0 for x in out.st.unit_stride_axes())):
|
||||
out.dtype = dtypes.float32
|
||||
bops = [x for x in op.get_lazyops() if x.op == BufferOps.MEM]
|
||||
for b in bops:
|
||||
if isinstance(buffers[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())):
|
||||
buffers[b.arg.idx-1].dtype = dtypes.float32
|
||||
for si in schedule:
|
||||
if isinstance(si.out.dtype, ImageDType) and (prod(si.out.shape) != prod(si.out.dtype.shape) or not any(si.out.shape[x]%4 == 0 for x in si.out.st.unit_stride_axes())):
|
||||
si.out.dtype = dtypes.float32
|
||||
for b in si.ast.get_lazyops():
|
||||
if b.op != BufferOps.MEM: continue
|
||||
if isinstance(si.inputs[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())):
|
||||
si.inputs[b.arg.idx-1].dtype = dtypes.float32
|
||||
|
||||
# now fix up the schedule to reflect the new dtypes
|
||||
fixed_schedule = []
|
||||
for op,out,buffers in schedule:
|
||||
fixed_schedule:List[ScheduleItem] = []
|
||||
for si in schedule:
|
||||
ast = si.ast
|
||||
# fix input dtypes to match what they actually are
|
||||
bops = [x for x in op.get_lazyops() if x.op == BufferOps.MEM]
|
||||
replacements = {}
|
||||
for x in bops:
|
||||
if x.arg.dtype != buffers[x.arg.idx-1].dtype:
|
||||
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(x.arg.idx, buffers[x.arg.idx-1].dtype, x.arg.st))
|
||||
if replacements: op = op.map_buffers(replacements)
|
||||
for b in si.ast.get_lazyops():
|
||||
if b.op != BufferOps.MEM: continue
|
||||
if b.arg.dtype != si.inputs[b.arg.idx-1].dtype:
|
||||
replacements[b] = LazyOp(BufferOps.MEM, (), MemBuffer(b.arg.idx, si.inputs[b.arg.idx-1].dtype, b.arg.st))
|
||||
if replacements: ast = ast.map_buffers(replacements)
|
||||
|
||||
# fix the ops to create the output dtype
|
||||
if op.op not in LoadOps:
|
||||
info = get_lazyop_info(op)
|
||||
if info.dtype != out.dtype:
|
||||
op = LazyOp(UnaryOps.CAST, (op,), (out.dtype, False))
|
||||
if ast.op not in LoadOps:
|
||||
info = get_lazyop_info(ast)
|
||||
if info.dtype != si.out.dtype:
|
||||
ast = LazyOp(UnaryOps.CAST, (ast,), (si.out.dtype, False))
|
||||
|
||||
# put this in the fixed schedule
|
||||
fixed_schedule.append((op, out, buffers))
|
||||
fixed_schedule.append(dataclasses.replace(si, ast=ast))
|
||||
return fixed_schedule
|
||||
|
||||
# *** this is where things happen ***
|
||||
|
||||
def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]):
|
||||
def run_schedule(schedule:List[ScheduleItem]):
|
||||
# HACK: images can be not usable due to shape
|
||||
if IMAGE >= 2: schedule = fix_schedule_for_images(schedule)
|
||||
|
||||
# NOTE: if you for loop the schedule it's slow because nothing frees
|
||||
while len(schedule):
|
||||
op,out,buffers = schedule.pop(0)
|
||||
log_schedule_item(op, out, buffers)
|
||||
assert all(x.realized for x in buffers), "can't run schedule, some buffers aren't realized"
|
||||
if DEBUG >= 3: print_tree(op)
|
||||
if op.op in LoadOps:
|
||||
si = schedule.pop(0)
|
||||
log_schedule_item(si)
|
||||
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
|
||||
if DEBUG >= 3: print_tree(si.ast)
|
||||
if si.ast.op in LoadOps:
|
||||
# confirm the LoadOps are contiguous and in order
|
||||
for i,s in enumerate(op.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out, *buffers)
|
||||
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
|
||||
else:
|
||||
out.realized = Device[out.device].exec_ast(op, output=out, inputs=buffers, var_vals=out.var_vals, **out._device_extra_args())
|
||||
del out.op
|
||||
for v in out.views: del v.op
|
||||
assert out.realized and isinstance(out.realized, Device[out.device].buffer), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
|
||||
assert out.realized.dtype == out.dtype, "realized dtype is incorrect"
|
||||
si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.out.var_vals, **si.out._device_extra_args())
|
||||
del si.out.op
|
||||
for v in si.out.views: del v.op
|
||||
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
|
||||
assert si.out.realized.dtype == si.out.dtype, "realized dtype is incorrect"
|
||||
|
||||
# *** zero op LoadOps ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user