mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
@@ -39,7 +39,7 @@ DEVICE = "CLANG" # NOTE: you can change this!
|
||||
import struct
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UOps
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# allocate some buffers + load in values
|
||||
@@ -49,14 +49,14 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
|
||||
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
|
||||
|
||||
# describe the computation
|
||||
buf_1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
|
||||
buf_2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2)
|
||||
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop()))
|
||||
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop()))
|
||||
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
|
||||
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2)
|
||||
ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop()))
|
||||
ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop()))
|
||||
alu = ld_1 + ld_2
|
||||
output_buf = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
st_0 = UOp(UOps.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu))
|
||||
s = UOp(UOps.SINK, dtypes.void, (st_0,))
|
||||
output_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
st_0 = UOp(Ops.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu))
|
||||
s = UOp(Ops.SINK, dtypes.void, (st_0,))
|
||||
|
||||
# convert the computation to a "linearized" format (print the format)
|
||||
from tinygrad.engine.realize import get_kernel, CompiledRunner
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
members_order: source
|
||||
show_labels: false
|
||||
|
||||
::: tinygrad.ops.UOps
|
||||
::: tinygrad.ops.Ops
|
||||
options:
|
||||
members: true
|
||||
members_order: source
|
||||
|
||||
@@ -4,7 +4,7 @@ from extra.mcts_search import mcts_search
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model
|
||||
from tinygrad import Tensor, Device, dtypes, nn
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.ops import UOps, sym_infer
|
||||
from tinygrad.ops import Ops, sym_infer
|
||||
from tinygrad.device import Compiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
|
||||
@@ -66,7 +66,7 @@ if __name__ == "__main__":
|
||||
print(f"optimizing for {Device.DEFAULT}")
|
||||
|
||||
sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]()
|
||||
sched = [x for x in sched if x.ast.op is UOps.SINK]
|
||||
sched = [x for x in sched if x.ast.op is Ops.SINK]
|
||||
|
||||
# focus on one kernel
|
||||
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCou
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_kernel, run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.ops import MetaOps, UOps
|
||||
from tinygrad.ops import MetaOps, Ops
|
||||
|
||||
TIMING = getenv("TIMING")
|
||||
|
||||
@@ -41,7 +41,7 @@ if __name__ == "__main__":
|
||||
print(f"calls {i}:", len(sched))
|
||||
#run_schedule(sched[:])
|
||||
sched = memory_planner(sched)
|
||||
ast_dedup = dedup([si.ast for si in sched if si.ast.op is UOps.SINK])
|
||||
ast_dedup = dedup([si.ast for si in sched if si.ast.op is Ops.SINK])
|
||||
srcs = {}
|
||||
for ast in ast_dedup:
|
||||
k = get_kernel(Device["CLANG"].renderer, ast)
|
||||
@@ -82,7 +82,7 @@ if __name__ == "__main__":
|
||||
for i,si in enumerate(sched):
|
||||
bufs = [(named_buffers.get(b, f"b{numbered_bufs[b]}"), b) for b in si.bufs]
|
||||
all_bufs += bufs
|
||||
if si.ast.op is not UOps.SINK:
|
||||
if si.ast.op is not Ops.SINK:
|
||||
print(f"// {si.ast.op}", bufs)
|
||||
else:
|
||||
print(f"{srcs[si.ast][0]}({', '.join([x[0] for x in bufs])})")
|
||||
|
||||
@@ -19,7 +19,7 @@ from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
Device.DEFAULT = "GPU"
|
||||
|
||||
@@ -50,7 +50,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
||||
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
|
||||
|
||||
# confirm no non-sink metaop in the (non independent) schedule except for the ones that load the input buffers
|
||||
assert all(si.ast.op is UOps.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
|
||||
assert all(si.ast.op is Ops.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
|
||||
return schedule, schedule_independent, inputs
|
||||
|
||||
def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tensor]):
|
||||
@@ -106,7 +106,7 @@ if __name__ == "__main__":
|
||||
#exit(0)
|
||||
|
||||
schedule, schedule_independent, inputs = get_schedule(onnx_data)
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is UOps.SINK)
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is Ops.SINK)
|
||||
print(f"{len(schedule_input)} inputs")
|
||||
|
||||
run_schedule(schedule_independent)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
|
||||
from tinygrad.codegen.kernel import UOps, MemOp, UOp
|
||||
from tinygrad.codegen.kernel import Ops, MemOp, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
@@ -23,7 +23,7 @@ class Register(NamedTuple):
|
||||
return []
|
||||
|
||||
class AssemblyInstruction(NamedTuple):
|
||||
op: UOps
|
||||
op: Ops
|
||||
out: Optional[Register]
|
||||
vin: List[Union[Register, int, float]]
|
||||
arg: Any = None
|
||||
@@ -49,21 +49,21 @@ class AssemblyLanguage:
|
||||
|
||||
def render_numnode(self, b) -> Register:
|
||||
key = ("num", b)
|
||||
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||
if key not in self.tor: self.ins.append(AssemblyInstruction(Ops.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
|
||||
return self.tor[key]
|
||||
|
||||
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
|
||||
key = (op, a, b)
|
||||
if key not in self.tor:
|
||||
#if not isinstance(b, Register): b = render_numnode(b)
|
||||
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
self.ins.append(AssemblyInstruction(Ops.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
|
||||
return self.tor[key]
|
||||
|
||||
def render_cast(self, a:Register, new_dtype:DType) -> Register:
|
||||
if a.dtype == new_dtype: return a
|
||||
key = (a, new_dtype)
|
||||
if key not in self.tor:
|
||||
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
||||
self.ins.append(AssemblyInstruction(Ops.CAST, self.newreg(key, dtype=new_dtype), [a]))
|
||||
return self.tor[key]
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
|
||||
@@ -87,7 +87,7 @@ class AssemblyLanguage:
|
||||
if self.supports_load3:
|
||||
if reg.scalar:
|
||||
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
|
||||
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
self.ins.append(AssemblyInstruction(Ops.ALU, new_reg, [reg], UnaryOps.NOOP))
|
||||
reg = new_reg
|
||||
return self.tor[args.name], reg, off
|
||||
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
|
||||
@@ -98,91 +98,91 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
|
||||
lang.ins.clear()
|
||||
lang.tor.clear()
|
||||
lang.cnts.clear()
|
||||
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
|
||||
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == Ops.DEFINE_GLOBAL}
|
||||
global_size, local_size = [], []
|
||||
skipload_branch = 0
|
||||
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||||
lang.ins += [AssemblyInstruction(Ops.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
|
||||
for u in uops:
|
||||
uop,dtype,vin,args,_ = u
|
||||
if uop == UOps.DEFINE_LOCAL:
|
||||
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
elif uop == UOps.LOOP:
|
||||
if uop == Ops.DEFINE_LOCAL:
|
||||
lang.ins.append(AssemblyInstruction(Ops.DEFINE_LOCAL, None, [], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
|
||||
elif uop == Ops.LOOP:
|
||||
if args[1] == "global":
|
||||
for i,var in enumerate(args[0]):
|
||||
global_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
|
||||
elif args[1] == "local":
|
||||
for i,var in enumerate(args[0]):
|
||||
local_size.append(var.max+1)
|
||||
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
|
||||
else:
|
||||
for var in args[0]:
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == UOps.ENDLOOP:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
|
||||
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], "$loop_"+var.expr))
|
||||
elif uop == Ops.ENDLOOP:
|
||||
if args[1] not in ["global", "local", "global+local"]:
|
||||
for var in reversed(args[0]):
|
||||
if not isinstance(var, NumNode): # TODO: why is this coming through?
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
|
||||
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
|
||||
elif args[1] == "global+local":
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
|
||||
elif args[1] == 'local':
|
||||
for i, var in enumerate(reversed(args[0])):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
||||
elif uop == UOps.CAST:
|
||||
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
|
||||
elif uop == Ops.CAST:
|
||||
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
|
||||
out = lang.newreg(u, dtype)
|
||||
for i,sr in enumerate(out.subregs()):
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
||||
elif uop == UOps.ALU:
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
|
||||
elif uop == Ops.ALU:
|
||||
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
|
||||
# this is the only thing that can violate SSA
|
||||
if args in [BinaryOps.CMPLT]:
|
||||
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
||||
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, pred_reg, [lang.tor[x] for x in vin], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.CAST, out, [pred_reg], args))
|
||||
elif args == BinaryOps.DIV and lang.no_div:
|
||||
tmp = lang.newreg((u, "rcp"))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
|
||||
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
|
||||
tmp = lang.newreg((u, "2pi"))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [tmp], args))
|
||||
else:
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[x] for x in vin], args))
|
||||
elif uop == Ops.DEFINE_ACC:
|
||||
reg = lang.newreg(u, dtype=dtype)
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
|
||||
elif uop == UOps.SPECIAL:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], args))
|
||||
elif uop == Ops.SPECIAL:
|
||||
lang.tor[u] = lang.tor[args]
|
||||
elif uop == UOps.CONST:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args))
|
||||
elif uop == UOps.LOAD:
|
||||
elif uop == Ops.CONST:
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(u, dtype=dtype), [], args))
|
||||
elif uop == Ops.LOAD:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
|
||||
if args.valid.min == 0:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], 0))
|
||||
if args.valid.max == 1:
|
||||
pred = args.valid.render(lang.render_ops, lang)
|
||||
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
|
||||
if args.valid.max == 1:
|
||||
# NOTE: you can't compute the index in here, because it assumes it's all available later
|
||||
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], f"$skipload_{skipload_branch}"))
|
||||
skipload_branch += 1
|
||||
elif uop == UOps.STORE:
|
||||
elif uop == Ops.STORE:
|
||||
if args is None:
|
||||
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
|
||||
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
|
||||
else:
|
||||
idx, treg, off = lang.addr_w_offset(args)
|
||||
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
lang.ins.append(AssemblyInstruction(Ops.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
|
||||
|
||||
if DEBUG >= 4:
|
||||
for tins in lang.ins: print(tins)
|
||||
|
||||
@@ -3,7 +3,7 @@ from platform import system
|
||||
from typing import Tuple, Dict, List, Optional
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.codegen.kernel import UOps, UOp
|
||||
from tinygrad.codegen.kernel import Ops, UOp
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
|
||||
@@ -19,7 +19,7 @@ class ARM64Language(AssemblyLanguage): pass
|
||||
|
||||
def specialize_to_arm64(fn_nm, asm):
|
||||
var_size = 16
|
||||
prev_uop:Optional[UOps] = None
|
||||
prev_uop:Optional[Ops] = None
|
||||
ins = []
|
||||
x_regs = ['x' + str(i) for i in reversed(range(12))]
|
||||
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
|
||||
@@ -81,7 +81,7 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
||||
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
||||
|
||||
if uop == UOps.SPECIAL:
|
||||
if uop == Ops.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
# data 8 to n into the stack
|
||||
if int(arg[4:]) >= 8:
|
||||
@@ -90,7 +90,7 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
else:
|
||||
ins.append(f"mov {rtor[out.nm]}, #0")
|
||||
ins.append(f"loop_{arg}:")
|
||||
elif uop == UOps.CAST:
|
||||
elif uop == Ops.CAST:
|
||||
if arg == BinaryOps.CMPLT:
|
||||
if rtor[out.nm][0] == 's':
|
||||
mov_imm(0.0, 's0')
|
||||
@@ -102,7 +102,7 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
|
||||
else:
|
||||
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
||||
elif uop == UOps.ALU:
|
||||
elif uop == Ops.ALU:
|
||||
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
@@ -136,7 +136,7 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
|
||||
else:
|
||||
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
||||
elif uop == UOps.LOAD:
|
||||
elif uop == Ops.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
mov_imm(arg, rtor[out.nm])
|
||||
else:
|
||||
@@ -146,20 +146,20 @@ def specialize_to_arm64(fn_nm, asm):
|
||||
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
||||
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
||||
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
||||
elif uop == UOps.STORE:
|
||||
elif uop == Ops.STORE:
|
||||
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
||||
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
||||
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
|
||||
ins.append(f"mov x15, #{arg[0]}")
|
||||
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
||||
if prev_uop == UOps.LOAD:
|
||||
if prev_uop == Ops.LOAD:
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
|
||||
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
|
||||
elif uop == UOps.LABEL:
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg[1:]}:")
|
||||
elif uop == UOps.ENDLOOP:
|
||||
elif uop == Ops.ENDLOOP:
|
||||
mov_imm(arg[0], "x15")
|
||||
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
|
||||
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import List
|
||||
import struct
|
||||
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
||||
from tinygrad.codegen.kernel import UOps, UOp
|
||||
from tinygrad.codegen.kernel import Ops, UOp
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_cuda import arch
|
||||
@@ -37,11 +37,11 @@ def specialize_to_ptx(lang, function_name):
|
||||
UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
|
||||
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
|
||||
for uop, out, vin, arg in lang.ins:
|
||||
if uop == UOps.ENDLOOP:
|
||||
if uop == Ops.ENDLOOP:
|
||||
ins.append("bar.sync 0;")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
elif uop == Ops.DEFINE_LOCAL:
|
||||
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
|
||||
elif uop == UOps.SPECIAL:
|
||||
elif uop == Ops.SPECIAL:
|
||||
if arg.startswith('data'):
|
||||
param_cnt += 1
|
||||
ins.append(f"ld.param.u64 {out}, [{arg}];")
|
||||
@@ -51,7 +51,7 @@ def specialize_to_ptx(lang, function_name):
|
||||
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
|
||||
elif arg.startswith('lid'):
|
||||
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
|
||||
elif uop == UOps.ALU:
|
||||
elif uop == Ops.ALU:
|
||||
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
||||
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
|
||||
else:
|
||||
@@ -64,7 +64,7 @@ def specialize_to_ptx(lang, function_name):
|
||||
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
|
||||
vin = vin[1:] + [reg]
|
||||
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
|
||||
elif uop == UOps.LOAD:
|
||||
elif uop == Ops.LOAD:
|
||||
if arg.__class__ in (int, float):
|
||||
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
|
||||
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
|
||||
@@ -74,7 +74,7 @@ def specialize_to_ptx(lang, function_name):
|
||||
render_cast(ins, reg, out)
|
||||
else:
|
||||
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
|
||||
elif uop == UOps.STORE:
|
||||
elif uop == Ops.STORE:
|
||||
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
|
||||
if arg[2] == dtypes.bool != vin[1].dtype:
|
||||
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
|
||||
@@ -85,11 +85,11 @@ def specialize_to_ptx(lang, function_name):
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
|
||||
else:
|
||||
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
|
||||
elif uop == UOps.CAST:
|
||||
elif uop == Ops.CAST:
|
||||
render_cast(ins, vin[0], out)
|
||||
elif uop == UOps.LABEL:
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
|
||||
|
||||
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",
|
||||
|
||||
@@ -2,7 +2,7 @@ import yaml
|
||||
from typing import Tuple, Set, Dict
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.codegen.assembly import AssemblyCodegen, Register
|
||||
from tinygrad.codegen.kernel import UOps
|
||||
from tinygrad.codegen.kernel import Ops
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
|
||||
|
||||
@@ -61,7 +61,7 @@ class RDNACodegen(AssemblyCodegen):
|
||||
def reg_out(x):
|
||||
return rtor[x]
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
if uop == Ops.DEFINE_REGISTER:
|
||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
|
||||
for i in range(arg[2]):
|
||||
# TODO: Re-use gaps created by this to avoid wasting registers
|
||||
@@ -86,7 +86,7 @@ class RDNACodegen(AssemblyCodegen):
|
||||
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
||||
else:
|
||||
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
|
||||
elif uop == UOps.SPECIAL:
|
||||
elif uop == Ops.SPECIAL:
|
||||
if arg.startswith('buf'):
|
||||
i = int(arg[3:])
|
||||
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
|
||||
@@ -106,7 +106,7 @@ class RDNACodegen(AssemblyCodegen):
|
||||
pend_regs.clear()
|
||||
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
|
||||
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
|
||||
elif uop == UOps.CONST:
|
||||
elif uop == Ops.CONST:
|
||||
if arg == float('inf'): arg = "0x7f800000"
|
||||
elif arg == float('-inf'): arg = "0xff800000"
|
||||
if out.dtype == dtypes.float.vec(4):
|
||||
@@ -114,7 +114,7 @@ class RDNACodegen(AssemblyCodegen):
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
|
||||
elif uop == UOps.ALU:
|
||||
elif uop == Ops.ALU:
|
||||
if arg in [BinaryOps.CMPLT]:
|
||||
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
else:
|
||||
@@ -127,7 +127,7 @@ class RDNACodegen(AssemblyCodegen):
|
||||
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
||||
else:
|
||||
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
||||
elif uop == UOps.LOAD:
|
||||
elif uop == Ops.LOAD:
|
||||
if out.scalar:
|
||||
# swap arg order
|
||||
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
|
||||
@@ -135,13 +135,13 @@ class RDNACodegen(AssemblyCodegen):
|
||||
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
pend_regs.add(out)
|
||||
for r in out.subregs(): pend_regs.add(r)
|
||||
elif uop == UOps.STORE:
|
||||
elif uop == Ops.STORE:
|
||||
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
||||
elif uop == UOps.LABEL:
|
||||
elif uop == Ops.LABEL:
|
||||
ins.append(f"{arg}:")
|
||||
elif uop == UOps.COND_BRANCH:
|
||||
elif uop == Ops.COND_BRANCH:
|
||||
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
|
||||
elif uop == UOps.CAST:
|
||||
elif uop == Ops.CAST:
|
||||
if vin[0].dtype == dtypes.bool:
|
||||
if out.dtype == dtypes.float32:
|
||||
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, List, Final, Callable, DefaultDict
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
|
||||
from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv
|
||||
from tinygrad.codegen.kernel import UOp, UOps
|
||||
from tinygrad.codegen.kernel import UOp, Ops
|
||||
from triton.compiler import compile as triton_compile
|
||||
import linecache
|
||||
import math
|
||||
@@ -75,32 +75,32 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop == UOps.LOOP:
|
||||
if uop == Ops.LOOP:
|
||||
kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):")
|
||||
depth += 1
|
||||
elif uop == UOps.END: depth -= 1
|
||||
elif uop == UOps.ALU:
|
||||
elif uop == Ops.END: depth -= 1
|
||||
elif uop == Ops.ALU:
|
||||
assert dtype is not None
|
||||
val = code_for_op[args](*[r[x] for x in vin])
|
||||
if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val
|
||||
else: kk(f"{ssa(u, 'alu')} = ({val})")
|
||||
elif uop == UOps.LOAD:
|
||||
elif uop == Ops.LOAD:
|
||||
assert dtype is not None
|
||||
if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}")
|
||||
else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}")
|
||||
elif uop == UOps.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
|
||||
elif uop == UOps.CONST: r[u] = define_scalar([], dtype, args)
|
||||
elif uop == UOps.ASSIGN:
|
||||
elif uop == Ops.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
|
||||
elif uop == Ops.CONST: r[u] = define_scalar([], dtype, args)
|
||||
elif uop == Ops.ASSIGN:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop == UOps.STORE:
|
||||
elif uop == Ops.STORE:
|
||||
assert not isinstance(dtype, ImageDType), "unimplemented: image store"
|
||||
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
|
||||
elif uop == UOps.DEFINE_GLOBAL:
|
||||
elif uop == Ops.DEFINE_GLOBAL:
|
||||
bufs.append(args)
|
||||
signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype])
|
||||
r[u] = args
|
||||
elif uop == UOps.SPECIAL:
|
||||
elif uop == Ops.SPECIAL:
|
||||
dims.append(args[1])
|
||||
valid.append(f"{args[1]}<{get_max(args[2])}")
|
||||
if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}")
|
||||
@@ -108,7 +108,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
||||
kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
|
||||
local_size.append(args[2])
|
||||
r[u] = args[1]
|
||||
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
|
||||
elif uop == Ops.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
|
||||
else: raise NotImplementedError(f"unimplemented: {uop}")
|
||||
|
||||
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n"
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
from typing import Tuple
|
||||
from tinygrad import Variable
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.ops import UOp, UOps, KernelInfo, TernaryOps, BinaryOps, UnaryOps, ReduceOps, MetaOps
|
||||
from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, ReduceOps, MetaOps
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
inf, nan = float('inf'), float('nan')
|
||||
UOps = Ops
|
||||
|
||||
# kernel unpacker
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from typing import List, Tuple, DefaultDict
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
||||
from tinygrad.helpers import prod, tqdm
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.ops import sym_infer, Node
|
||||
|
||||
@@ -136,7 +136,7 @@ def test_rebuild(st: ShapeTracker):
|
||||
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
|
||||
|
||||
def test_rebuild_bufferop_st(ast:UOp):
|
||||
if ast.op is UOps.SHAPETRACKER:
|
||||
if ast.op is Ops.SHAPETRACKER:
|
||||
test_rebuild(ast.arg)
|
||||
for src in ast.src: test_rebuild_bufferop_st(src)
|
||||
|
||||
|
||||
4
test/external/external_benchmark_schedule.py
vendored
4
test/external/external_benchmark_schedule.py
vendored
@@ -2,7 +2,7 @@ from typing import List
|
||||
from extra.models.resnet import ResNet50
|
||||
from tinygrad import Tensor, Device, nn
|
||||
from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
@@ -27,7 +27,7 @@ if __name__ == "__main__":
|
||||
sched = out.schedule()
|
||||
|
||||
if not SCHEDULE_ONLY:
|
||||
asts = list({x.ast.key:x.ast for x in sched if x.ast.op is UOps.SINK}.values())
|
||||
asts = list({x.ast.key:x.ast for x in sched if x.ast.op is Ops.SINK}.values())
|
||||
if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1]
|
||||
kernels: List[Kernel] = []
|
||||
with Timing(f"***** model opts({len(asts):2d}) in "):
|
||||
|
||||
140
test/external/external_test_valid_remove.py
vendored
140
test/external/external_test_valid_remove.py
vendored
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
|
||||
from tinygrad import Device
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -13,40 +13,40 @@ class TestOpenpilotValidhack(unittest.TestCase):
|
||||
def test_valid_removal(self):
|
||||
Device.DEFAULT = "GPU"
|
||||
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((64, 1024, 4)), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 4096, 32, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
|
||||
x5:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=(
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((128, 768, 4)), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 3, 1, 4, 4, 130, 4, 258), strides=(0, 0, 0, 0, 0, 4, 0, 1, 0, 3072, 0, 12), offset=-3084, mask=((0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 3), (0, 1), (0, 4), (0, 4), (1, 129), (0, 4), (1, 257)), contiguous=False), View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 2064, 2, 0, 0, 0, 0, 2146560, 536640, 135192, 259), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((8, 108, 4)), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 0, 0, 0, 0, 432, 1, 48, 4, 144, 16), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
x19:=UOp(UOps.CONST, dtypes.float, arg=0.0, src=(
|
||||
x20:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=1.0, src=(
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((64, 1024, 4)), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 4096, 32, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
|
||||
x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((128, 768, 4)), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 3, 1, 4, 4, 130, 4, 258), strides=(0, 0, 0, 0, 0, 4, 0, 1, 0, 3072, 0, 12), offset=-3084, mask=((0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 3), (0, 1), (0, 4), (0, 4), (1, 129), (0, 4), (1, 257)), contiguous=False), View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 2064, 2, 0, 0, 0, 0, 2146560, 536640, 135192, 259), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((8, 108, 4)), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 0, 0, 0, 0, 432, 1, 48, 4, 144, 16), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
x19:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=(
|
||||
x20:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.CONST, dtypes.float, arg=1.0, src=(
|
||||
x20,)),
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.EXP2, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.EXP2, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
x5,
|
||||
UOp(UOps.CONST, dtypes.float, arg=1.4426950408889634, src=(
|
||||
UOp(Ops.CONST, dtypes.float, arg=1.4426950408889634, src=(
|
||||
x20,)),)),)),
|
||||
x29:=UOp(UOps.CONST, dtypes.float, arg=-1.0, src=(
|
||||
x29:=UOp(Ops.CONST, dtypes.float, arg=-1.0, src=(
|
||||
x20,)),)),)),
|
||||
x19,)),
|
||||
x29,)),)),)),))
|
||||
@@ -62,51 +62,51 @@ class TestOpenpilotValidhack(unittest.TestCase):
|
||||
def test_const_idx(self):
|
||||
Device.DEFAULT = "GPU"
|
||||
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((10, 128, 4)), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 512, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (0, 1), (0, 512)), contiguous=False),)), src=()),)),
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
x18:=UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 128, 4)), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 512, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (0, 1), (0, 512)), contiguous=False),)), src=()),)),
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
x18:=UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
x18,
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=45568, mask=((0, 1), (2, 3), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=45568, mask=((0, 1), (2, 3), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
x18,
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=43008, mask=((0, 1), (3, 4), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=43008, mask=((0, 1), (3, 4), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
x18,
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=40448, mask=((0, 1), (4, 5), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=40448, mask=((0, 1), (4, 5), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
x18,
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=37888, mask=((0, 1), (5, 6), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=37888, mask=((0, 1), (5, 6), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
x18,
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=35328, mask=((0, 1), (6, 7), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=35328, mask=((0, 1), (6, 7), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
x18,
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=32768, mask=((0, 1), (7, 8), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=32768, mask=((0, 1), (7, 8), (0, 512)), contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
x18,
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=30208, mask=((0, 1), (8, 9), (0, 512)), contiguous=False),)), src=()),)),)),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=3, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (9, 10), (0, 512)), contiguous=False),)), src=()),)),)),)),)),))
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=30208, mask=((0, 1), (8, 9), (0, 512)), contiguous=False),)), src=()),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=3, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (9, 10), (0, 512)), contiguous=False),)), src=()),)),)),)),)),))
|
||||
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)]
|
||||
kernel = Kernel(ast)
|
||||
|
||||
4
test/external/fuzz_linearizer.py
vendored
4
test/external/fuzz_linearizer.py
vendored
@@ -25,7 +25,7 @@ from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing
|
||||
from tinygrad.ops import UnaryOps, UOp, UOps
|
||||
from tinygrad.ops import UnaryOps, UOp, Ops
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
def on_linearizer_will_run(): pass
|
||||
@@ -252,7 +252,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
|
||||
def _is_simple(lin: Kernel) -> bool:
|
||||
if len(lin.ast.src) > 1: return False
|
||||
ast:UOp = lin.ast.src[0]
|
||||
if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is UOps.LOAD: return True
|
||||
if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys, time, logging, difflib
|
||||
from typing import Callable, Optional, Tuple
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.ops import UOp, UOps, sint
|
||||
from tinygrad.ops import UOp, Ops, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.engine.realize import Runner
|
||||
@@ -69,7 +69,7 @@ def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
|
||||
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
|
||||
if st_src is None:
|
||||
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
|
||||
return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
|
||||
return UOp(Ops.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
|
||||
|
||||
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
|
||||
st = time.perf_counter_ns()
|
||||
@@ -77,7 +77,7 @@ def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
|
||||
return ret, (time.perf_counter_ns()-st)*1e-6
|
||||
|
||||
def eval_uop(uop:UOp):
|
||||
g = UOp(UOps.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
|
||||
g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
|
||||
rw = full_graph_rewrite(UOp.store(g, UOp.const(dtypes.int, 0), uop).sink(), PythonRenderer)
|
||||
prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render("run", linearize_uop(rw))))
|
||||
buf = PythonAllocator().alloc(uop.dtype.itemsize)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest, math
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.helpers import CI
|
||||
import numpy as np
|
||||
@@ -9,7 +9,7 @@ from test.helpers import is_dtype_supported
|
||||
def _check_ast_count(desired_count:int, t:Tensor):
|
||||
# NOTE: this has side effect because everything can be scheduled only once
|
||||
schedule = create_schedule(t.lazydata.lbs)
|
||||
asts = [s for s in schedule if s.ast.op is UOps.SINK]
|
||||
asts = [s for s in schedule if s.ast.op is Ops.SINK]
|
||||
assert len(asts) == desired_count
|
||||
|
||||
class TestUnaryOpsConstFolding(unittest.TestCase):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -15,9 +15,9 @@ class TestConvShapetracker(unittest.TestCase):
|
||||
# first run to init the weights, they are scheduled.
|
||||
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata])
|
||||
# run it again to get the kernels
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is UOps.SINK]
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is Ops.SINK]
|
||||
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
|
||||
for st in [x.st_arg for x in sched[0].ast.parents if x.op is UOps.LOAD]:
|
||||
for st in [x.st_arg for x in sched[0].ast.parents if x.op is Ops.LOAD]:
|
||||
assert len(st.views) == 1
|
||||
|
||||
def test_conv_2x2_backward_one_view(self):
|
||||
@@ -26,7 +26,7 @@ class TestConvShapetracker(unittest.TestCase):
|
||||
conv(X).mean().backward()
|
||||
si = X.grad.schedule()[-1]
|
||||
print(si)
|
||||
ldb = [x for x in si.ast.parents if x.op is UOps.LOAD][0]
|
||||
ldb = [x for x in si.ast.parents if x.op is Ops.LOAD][0]
|
||||
st: ShapeTracker = ldb.st_arg.simplify()
|
||||
# NOTE: st.real_size() is broken
|
||||
print(si.inputs[0].size)
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.dtype import DType
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.ops import UnaryOps, UOps
|
||||
from tinygrad.ops import UnaryOps, Ops
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from test.helpers import is_dtype_supported
|
||||
import pytest
|
||||
@@ -79,7 +79,7 @@ def universal_test_unary(a, dtype, op):
|
||||
np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
|
||||
else: np.testing.assert_equal(tensor_value, numpy_value)
|
||||
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
|
||||
op = [x for x in ast.parents if x.op is UOps.ALU and x.arg in UnaryOps][0]
|
||||
op = [x for x in ast.parents if x.op is Ops.ALU and x.arg in UnaryOps][0]
|
||||
assert op.dtype == dtype
|
||||
|
||||
def universal_test_cast(a, in_dtype, dtype):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.engine.lazy import LazyBuffer, MetaOps
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestReduceOp(unittest.TestCase):
|
||||
a = a.sum()
|
||||
sched = create_schedule([a.lazydata])
|
||||
assert len(sched) == 1
|
||||
self.assertIs(sched[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS)
|
||||
self.assertIs(sched[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
|
||||
def test_split_reduce_kernel_dim0(self):
|
||||
a = Tensor.rand(256, 255).realize()
|
||||
@@ -83,7 +83,7 @@ class TestReduceOp(unittest.TestCase):
|
||||
sched = create_schedule([a.lazydata])
|
||||
assert len(sched) == 2
|
||||
for s in sched:
|
||||
self.assertIs(s.ast.src[0].src[2].op, UOps.REDUCE_AXIS)
|
||||
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
|
||||
def test_split_reduce_kernel_dim1(self):
|
||||
a = Tensor.rand(255, 256).realize()
|
||||
@@ -91,7 +91,7 @@ class TestReduceOp(unittest.TestCase):
|
||||
sched = create_schedule([a.lazydata])
|
||||
assert len(sched) == 2
|
||||
for s in sched:
|
||||
self.assertIs(s.ast.src[0].src[2].op, UOps.REDUCE_AXIS)
|
||||
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
|
||||
class TestView(unittest.TestCase):
|
||||
def test_all_masked_out(self):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,7 @@
|
||||
import unittest
|
||||
from test.helpers import ast_const
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
@@ -14,26 +14,26 @@ from tinygrad.codegen.kernel import Kernel
|
||||
class TestLinearizerDumb(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL")
|
||||
def test_unmerged_ifs(self):
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MAX, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MAX, src=(
|
||||
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
||||
ast_const(dtypes.half, 0.9999950000374996, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
ast_const(dtypes.half, 0.0, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
|
||||
opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0)]
|
||||
k = Kernel(ast, opts=Device["METAL"].renderer)
|
||||
k.required_optimizations()
|
||||
@@ -43,88 +43,88 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
Device[Device.DEFAULT].compiler.compile_cached(prg.src)
|
||||
gate_count = len([x for x in prg.src.splitlines() if "if" in x])
|
||||
assert gate_count == 1, f"must have only one gate {gate_count} != 1"
|
||||
assert len([u for u in k.uops if u.op is UOps.IF]) == 1, "must have a single IF"
|
||||
assert len([u for u in k.uops if u.op is Ops.IF]) == 1, "must have a single IF"
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
|
||||
def test_max_simplify_and_cancel(self):
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.CAST, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
ast_const(dtypes.bool, True, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
ast_const(dtypes.int, -1, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
ast_const(dtypes.int, 1000, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
k.required_optimizations()
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
assert prg.uops is not None and not any(uop.op is UOps.ALU and uop.arg is BinaryOps.MAX for uop in prg.uops), "leftover MAX"
|
||||
assert prg.uops is not None and not any(uop.op is Ops.ALU and uop.arg is BinaryOps.MAX for uop in prg.uops), "leftover MAX"
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
|
||||
def test_expander_new_srcs(self):
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
|
||||
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
k.required_optimizations()
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
if_uops = [u for u in k.uops if u.op is UOps.IF]
|
||||
if_uops = [u for u in k.uops if u.op is Ops.IF]
|
||||
self.assertIn(len(if_uops), {1,3})
|
||||
conditions = if_uops[0].src[0].sparents
|
||||
self.assertLessEqual(len(conditions), 9)
|
||||
|
||||
# this was a bug in embedding, someday we should fold this anyway
|
||||
def test_llama_embedding(self):
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
ast_const(dtypes.int, 1, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)),
|
||||
ast_const(dtypes.int, -1, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
ast_const(dtypes.bool, True, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
@@ -132,88 +132,88 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
# from process replay https://github.com/tinygrad/tinygrad/actions/runs/10389229290/job/28766762085#step:18:6490
|
||||
@unittest.expectedFailure
|
||||
def test_unaligns_idxs(self):
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.LOAD, dtypes.long, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(UOps.CAST, dtypes.long, arg=None, src=(
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(Ops.LOAD, dtypes.long, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.CAST, dtypes.long, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
ast_const(dtypes.bool, True, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
load_idxs = [x.src[1] for x in k.uops if x.op is UOps.LOAD and x.src[0].arg == 3]
|
||||
load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 3]
|
||||
assert load_idxs[0] < load_idxs[1], f"first loaded idx {load_idxs[0].arg} then {load_idxs[1].arg}!"
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4")
|
||||
def test_unrolled_float4_align(self):
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(UOps.LOAD, dtypes.long, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||
UOp(Ops.LOAD, dtypes.long, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||
ast_const(dtypes.long, -1, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
ast_const(dtypes.bool, True, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
ast_const(dtypes.float, 0.0, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),))
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
load_idxs = [x.src[1] for x in k.uops if x.op is UOps.LOAD and x.src[0].arg == 2]
|
||||
load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 2]
|
||||
assert load_idxs[0] < load_idxs[1], f"first loaded idx {load_idxs[0].arg} then {load_idxs[1].arg}!"
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4")
|
||||
@unittest.skipIf(getenv("PTX"), "this is somehow correct in PTX")
|
||||
def test_upcasted_stores_out_of_order(self):
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
|
||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
store_idxs = [x.src[1] for x in k.uops if x.op is UOps.STORE]
|
||||
store_idxs = [x.src[1] for x in k.uops if x.op is Ops.STORE]
|
||||
for i in range(len(store_idxs) - 1):
|
||||
first_bounds = store_idxs[i].vmin+store_idxs[i].vmax
|
||||
next_bounds = store_idxs[i+1].vmin+store_idxs[i+1].vmax
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,7 @@ from tinygrad.engine.search import Opt, OptOps
|
||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin
|
||||
|
||||
# stuff needed to unpack a kernel
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
@@ -24,143 +24,143 @@ def _test_overflow(ast, opts):
|
||||
@unittest.skip("unneeded without launch bounds")
|
||||
class TestLinearizerOverflow(unittest.TestCase):
|
||||
def test_overflow_1(self):
|
||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||
UOp(UOps.STORE, None, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
ast = UOp(Ops.SINK, None, arg=None, src=(
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||
x16:=ast_const(dtypes.float, 0.0, st_src=(
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
x23:=ast_const(dtypes.float, 1.0, st_src=(
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
x23,
|
||||
ast_const(dtypes.float, 1e-05, st_src=(
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
x16,)),)),))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
# From BEAM on hlb_cifar.py
|
||||
def test_overflow_2(self):
|
||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||
UOp(UOps.STORE, None, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
ast = UOp(Ops.SINK, None, arg=None, src=(
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
# from BEAM on default simple_conv.py (which is quite large):
|
||||
def test_overflow_3(self):
|
||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||
UOp(UOps.STORE, None, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
ast = UOp(Ops.SINK, None, arg=None, src=(
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
# from BEAM on BS=4 simple_conv.py:
|
||||
def test_overflow_4(self):
|
||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||
UOp(UOps.STORE, None, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
ast = UOp(Ops.SINK, None, arg=None, src=(
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
# from BEAM on BS=2 simple_conv.py:
|
||||
def test_overflow_5(self):
|
||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||
UOp(UOps.STORE, None, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
ast = UOp(Ops.SINK, None, arg=None, src=(
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
# from BEAM on BS=3 simple_conv.py:
|
||||
def test_overflow_6(self):
|
||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||
UOp(UOps.STORE, None, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
ast = UOp(Ops.SINK, None, arg=None, src=(
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
# from BEAM on BS=3 simple_conv.py: (alt)
|
||||
def test_overflow_7(self):
|
||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||
UOp(UOps.STORE, None, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
ast = UOp(Ops.SINK, None, arg=None, src=(
|
||||
UOp(Ops.STORE, None, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
@@ -169,26 +169,26 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||
class TestLinearizerOverflowAlt(unittest.TestCase):
|
||||
def test_overflow_1(self):
|
||||
BS = 2
|
||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False),
|
||||
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
|
||||
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
|
||||
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
|
||||
prod = UOp(UOps.LOAD, dtypes.float, (g1, in_st_1)) * UOp(UOps.LOAD, dtypes.float, (g2, in_st_2))
|
||||
store = UOp(UOps.STORE, src=(g0, ot_st, UOp(UOps.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
|
||||
ast = UOp(UOps.SINK, src=(store,))
|
||||
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
|
||||
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
|
||||
ast = UOp(Ops.SINK, src=(store,))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
|
||||
_test_overflow(ast, opts)
|
||||
def test_overflow_2(self):
|
||||
BS = 2
|
||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||
in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False),
|
||||
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
|
||||
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
|
||||
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
|
||||
prod = UOp(UOps.LOAD, dtypes.float, (g1, in_st_1)) * UOp(UOps.LOAD, dtypes.float, (g2, in_st_2))
|
||||
store = UOp(UOps.STORE, src=(g0, ot_st, UOp(UOps.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
|
||||
ast = UOp(UOps.SINK, src=(store,))
|
||||
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
|
||||
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
|
||||
ast = UOp(Ops.SINK, src=(store,))
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=2)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, functools, random
|
||||
from typing import List
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
|
||||
from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, UOps
|
||||
from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, Ops
|
||||
from tinygrad.helpers import CI, getenv, prod, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -617,24 +617,24 @@ class TestMultiTensor(unittest.TestCase):
|
||||
t = t + 1
|
||||
for si in t.schedule():
|
||||
ast = si.ast.src[0]
|
||||
assert ast.op is UOps.STORE
|
||||
assert ast.op is Ops.STORE
|
||||
assert ast.src[2].arg is BinaryOps.ADD
|
||||
assert ast.src[2].src[0].op is UOps.LOAD
|
||||
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1
|
||||
assert ast.src[2].src[0].op is Ops.LOAD
|
||||
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1
|
||||
t = 2 * t
|
||||
for si in t.schedule():
|
||||
ast = si.ast.src[0]
|
||||
assert ast.op is UOps.STORE
|
||||
assert ast.op is Ops.STORE
|
||||
assert ast.src[2].arg is BinaryOps.MUL
|
||||
assert ast.src[2].src[0].src[1].op is UOps.CONST and ast.src[2].src[0].src[1].arg == 2
|
||||
assert ast.src[2].src[1].op is UOps.LOAD
|
||||
assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2
|
||||
assert ast.src[2].src[1].op is Ops.LOAD
|
||||
t = t + t.full_like(3)
|
||||
for si in t.schedule():
|
||||
ast = si.ast.src[0]
|
||||
assert ast.op is UOps.STORE
|
||||
assert ast.op is Ops.STORE
|
||||
assert ast.src[2].arg is BinaryOps.ADD
|
||||
assert ast.src[2].src[0].op is UOps.LOAD
|
||||
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 3
|
||||
assert ast.src[2].src[0].op is Ops.LOAD
|
||||
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3
|
||||
|
||||
def test_shard_memory(self):
|
||||
devices = (d0, d1, d2, d3)
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from tinygrad import Tensor, Device, TinyJit
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.helpers import CI, Context
|
||||
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
|
||||
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
|
||||
@@ -517,7 +517,7 @@ class TestNN(unittest.TestCase):
|
||||
[12, 19, 8, 1]])
|
||||
result = layer(a)
|
||||
schedule = create_schedule([result.lazydata])
|
||||
self.assertEqual(3, len([item for item in schedule if item.ast.op is UOps.SINK]), "first run realizes arange, weight, and embedding")
|
||||
self.assertEqual(3, len([item for item in schedule if item.ast.op is Ops.SINK]), "first run realizes arange, weight, and embedding")
|
||||
run_schedule(schedule)
|
||||
|
||||
b = Tensor([[1, 2, 3],
|
||||
@@ -525,7 +525,7 @@ class TestNN(unittest.TestCase):
|
||||
[7, 8, 9]])
|
||||
result = layer(b)
|
||||
schedule = create_schedule([result.lazydata])
|
||||
self.assertEqual(1, len([item for item in schedule if item.ast.op is UOps.SINK]), "second run realizes embedding only")
|
||||
self.assertEqual(1, len([item for item in schedule if item.ast.op is Ops.SINK]), "second run realizes embedding only")
|
||||
run_schedule(schedule)
|
||||
|
||||
def test_embedding_shape(self):
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import dedup, flatten, prod
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.ops import BinaryOps, UOp, UOps
|
||||
from tinygrad.ops import BinaryOps, UOp, Ops
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
@@ -20,7 +20,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||
def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
|
||||
uops = dedup(flatten(_recursive_add(st) for st in stores))
|
||||
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is UOps.STORE]
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
|
||||
inbufs = [cast(LazyBuffer,x.lazydata).base.buffer for x in inputs]
|
||||
src = Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, local_size=local_size))
|
||||
@@ -30,13 +30,13 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle")
|
||||
class TestCStyleFailures(unittest.TestCase):
|
||||
def test_inline_const_alu(self):
|
||||
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
b = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (b, idx))
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (b, idx))
|
||||
alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
|
||||
store = UOp.store(a, idx, alu)
|
||||
sink = UOp(UOps.SINK, dtypes.void, (store,))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
# CLANG doesn't use the max function
|
||||
ret = _test_uop_result([Tensor([1])], uops)[0]
|
||||
@@ -45,21 +45,21 @@ class TestCStyleFailures(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local and Device.DEFAULT == "PTX", "need local")
|
||||
class TestPTXFailures(unittest.TestCase):
|
||||
def test_gated_store_with_alu(self):
|
||||
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,))
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
def test_gated_store_with_if(self):
|
||||
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
val = UOp.const(dtypes.int, 1)
|
||||
if_uop = UOp(UOps.IF, dtypes.void, (gate_alu,))
|
||||
gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, val, if_uop))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,))
|
||||
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a, lidx0, val, if_uop))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
|
||||
@@ -29,7 +29,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize: pre.schedule()
|
||||
sched = create_schedule(outs)
|
||||
if filter_sink: sched = [s for s in sched if s.ast.op is UOps.SINK]
|
||||
if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
for i, s in enumerate(sched):
|
||||
@@ -38,7 +38,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
|
||||
if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}")
|
||||
# test the (sink) ops linearize
|
||||
for s in sched:
|
||||
if s.ast.op is not UOps.SINK: continue
|
||||
if s.ast.op is not Ops.SINK: continue
|
||||
l = Kernel(s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
@@ -58,7 +58,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
dtypes.default_float = old_default_float
|
||||
with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata])
|
||||
run_schedule(s.copy())
|
||||
cnt = len([si for si in s if si.ast.op is UOps.SINK])
|
||||
cnt = len([si for si in s if si.ast.op is Ops.SINK])
|
||||
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
|
||||
if getenv("CHECK", 1):
|
||||
import torch
|
||||
@@ -191,7 +191,7 @@ class TestSchedule(unittest.TestCase):
|
||||
r1 = (x - r0).sum(axis=0).div(2)
|
||||
out = r0 + r1
|
||||
schedule = check_schedule(out, 2)
|
||||
reduceops = [x for si in schedule for x in si.ast.parents if x.op is UOps.REDUCE_AXIS]
|
||||
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
|
||||
assert len(reduceops) == 2
|
||||
|
||||
def test_cache_reduce_multiple_children(self):
|
||||
@@ -202,7 +202,7 @@ class TestSchedule(unittest.TestCase):
|
||||
out0 = r0 + y
|
||||
out1 = r1 + y
|
||||
schedule = check_schedule([out0, out1], 4)
|
||||
reduceops = [x for si in schedule for x in si.ast.parents if x.op is UOps.REDUCE_AXIS]
|
||||
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
|
||||
assert len(reduceops) == 2
|
||||
|
||||
def test_fold_double_unary(self):
|
||||
@@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = (a.sum(0) + a.max(1)) + 2
|
||||
schedule = check_schedule(b, 2)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
|
||||
# multireduce spec
|
||||
def test_multireduce_midreduce_nochase(self):
|
||||
@@ -1117,7 +1117,7 @@ class TestSchedule(unittest.TestCase):
|
||||
b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2
|
||||
# schedule = check_schedule(b, 2)
|
||||
schedule = check_schedule(b, 4)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@@ -1352,7 +1352,7 @@ class TestIndexing(unittest.TestCase):
|
||||
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
|
||||
lst = [xt] if isinstance(xt, Tensor) else xt
|
||||
s = Tensor.schedule(*lst)
|
||||
kernels = [si for si in s if si.ast.op is UOps.SINK]
|
||||
kernels = [si for si in s if si.ast.op is Ops.SINK]
|
||||
for si in kernels: verify_ast(si.ast)
|
||||
run_schedule(s)
|
||||
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
|
||||
@@ -1607,20 +1607,20 @@ class TestIndexing(unittest.TestCase):
|
||||
self.assertLess(et, 1200)
|
||||
|
||||
def test_no_rewrite_elementwise(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
self.assertEqual(rsink.key, sink.key)
|
||||
|
||||
def test_simple_store_reshape(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
|
||||
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
@@ -1628,21 +1628,21 @@ class TestIndexing(unittest.TestCase):
|
||||
verify_ast(rsink)
|
||||
|
||||
def test_no_reshape_reduceop(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
verify_ast(sink)
|
||||
self.assertEqual(sink.key, rsink.key)
|
||||
|
||||
def test_reshape_many(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
for _ in range(24): r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink, et = timeit(graph_rewrite, sink, view_right)
|
||||
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
|
||||
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
@@ -1656,11 +1656,11 @@ class TestIndexing(unittest.TestCase):
|
||||
sizes = [10*(i+1) for i in range(SZ)]
|
||||
tms: List[float] = []
|
||||
for sz in sizes:
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink, et = timeit(graph_rewrite, sink, view_right)
|
||||
with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
verify_ast(rsink)
|
||||
@@ -1676,20 +1676,20 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
def test_swizzle_rewrite(self):
|
||||
# graph rewrite
|
||||
sink = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x8:=UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
x8:=UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
x8,
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
|
||||
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
@@ -1705,13 +1705,13 @@ class TestIndexing(unittest.TestCase):
|
||||
a = Tensor.randint(4,).realize()
|
||||
expected_out = a.numpy().sum(0)+1
|
||||
# LazyBuffer to pre-rewrite AST
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
|
||||
swizzle_r = UOp(UOps.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
|
||||
swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
|
||||
const = ast_const(dtypes.int, 1, ())
|
||||
alu = swizzle_r+const
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
|
||||
# graph rewrite
|
||||
sink = graph_rewrite(sink, view_right)
|
||||
# verify output
|
||||
@@ -1728,13 +1728,13 @@ class TestIndexing(unittest.TestCase):
|
||||
b = Tensor.randint(4,).realize()
|
||||
expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2
|
||||
# LazyBuffer to pre-rewrite AST
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r1 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
|
||||
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r2 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,)))
|
||||
alu = UOp(UOps.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(UOps.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
|
||||
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,)))
|
||||
alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
|
||||
# graph rewrite
|
||||
sink = graph_rewrite(sink, view_right)
|
||||
# verify output
|
||||
@@ -1745,51 +1745,51 @@ class TestIndexing(unittest.TestCase):
|
||||
np.testing.assert_equal(c.numpy(), expected_out)
|
||||
|
||||
def test_swizzle_rewrite_alt(self):
|
||||
swizzle = UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
|
||||
swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
|
||||
# there's an EXPAND pushing through the REDUCE_AXIS
|
||||
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
|
||||
ret = graph_rewrite(graph_rewrite(swizzle, view_left), view_right)
|
||||
# EXPAND is rewritten
|
||||
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
|
||||
# and pushed to the LOAD
|
||||
new_load_st = unwrap([x for x in ret.parents if x.op is UOps.VIEW][0].st)
|
||||
new_load_st = unwrap([x for x in ret.parents if x.op is Ops.VIEW][0].st)
|
||||
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
|
||||
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
|
||||
|
||||
def test_permute_rewrite(self):
|
||||
sink = UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
x1:=UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
|
||||
x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=(
|
||||
sink = UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
x1:=UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
|
||||
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=(
|
||||
x1,
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
|
||||
x2,)),)),
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
x11,)),)),)),)),))
|
||||
@track_rewrites()
|
||||
def rewrite(sink): return graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
ret = rewrite(sink)
|
||||
assert len([x for x in ret.sparents if x.op is UOps.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
|
||||
assert len([x for x in ret.sparents if x.op is Ops.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
from test.helpers import ast_const
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
|
||||
from tinygrad.device import Device, Buffer
|
||||
@@ -16,15 +16,15 @@ from tinygrad.shape.view import View
|
||||
|
||||
class TestTimeLinearizer(unittest.TestCase):
|
||||
def test_reasonable_time(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is UOps.SINK][0]
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0]
|
||||
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
|
||||
memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.parents if x.op is UOps.LOAD}
|
||||
memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.parents if x.op is Ops.LOAD}
|
||||
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
|
||||
tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
def test_bufs_from_lin(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is UOps.SINK][0]
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0]
|
||||
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
|
||||
assert len(rawbufs) == len(lin.membufs) == 2
|
||||
assert all(r is not None for r in rawbufs)
|
||||
@@ -34,7 +34,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
def test_bufs_from_lin_alt(self):
|
||||
a = Tensor.randn(4, 4).realize()
|
||||
b = a+a[0]
|
||||
si = [si for si in b.schedule() if si.ast.op is UOps.SINK][0]
|
||||
si = [si for si in b.schedule() if si.ast.op is Ops.SINK][0]
|
||||
rawbufs = bufs_from_lin(k:=Kernel(si.ast))
|
||||
assert len(rawbufs) == len(k.membufs) == 2
|
||||
assert all(r is not None for r in rawbufs)
|
||||
@@ -46,12 +46,12 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
Ensure that the kernel count is not incremented by time_linearizer when clearing l2
|
||||
"""
|
||||
# ast of Tensor.zeros(16).contiguous().realize()
|
||||
ast = UOp(UOps.SINK, src=(
|
||||
UOp(UOps.STORE, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))),
|
||||
ast = UOp(Ops.SINK, src=(
|
||||
UOp(Ops.STORE, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))),
|
||||
ast_const(dtypes.float, 0.0, st_src=(
|
||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),))
|
||||
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),))
|
||||
lin = Kernel(ast)
|
||||
bufs = bufs_from_lin(lin)
|
||||
|
||||
@@ -103,37 +103,37 @@ class TestBEAM(unittest.TestCase):
|
||||
|
||||
def test_filter_global_buffer(self):
|
||||
# taken from https://github.com/tinygrad/tinygrad/issues/4612
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
ast_const(dtypes.float, 1.4285714285714286, st_src=(
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501
|
||||
lin = Kernel(ast)
|
||||
|
||||
bufs = bufs_from_lin(lin)
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List
|
||||
import unittest, time
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, Ops, UOp, KernelInfo
|
||||
from tinygrad.ops import UPat, PatternMatcher
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
||||
@@ -24,30 +24,30 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
|
||||
c1 = UOp.const(dtypes.int, 1)
|
||||
c2 = UOp.const(dtypes.int, 2)
|
||||
st = time.perf_counter()
|
||||
uops = [UOp(UOps.ALU, dtypes.int, (c1, c2), BinaryOps.ADD) for _ in range(10000)]
|
||||
uops = [UOp(Ops.ALU, dtypes.int, (c1, c2), BinaryOps.ADD) for _ in range(10000)]
|
||||
et = time.perf_counter() - st
|
||||
print(f"created {len(uops)} uops in {et*1000:.2f} ms")
|
||||
|
||||
def test_expand_rewrite(self):
|
||||
sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
|
||||
sink = UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
|
||||
strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0),
|
||||
offset=0, mask=None, contiguous=False),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 10)), src=(
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 10)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(
|
||||
View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16,
|
||||
mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False),
|
||||
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0,
|
||||
mask=None, contiguous=False))), src=()),)),
|
||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(
|
||||
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0,
|
||||
mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
|
||||
lower_sink = rewrite_shapetracker_with_index(sink, Device[Device.DEFAULT].renderer)
|
||||
@@ -82,7 +82,7 @@ class TestGraphRewriteConst(unittest.TestCase):
|
||||
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
|
||||
v2 = UOp.const(dtypes.int.vec(3), (5,6,7))
|
||||
ret = graph_rewrite(v1+v2, sym)
|
||||
self.assertEqual(ret.op, UOps.VCONST)
|
||||
self.assertEqual(ret.op, Ops.VCONST)
|
||||
self.assertEqual(ret.dtype, dtypes.int.vec(3))
|
||||
self.assertEqual(ret.arg, (5,7,9))
|
||||
|
||||
@@ -90,31 +90,31 @@ class TestGraphRewriteConst(unittest.TestCase):
|
||||
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
|
||||
v2 = UOp.const(dtypes.int.vec(3), (2,1,0))
|
||||
ret = graph_rewrite(v1+v2, sym)
|
||||
self.assertEqual(ret.op, UOps.CONST)
|
||||
self.assertEqual(ret.op, Ops.CONST)
|
||||
self.assertEqual(ret.dtype, dtypes.int.vec(3))
|
||||
self.assertEqual(ret.arg, 2)
|
||||
|
||||
class TestGraphRewrite(unittest.TestCase):
|
||||
def test_dedup(self):
|
||||
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
|
||||
v2 = UOp(UOps.DEFINE_VAR, dtypes.float)
|
||||
v1 = UOp(Ops.DEFINE_VAR, dtypes.float)
|
||||
v2 = UOp(Ops.DEFINE_VAR, dtypes.float)
|
||||
nout = graph_rewrite(v1+v2, PatternMatcher([]))
|
||||
self.assertIs(nout.src[0], nout.src[1])
|
||||
|
||||
# NOTE: this shows why we can't have a UOp in arg
|
||||
@unittest.expectedFailure
|
||||
def test_no_dedup_args(self):
|
||||
a1 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
|
||||
a2 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
|
||||
a1 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
|
||||
a2 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
|
||||
sink = a1.sink(a2)
|
||||
define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).sparents if x.op is UOps.DEFINE_VAR]
|
||||
define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).sparents if x.op is Ops.DEFINE_VAR]
|
||||
self.assertEqual(len(define_vars), 1)
|
||||
|
||||
def test_simple(self):
|
||||
c1 = UOp.const(dtypes.float, 1.0)
|
||||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
nout = graph_rewrite(c1+c2, simple_pm)
|
||||
self.assertEqual(nout.op, UOps.CONST)
|
||||
self.assertEqual(nout.op, Ops.CONST)
|
||||
self.assertEqual(nout.arg, 3.0)
|
||||
|
||||
def test_depth_2_late(self):
|
||||
@@ -122,7 +122,7 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
c3 = UOp.const(dtypes.float, 3.0)
|
||||
nout = graph_rewrite(c1*c2*(c3+c3), simple_pm)
|
||||
self.assertEqual(nout.op, UOps.CONST)
|
||||
self.assertEqual(nout.op, Ops.CONST)
|
||||
self.assertEqual(nout.arg, 12.0)
|
||||
|
||||
def test_double(self):
|
||||
@@ -130,7 +130,7 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
c3 = UOp.const(dtypes.float, 3.0)
|
||||
nout = graph_rewrite(c1+c2+c3, simple_pm)
|
||||
self.assertEqual(nout.op, UOps.CONST)
|
||||
self.assertEqual(nout.op, Ops.CONST)
|
||||
self.assertEqual(nout.arg, 6.0)
|
||||
|
||||
def test_triple(self):
|
||||
@@ -139,7 +139,7 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
c3 = UOp.const(dtypes.float, 3.0)
|
||||
c4 = UOp.const(dtypes.float, 4.0)
|
||||
nout = graph_rewrite(c1+c2+c3+c4, simple_pm)
|
||||
self.assertEqual(nout.op, UOps.CONST)
|
||||
self.assertEqual(nout.op, Ops.CONST)
|
||||
self.assertEqual(nout.arg, 10.0)
|
||||
|
||||
def test_diamond(self):
|
||||
@@ -147,23 +147,23 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
c3 = UOp.const(dtypes.float, 3.0)
|
||||
nout = graph_rewrite((c1+c2)+(c1+c3), simple_pm)
|
||||
self.assertEqual(nout.op, UOps.CONST)
|
||||
self.assertEqual(nout.op, Ops.CONST)
|
||||
self.assertEqual(nout.arg, 7.0)
|
||||
|
||||
def test_magic_4(self):
|
||||
c1 = UOp.const(dtypes.int, 4.0)
|
||||
nout = graph_rewrite(c1, simple_pm)
|
||||
self.assertEqual(nout.op, UOps.CONST)
|
||||
self.assertEqual(nout.op, Ops.CONST)
|
||||
self.assertEqual(nout.arg, 3.0)
|
||||
|
||||
def test_depth_2_fold(self):
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.float)
|
||||
v = UOp(Ops.DEFINE_VAR, dtypes.float)
|
||||
c1 = UOp.const(dtypes.float, 1.0)
|
||||
c2 = UOp.const(dtypes.float, 2.0)
|
||||
nout = graph_rewrite(v+c1+c2, simple_pm)
|
||||
self.assertEqual(nout.op, UOps.ALU)
|
||||
self.assertEqual(nout.src[0].op, UOps.DEFINE_VAR)
|
||||
self.assertEqual(nout.src[1].op, UOps.CONST)
|
||||
self.assertEqual(nout.op, Ops.ALU)
|
||||
self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR)
|
||||
self.assertEqual(nout.src[1].op, Ops.CONST)
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
def test_commutative_work(self):
|
||||
@@ -182,77 +182,77 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
b = UOp.variable('b', 0, 1)
|
||||
c = UOp.variable('c', 0, 1)
|
||||
d = UOp.variable('d', 0, 1)
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(Ops.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, sym)
|
||||
print(sink.render())
|
||||
self.assertEqual(sink.op, UOps.ALU)
|
||||
self.assertEqual(sink.src[1].op, UOps.CONST)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1)
|
||||
self.assertEqual(sink.op, Ops.ALU)
|
||||
self.assertEqual(sink.src[1].op, Ops.CONST)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is Ops.CONST]), 1)
|
||||
|
||||
class TestUOpGraph(unittest.TestCase):
|
||||
def test_add_constant_fold(self):
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(Ops.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.op, Ops.CONST)
|
||||
self.assertEqual(out.arg, 3.0)
|
||||
|
||||
def test_where_same_fold(self):
|
||||
v = UOp.variable('tmp', 0, 1)
|
||||
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
|
||||
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
|
||||
c0 = UOp(Ops.CONST, dtypes.int, arg=0)
|
||||
vc = UOp(Ops.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
out = UOp(Ops.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.op, Ops.CONST)
|
||||
self.assertEqual(out.arg, 1.0)
|
||||
|
||||
def test_where_const_fold(self):
|
||||
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
|
||||
bf = UOp(Ops.CONST, dtypes.bool, arg=False)
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(Ops.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.op, Ops.CONST)
|
||||
self.assertEqual(out.arg, 2.0)
|
||||
|
||||
def test_const_cast(self):
|
||||
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
out = UOp(UOps.CAST, dtypes.int, (bf,))
|
||||
bf = UOp(Ops.CONST, dtypes.bool, arg=False)
|
||||
out = UOp(Ops.CAST, dtypes.int, (bf,))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 1)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.op, Ops.CONST)
|
||||
self.assertEqual(out.arg, 0)
|
||||
|
||||
@unittest.skip("this test isn't valid uops")
|
||||
def test_noop_vectorize_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,))
|
||||
x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0)
|
||||
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
|
||||
out = UOp(UOps.STORE, dtypes.void, (d0, idx, alu))
|
||||
ld = UOp(Ops.LOAD, dtypes.float.vec(2), (d0, idx))
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(2), (ld,))
|
||||
x = UOp(Ops.GEP, dtypes.float, (vec, ), arg=0)
|
||||
alu = UOp(Ops.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0)
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.VECTORIZE]), 0)
|
||||
|
||||
def test_gep_vec_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
d2 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 2)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
d2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 2)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
def _test_vec(geps, count=4):
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
|
||||
out = UOp(UOps.STORE, dtypes.void, (d0, idx, vec))
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(count), geps)
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0, idx, vec))
|
||||
uops = to_uops_list([out])
|
||||
if DEBUG >= 4:
|
||||
from tinygrad import Device
|
||||
@@ -260,53 +260,53 @@ class TestUOpGraph(unittest.TestCase):
|
||||
return uops[-1].src[-1]
|
||||
|
||||
# possible
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), (i,)) for i in range(4))
|
||||
self.assertIs(_test_vec(xyzw).op, UOps.LOAD)
|
||||
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
xyzw = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in range(4))
|
||||
self.assertIs(_test_vec(xyzw).op, Ops.LOAD)
|
||||
|
||||
# unaligned
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
wzyx = tuple(UOp(UOps.GEP, dtypes.float, (val,), (i,)) for i in reversed(range(4)))
|
||||
self.assertIs(_test_vec(wzyx).op, UOps.VECTORIZE)
|
||||
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
wzyx = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in reversed(range(4)))
|
||||
self.assertIs(_test_vec(wzyx).op, Ops.VECTORIZE)
|
||||
|
||||
# different_size
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))
|
||||
xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
|
||||
self.assertIs(_test_vec(xy+xy).op, UOps.VECTORIZE)
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
|
||||
self.assertIs(_test_vec(xy, count=2).op, UOps.VECTORIZE)
|
||||
val = UOp(Ops.LOAD, dtypes.float.vec(2), (d1, idx))
|
||||
xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
|
||||
self.assertIs(_test_vec(xy+xy).op, Ops.VECTORIZE)
|
||||
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
|
||||
self.assertIs(_test_vec(xy, count=2).op, Ops.VECTORIZE)
|
||||
|
||||
# different vals
|
||||
val1 = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))
|
||||
val2 = UOp(UOps.LOAD, dtypes.float.vec(2), (d2, idx))
|
||||
xy1 = tuple(UOp(UOps.GEP, dtypes.float, (val1, ), (i,)) for i in range(2))
|
||||
xy2 = tuple(UOp(UOps.GEP, dtypes.float, (val2, ), (i,)) for i in range(2))
|
||||
self.assertIs(_test_vec(xy1+xy2).op, UOps.VECTORIZE)
|
||||
val1 = UOp(Ops.LOAD, dtypes.float.vec(2), (d1, idx))
|
||||
val2 = UOp(Ops.LOAD, dtypes.float.vec(2), (d2, idx))
|
||||
xy1 = tuple(UOp(Ops.GEP, dtypes.float, (val1, ), (i,)) for i in range(2))
|
||||
xy2 = tuple(UOp(Ops.GEP, dtypes.float, (val2, ), (i,)) for i in range(2))
|
||||
self.assertIs(_test_vec(xy1+xy2).op, Ops.VECTORIZE)
|
||||
|
||||
def test_gep_vec_const_fold(self):
|
||||
for vec_size in [2, 4, 8]:
|
||||
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
||||
uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
||||
uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
|
||||
for uop, const in zip(uops, consts):
|
||||
self.assertEqual(uop, const)
|
||||
|
||||
def test_wmma_vectorize_fold(self):
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i))
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[0], acc)
|
||||
self.assertEqual(len(uops), 1)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i))
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[0], acc)
|
||||
self.assertEqual(len(uops), 1)
|
||||
@@ -314,109 +314,109 @@ class TestUOpGraph(unittest.TestCase):
|
||||
@unittest.skip("wmma is wrong here, it needs an arg")
|
||||
def test_wmma_vectorize_no_fold(self):
|
||||
for i in [4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
|
||||
for i in [4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
|
||||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assertEqual(uops[-1], wmma)
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
|
||||
d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
|
||||
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (d1, idx))
|
||||
alu = ld.lt(1).cast(dtypes.bool)
|
||||
out = UOp(UOps.STORE, dtypes.void, (d0, idx, alu))
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0)
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
|
||||
|
||||
def test_double_cast_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
|
||||
d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
|
||||
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
|
||||
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (d1, idx))
|
||||
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
||||
out = UOp(UOps.STORE, dtypes.void, (d0, idx, alu))
|
||||
out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu))
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
v = UOp.variable("tmp", 0, 1)
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
||||
c2 = UOp(Ops.CONST, dtypes.int, arg=2)
|
||||
c4 = UOp(Ops.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(Ops.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
out = UOp(Ops.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 3)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, UOps.ALU)
|
||||
self.assertEqual(out.op, Ops.ALU)
|
||||
self.assertEqual(out.arg, BinaryOps.ADD)
|
||||
self.assertEqual(out.src[1].op, UOps.CONST)
|
||||
self.assertEqual(out.src[1].op, Ops.CONST)
|
||||
self.assertEqual(out.src[1].arg, 6)
|
||||
|
||||
def test_fold_gated_load(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
glbl1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
glbl2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False)))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, ld1+ld0))])
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False)))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, ld1+ld0))])
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
smem = UOp(UOps.DEFINE_LOCAL, dtypes.int.ptr(local=True), (), ("temp", 1))
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
|
||||
barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False), barrier))
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True), barrier))
|
||||
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))])
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(local=True), (), ("temp", 1))
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
st = UOp(Ops.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False), barrier))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True), barrier))
|
||||
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))])
|
||||
ld0 = uops[-1].src[-1]
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assertEqual(ld0.src[0], smem.index(lidx+2))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
idx0 = UOp.const(dtypes.int, 0)
|
||||
idx1 = UOp.const(dtypes.int, 0)
|
||||
val = UOp.const(dtypes.int, 42)
|
||||
st0 = UOp(UOps.STORE, dtypes.void, (glbl, idx0, val, UOp.const(dtypes.bool, False)))
|
||||
st1 = UOp(UOps.STORE, dtypes.void, (glbl, idx1, val, UOp.const(dtypes.bool, True)))
|
||||
st0 = UOp(Ops.STORE, dtypes.void, (glbl, idx0, val, UOp.const(dtypes.bool, False)))
|
||||
st1 = UOp(Ops.STORE, dtypes.void, (glbl, idx1, val, UOp.const(dtypes.bool, True)))
|
||||
uops = to_uops_list([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops), 5)
|
||||
@@ -424,23 +424,23 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
@unittest.skip("this is a uop type error")
|
||||
def test_asserts_bad_gate(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
bad_gate = UOp.const(dtypes.int, 1)
|
||||
with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
||||
with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
||||
|
||||
def test_switched_range_order(self):
|
||||
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
c0 = UOp.const(dtypes.int, 0)
|
||||
c2 = UOp.const(dtypes.int, 2)
|
||||
cf = UOp.const(dtypes.float, 0.0)
|
||||
r1 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 0, False))
|
||||
r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False))
|
||||
alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
|
||||
store = UOp(UOps.STORE, dtypes.void, (glbl, alu, cf))
|
||||
r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 0, False))
|
||||
r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 1, False))
|
||||
alu = UOp(Ops.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
|
||||
store = UOp(Ops.STORE, dtypes.void, (glbl, alu, cf))
|
||||
uops = to_uops_list([store])
|
||||
ranges = [x for x in uops if x.op is UOps.RANGE]
|
||||
endranges = [x for x in uops if x.op is UOps.ENDRANGE]
|
||||
ranges = [x for x in uops if x.op is Ops.RANGE]
|
||||
endranges = [x for x in uops if x.op is Ops.ENDRANGE]
|
||||
# ranges are closed in the right order
|
||||
self.assertEqual(endranges[-1].src[0], ranges[0])
|
||||
|
||||
@@ -449,90 +449,90 @@ def float4_rewrite(sink): return full_graph_rewrite(sink, Renderer())
|
||||
|
||||
class TestExpander(unittest.TestCase):
|
||||
def test_expand_add_broadcast(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
sink = expander_rewrite(e1+3)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 4
|
||||
assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 4
|
||||
self.assertTupleEqual(sink.src[0].arg, (3,4,5,6))
|
||||
|
||||
def test_contract_simple(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
self.assertEqual(sink.op, UOps.VCONST)
|
||||
self.assertEqual(sink.op, Ops.VCONST)
|
||||
self.assertTupleEqual(sink.arg, (0,1,2,3))
|
||||
|
||||
def test_contract_axis_1(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),)
|
||||
assert sink.src[0].op is UOps.VCONST
|
||||
assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),)
|
||||
assert sink.src[0].op is Ops.VCONST
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0,4,8,12))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:], (3,7,11,15))
|
||||
|
||||
def test_contract_axis_2(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),)
|
||||
assert sink.src[0].op is UOps.VCONST
|
||||
assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),)
|
||||
assert sink.src[0].op is Ops.VCONST
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0,1,2,3))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:], (12,13,14,15))
|
||||
|
||||
def test_contract_axis_2_big(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
|
||||
assert sink.op is Ops.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
|
||||
self.assertTupleEqual(sink.src[0].arg[0:2], (0,4))
|
||||
self.assertTupleEqual(sink.src[0].arg[12:14], (10,14))
|
||||
|
||||
def test_contract_multi_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
|
||||
sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
|
||||
assert sink.op is Ops.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 4, 2, 6))
|
||||
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
|
||||
assert sink.op is Ops.EXPAND and sink.arg == ((1, 2), (4, 2))
|
||||
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 2, 4, 6))
|
||||
|
||||
def test_contract_mid(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.EXPAND and sink.arg == ((1,2),(3,2))
|
||||
assert sink.src[0].op is UOps.VCONST and len(sink.src[0].arg) == 8
|
||||
assert sink.op is Ops.EXPAND and sink.arg == ((1,2),(3,2))
|
||||
assert sink.src[0].op is Ops.VCONST and len(sink.src[0].arg) == 8
|
||||
self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7))
|
||||
|
||||
def test_contract_no_expand(self):
|
||||
e1 = UOp(UOps.DEFINE_VAR, dtypes.int)
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
e1 = UOp(Ops.DEFINE_VAR, dtypes.int)
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.VECTORIZE and len(sink.src) == 2
|
||||
assert sink.op is Ops.VECTORIZE and len(sink.src) == 2
|
||||
assert sink.src[0] == sink.src[1]
|
||||
|
||||
def test_contract_half_expand(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
con = UOp(Ops.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
|
||||
sink = expander_rewrite(con)
|
||||
assert sink.op is UOps.VCONST and len(sink.arg) == 8
|
||||
assert sink.op is Ops.VCONST and len(sink.arg) == 8
|
||||
assert sink.arg[0] == sink.arg[1]
|
||||
assert sink.arg[0] != sink.arg[2]
|
||||
assert sink.arg[6] == sink.arg[7]
|
||||
|
||||
def test_expand_same_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
sink = expander_rewrite(e1+e2)
|
||||
self.assertEqual(sink.op, UOps.EXPAND)
|
||||
self.assertEqual(sink.src[0].op, UOps.VCONST)
|
||||
self.assertEqual(sink.op, Ops.EXPAND)
|
||||
self.assertEqual(sink.src[0].op, Ops.VCONST)
|
||||
self.assertTupleEqual(sink.src[0].arg, (0,5,10,15))
|
||||
|
||||
def test_expand_different_axis(self, flip=False):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
|
||||
e2 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
|
||||
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
|
||||
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16
|
||||
assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 16
|
||||
assert sink.arg == ((1, 4), (2, 4))
|
||||
self.assertTupleEqual(sink.src[0].arg, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
|
||||
|
||||
@@ -540,47 +540,47 @@ class TestExpander(unittest.TestCase):
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_reduce_known_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
sink = UOp(UOps.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD)
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD)
|
||||
sink = expander_rewrite(sink)
|
||||
assert sink.op is UOps.CONST
|
||||
assert sink.op is Ops.CONST
|
||||
self.assertEqual(sink.arg, 3*(0+1+2+3))
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_reduce_const(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
sink = UOp(UOps.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD)
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD)
|
||||
sink = expander_rewrite(sink)
|
||||
assert sink.op is UOps.CONST
|
||||
assert sink.op is Ops.CONST
|
||||
self.assertEqual(sink.arg, 3*4)
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_double_expand(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
|
||||
e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((1,2),))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
|
||||
e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
|
||||
e = UOp(Ops.EXPAND, dtypes.int, (e1, e2), ((1,2),))
|
||||
sink = expander_rewrite(e)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 8
|
||||
assert sink.op is Ops.EXPAND and len(sink.src) == 8
|
||||
assert sink.arg == ((1, 2), (2, 4))
|
||||
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7])
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_double_expand_reverse(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
|
||||
e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((2,2),))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
|
||||
e = UOp(Ops.EXPAND, dtypes.int, (e1, e2), ((2,2),))
|
||||
sink = expander_rewrite(e)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 8
|
||||
assert sink.op is Ops.EXPAND and len(sink.src) == 8
|
||||
assert sink.arg == ((1, 4), (2, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7])
|
||||
|
||||
@unittest.skip("no longer supported")
|
||||
def test_double_expand_middle(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))
|
||||
e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((2,2),))
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
|
||||
e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))
|
||||
e = UOp(Ops.EXPAND, dtypes.int, (e1, e2), ((2,2),))
|
||||
sink = expander_rewrite(e)
|
||||
assert sink.op is UOps.EXPAND and len(sink.src) == 8
|
||||
assert sink.op is Ops.EXPAND and len(sink.src) == 8
|
||||
assert sink.arg == ((1, 2), (2, 2), (3, 2))
|
||||
self.assertListEqual([x.arg for x in sink.src], [0, 1, 4, 5, 2, 3, 6, 7])
|
||||
|
||||
@@ -588,106 +588,106 @@ class TestExpander(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
@unittest.skip
|
||||
def test_reduce_different_axis(self):
|
||||
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
|
||||
sink = UOp(UOps.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD)
|
||||
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
|
||||
e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
|
||||
sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD)
|
||||
sink = expander_rewrite(sink)
|
||||
print(sink)
|
||||
|
||||
class TestLoadStoreFolder(unittest.TestCase):
|
||||
def test_simple_load_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1
|
||||
|
||||
def test_two_load_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2
|
||||
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 2
|
||||
|
||||
def test_simple_load_fold_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
gate = UOp(Ops.DEFINE_VAR, dtypes.bool)
|
||||
load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
|
||||
self.assertEqual(single_load.src[1].op, UOps.VECTORIZE)
|
||||
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1
|
||||
single_load = [x for x in sink.sparents if x.op is Ops.LOAD][0]
|
||||
self.assertEqual(single_load.src[1].op, Ops.VECTORIZE)
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3
|
||||
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 3
|
||||
|
||||
def test_simple_store_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0))) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0))) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
|
||||
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1
|
||||
|
||||
def test_simple_store_fold_gate(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
|
||||
one_store = [x for x in sink.sparents if x.op is UOps.STORE][0]
|
||||
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1
|
||||
one_store = [x for x in sink.sparents if x.op is Ops.STORE][0]
|
||||
assert len(one_store.src) == 3
|
||||
assert str(one_store.src[2]) == str(gate) # huh, why do i need str here?
|
||||
|
||||
def test_simple_store_dont_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3
|
||||
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 3
|
||||
|
||||
class TestIFUOps(unittest.TestCase):
|
||||
def test_create_ifs(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 4))
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 4))
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gate = valid&(lidx.ne(2))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
st = UOp(UOps.STORE, dtypes.void, (sbuf, idx, UOp.const(dtypes.float, 42)))
|
||||
barrier = UOp(UOps.BARRIER, dtypes.void, (st,))
|
||||
lbuf = UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, 0), barrier))
|
||||
store = UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, 0), lbuf, gate))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (store,))
|
||||
st = UOp(Ops.STORE, dtypes.void, (sbuf, idx, UOp.const(dtypes.float, 42)))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st,))
|
||||
lbuf = UOp(Ops.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, 0), barrier))
|
||||
store = UOp(Ops.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, 0), lbuf, gate))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
if_uops = [u for u in sink.parents if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 2)
|
||||
|
||||
def test_expand_ifs_one_gate(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 16))
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 16))
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
gate = valid&(lidx.ne(2))
|
||||
st = UOp(UOps.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42)))
|
||||
barrier = UOp(UOps.BARRIER, dtypes.void, (st,))
|
||||
lbufs = [UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)]
|
||||
stores = [UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, i), lbufs[i], gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(stores))
|
||||
st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42)))
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (st,))
|
||||
lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)]
|
||||
stores = [UOp(Ops.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, i), lbufs[i], gate)) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
if_uops = [u for u in sink.parents if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
@@ -696,14 +696,14 @@ class TestIFUOps(unittest.TestCase):
|
||||
# this will be fixed with the merge gated stores bounty
|
||||
@unittest.expectedFailure
|
||||
def test_expand_ifs_dumb(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gate = valid&(lidx.ne(2))
|
||||
stores = [UOp(UOps.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(stores))
|
||||
stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
|
||||
sink = full_graph_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
if_uops = [u for u in sink.parents if u.op is Ops.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assertEqual(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.engine.schedule import create_schedule, to_si
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
@@ -23,18 +23,18 @@ def _uops_to_prg(uops_list):
|
||||
return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops,
|
||||
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
|
||||
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
def uop(uops:List[UOp], uop:Ops, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(src), arg))
|
||||
return uops[-1]
|
||||
|
||||
def _test_single_value(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
|
||||
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
|
||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
|
||||
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i], uop(uops, Ops.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
|
||||
alu = uop(uops, Ops.ALU, output_dtype, loads, op)
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), alu))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)]
|
||||
prg = _uops_to_prg([out])
|
||||
@@ -46,10 +46,10 @@ def _test_single_value(vals, op, dts):
|
||||
def _test_single_value_const(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
||||
alu = uop(uops, Ops.ALU, output_dtype, loads, op)
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), alu))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
prg = _uops_to_prg([out])
|
||||
prg.exec([buf])
|
||||
@@ -59,9 +59,9 @@ def _test_single_value_const(vals, op, dts):
|
||||
|
||||
def _test_uops_result(output_dtype, uops, res):
|
||||
# uops = []
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||
# res = output_fn(uops)
|
||||
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
|
||||
out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), res))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
prg = _uops_to_prg([out])
|
||||
prg.exec([buf])
|
||||
@@ -244,63 +244,63 @@ class TestConstantFolding(unittest.TestCase):
|
||||
si = create_schedule([t.lazydata])
|
||||
assert len(si) == 1
|
||||
ji = lower_schedule_item(si[-1])
|
||||
assert any(uop.op is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast"
|
||||
assert any(uop.op is Ops.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast"
|
||||
|
||||
class TestGatedStoreRewrite(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_tiny_gate_store(self):
|
||||
gmem = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
store = UOp(UOps.STORE, dtypes.void, (gmem, idx, val, gate))
|
||||
store = UOp(Ops.STORE, dtypes.void, (gmem, idx, val, gate))
|
||||
uops = to_uops_list([store])
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.op is UOps.IF)
|
||||
endif = next(u for u in uops if u.op is UOps.ENDIF)
|
||||
if_uop = next(u for u in uops if u.op is Ops.IF)
|
||||
endif = next(u for u in uops if u.op is Ops.ENDIF)
|
||||
assert endif.src[0] is if_uop
|
||||
gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
|
||||
self.assertEqual(len(gated_uops), 1)
|
||||
self.assertIs(gated_uops[-1].op, UOps.STORE)
|
||||
self.assertIs(gated_uops[-1].op, Ops.STORE)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_gate_some_stores(self):
|
||||
gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val)]
|
||||
uops = linearize_uop(stores)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.op is UOps.IF)
|
||||
endif = next(u for u in uops if u.op is UOps.ENDIF)
|
||||
if_uop = next(u for u in uops if u.op is Ops.IF)
|
||||
endif = next(u for u in uops if u.op is Ops.ENDIF)
|
||||
assert endif.src[0] is if_uop
|
||||
gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
|
||||
self.assertEqual(len(gated_uops), 1)
|
||||
self.assertIs(gated_uops[-1].op, UOps.STORE)
|
||||
self.assertIs(gated_uops[-1].op, Ops.STORE)
|
||||
|
||||
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
|
||||
@unittest.expectedFailure
|
||||
def test_merge_ifs_alt(self):
|
||||
gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val, gate)]
|
||||
uops = linearize_uop(stores)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
ifs = [u for u in uops if u.op is UOps.IF]
|
||||
endifs = [u for u in uops if u.op is UOps.ENDIF]
|
||||
ifs = [u for u in uops if u.op is Ops.IF]
|
||||
endifs = [u for u in uops if u.op is Ops.ENDIF]
|
||||
self.assertEqual(len(ifs), 1)
|
||||
self.assertEqual(len(endifs), 1)
|
||||
gated_uops = tuple(uops.uops[uops.uops.index(ifs[0])+1:uops.uops.index(endifs[0])])
|
||||
self.assertEqual(len(gated_uops), 2)
|
||||
for x in gated_uops: self.assertIs(x.op, UOps.STORE)
|
||||
for x in gated_uops: self.assertIs(x.op, Ops.STORE)
|
||||
|
||||
class TestLocalAccess(unittest.TestCase):
|
||||
# NOTE: this is failing on METAL CI, no idea why. Works locally.
|
||||
@@ -308,44 +308,44 @@ class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
def test_local_basic(self):
|
||||
uops = []
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.float32.ptr(local=True), (), ('smem', 16))
|
||||
st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0)))
|
||||
barr = uop(uops, UOps.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr))
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(local=True), (), ('smem', 16))
|
||||
st = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 0), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 0), barr))
|
||||
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
def test_local_indirect(self):
|
||||
uops = []
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.int32.ptr(local=True), (), ('smem', 16))
|
||||
st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2)))
|
||||
st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42)))
|
||||
barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2))
|
||||
ofs = uop(uops, UOps.LOAD, dtypes.int32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), barr))
|
||||
sres = uop(uops, UOps.LOAD, dtypes.int32, (smem, ofs))
|
||||
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(local=True), (), ('smem', 16))
|
||||
st1 = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 1), uop(uops, Ops.CONST, dtypes.int32, (), 2)))
|
||||
st2 = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 2), uop(uops, Ops.CONST, dtypes.int32, (), 42)))
|
||||
barr = uop(uops, Ops.BARRIER, dtypes.void, (st1,st2))
|
||||
ofs = uop(uops, Ops.LOAD, dtypes.int32, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 1), barr))
|
||||
sres = uop(uops, Ops.LOAD, dtypes.int32, (smem, ofs))
|
||||
self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42)
|
||||
|
||||
@unittest.skipUnless(getenv("PTX"), "This only tests assembly backends")
|
||||
class TestAssembly(unittest.TestCase):
|
||||
def test_bitshift_left(self):
|
||||
g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
|
||||
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
|
||||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
|
||||
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
|
||||
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
|
||||
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
|
||||
l1 = UOp(Ops.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
|
||||
a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
|
||||
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops[-1].arg, BinaryOps.SHL)
|
||||
self.assertEqual(uops[-2].arg, BinaryOps.MUL)
|
||||
|
||||
def test_bitshift_right(self):
|
||||
g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
|
||||
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
|
||||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
|
||||
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
|
||||
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
|
||||
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
|
||||
l1 = UOp(Ops.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
|
||||
a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
|
||||
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops[-1].arg, BinaryOps.SHR)
|
||||
@@ -354,38 +354,38 @@ class TestAssembly(unittest.TestCase):
|
||||
class TestUOpMethod(unittest.TestCase):
|
||||
@unittest.skip("uops lt no longer ordered")
|
||||
def test_compare_alu_same_src_different_arg(self):
|
||||
a = UOp(UOps.CONST, dtypes.float, (), 2.0)
|
||||
b = UOp(UOps.CONST, dtypes.float, (), 3.0)
|
||||
a = UOp(Ops.CONST, dtypes.float, (), 2.0)
|
||||
b = UOp(Ops.CONST, dtypes.float, (), 3.0)
|
||||
|
||||
add = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.ADD)
|
||||
mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL)
|
||||
add = UOp(Ops.ALU, dtypes.float, (a, b), BinaryOps.ADD)
|
||||
mul = UOp(Ops.ALU, dtypes.float, (a, b), BinaryOps.MUL)
|
||||
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
|
||||
|
||||
def test_uop_variables(self):
|
||||
a = UOp.variable("a", 1, 10)
|
||||
uop_var = UOp.const(dtypes.int, a)
|
||||
st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0),
|
||||
st_var = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0),
|
||||
ShapeTracker.from_shape((2, a)).to_uop()))
|
||||
ast_vars = (st_var+uop_var).variables()
|
||||
self.assertEqual(len(ast_vars), 1)
|
||||
self.assertEqual(ast_vars[0], a)
|
||||
|
||||
def test_const_factor(self):
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 8))
|
||||
self.assertEqual(UOp(UOps.CONST, dtypes.int, (), 17).const_factor(), 17)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 8))
|
||||
self.assertEqual(UOp(Ops.CONST, dtypes.int, (), 17).const_factor(), 17)
|
||||
self.assertEqual(gidx0.const_factor(), 1)
|
||||
self.assertEqual((gidx0*3).const_factor(), 3)
|
||||
self.assertEqual((gidx0*3+6).const_factor(), 3)
|
||||
self.assertEqual((gidx0*3+1).const_factor(), 1)
|
||||
|
||||
def test_replace(self):
|
||||
x = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
x = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
self.assertIs(x.replace(arg=None).arg, None)
|
||||
with self.assertRaises(AssertionError): x.replace(field="a")
|
||||
|
||||
class TestUOpStr(unittest.TestCase):
|
||||
def test_uop_str(self):
|
||||
a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0)
|
||||
a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0)
|
||||
for _ in range(20): a = a + a
|
||||
assert len(str(a)) < 10_000, "exponential string growth"
|
||||
assert str(eval(str(a))) == str(a)
|
||||
@@ -394,11 +394,11 @@ class TestUOpStr(unittest.TestCase):
|
||||
t = t + t * Tensor.rand(10)
|
||||
# nice big complicated uop
|
||||
with Context(NOOPT=1):
|
||||
sink = UOp(UOps.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
|
||||
self.assertEqual(sink, eval(str(sink)))
|
||||
|
||||
def test_vectorized_str(self):
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4)))
|
||||
vec = UOp(Ops.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4)))
|
||||
assert str(eval(str(vec))) == str(vec)
|
||||
|
||||
@unittest.skip("uop no longer has order like this")
|
||||
@@ -406,23 +406,23 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||
# NOTE: these tests skip type_verify since they add dtype to STORE
|
||||
@unittest.expectedFailure
|
||||
def test_simple_order(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
st0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1 = UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
uops = to_uops_list([st1, st0], skip_check=True)
|
||||
stores = [st for st in uops if st.op is UOps.STORE]
|
||||
stores = [st for st in uops if st.op is Ops.STORE]
|
||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_ordering_multi_output(self):
|
||||
buf0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
buf1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
st0_0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1_1 = UOp(UOps.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
buf0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||
st0_0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1_0 = UOp(Ops.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
st0_1 = UOp(Ops.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1_1 = UOp(Ops.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
uops = to_uops_list([st0_0, st1_0, st0_1, st1_1], skip_check=True)
|
||||
stores = [st for st in uops if st.op is UOps.STORE]
|
||||
stores = [st for st in uops if st.op is Ops.STORE]
|
||||
print("\n".join(map(str, stores)))
|
||||
# buf0 stores come first
|
||||
self.assertEqual(stores[0].src[0].arg, stores[1].src[0].arg)
|
||||
@@ -433,12 +433,12 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||
assert stores[2].src[1] < stores[3].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
|
||||
def test_simple_order_with_special(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
st0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||
st1 = UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||
uops = linearize_uop(UOp.sink(st1, st0), skip_check=True)
|
||||
stores = [st for st in uops if st.op is UOps.STORE]
|
||||
stores = [st for st in uops if st.op is Ops.STORE]
|
||||
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||
|
||||
class TestUPatHelpers(unittest.TestCase):
|
||||
@@ -447,7 +447,7 @@ class TestUPatHelpers(unittest.TestCase):
|
||||
self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
|
||||
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py")
|
||||
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
|
||||
test_upat = UPat(UOps.CONST, dtypes.bool)
|
||||
test_upat = UPat(Ops.CONST, dtypes.bool)
|
||||
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.helpers import getenv, GlobalCounters
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, Ops, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
|
||||
@@ -119,23 +119,23 @@ class TestUOpsStats(unittest.TestCase):
|
||||
|
||||
#MULACC should have the same stats as MUL + ADD
|
||||
def test_mulacc(self):
|
||||
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = UOp(UOps.LOAD, dtypes.int, (globl.index(o1),))
|
||||
u2 = UOp(UOps.LOAD, dtypes.int, (globl.index(o2),))
|
||||
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
|
||||
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
||||
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
|
||||
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
|
||||
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(Ops.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
|
||||
u5 = UOp(Ops.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
||||
uops = linearize_uop(u5.sink())
|
||||
|
||||
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = UOp(UOps.LOAD, dtypes.int, (globl.index(o1),))
|
||||
u2 = UOp(UOps.LOAD, dtypes.int, (globl.index(o2),))
|
||||
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
|
||||
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
|
||||
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
|
||||
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(Ops.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
|
||||
uops_fma = linearize_uop(u4.sink())
|
||||
|
||||
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, Optional
|
||||
import unittest
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, \
|
||||
graph_rewrite, contexts, track_rewrites
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestViz(unittest.TestCase):
|
||||
pm = PatternMatcher([
|
||||
(UPat.var("x")*1, lambda x:x),
|
||||
])
|
||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
uops = helper_test_viz(a*1, pm)
|
||||
self.assertEqual(len(uops), 1)
|
||||
self.assertEqual(uops[0], a)
|
||||
@@ -37,21 +37,21 @@ class TestViz(unittest.TestCase):
|
||||
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
|
||||
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
|
||||
])
|
||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
uops = helper_test_viz(a+a, pm)
|
||||
self.assertEqual(len(uops), 2)
|
||||
self.assertEqual(uops[0], a*2)
|
||||
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
|
||||
|
||||
def test_rewrite_with_ctx(self):
|
||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0)))
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
b = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0)))
|
||||
def store_load(ctx:Dict[UOp, None], x:UOp) -> Optional[UOp]:
|
||||
if x in ctx: return None
|
||||
ctx[x] = None
|
||||
return UOp.store(*x.src, x)
|
||||
pm = PatternMatcher([
|
||||
(UPat(UOps.LOAD, name="x"), store_load),
|
||||
(UPat(Ops.LOAD, name="x"), store_load),
|
||||
])
|
||||
uops = helper_test_viz(a+b, pm, {})
|
||||
self.assertEqual(len(uops), 2)
|
||||
@@ -61,7 +61,7 @@ class TestViz(unittest.TestCase):
|
||||
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
@track_rewrites(named=True)
|
||||
def do_rewrite(x:UOp): return graph_rewrite(x, simple)
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
do_rewrite(ld*1)
|
||||
do_rewrite(ld*2)
|
||||
ret = get_metadata(contexts)
|
||||
@@ -79,13 +79,13 @@ class TestViz(unittest.TestCase):
|
||||
def do_rewrite(x:UOp):
|
||||
x = graph_rewrite(x, simple) # NOTE: viz tracks this
|
||||
raise Exception("test")
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
with self.assertRaises(Exception): do_rewrite(ld*1)
|
||||
ret = get_metadata(contexts)
|
||||
self.assertEqual(len(ret), 1)
|
||||
|
||||
def test_fold_const(self):
|
||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
graph = uop_to_json(a)
|
||||
assert not any(v[0].startswith("CONST") for v in graph.values())
|
||||
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -23,7 +23,7 @@ class TestWinograd(unittest.TestCase):
|
||||
sched = create_schedule([out.lazydata])
|
||||
|
||||
for i,s in enumerate(sched):
|
||||
if s.ast.op is not UOps.SINK: continue
|
||||
if s.ast.op is not Ops.SINK: continue
|
||||
ops = s.ast.parents
|
||||
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
|
||||
l = Kernel(s.ast)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, math
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import all_same
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, exec_alu
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps, exec_alu
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
||||
|
||||
# Helper function to apply the graph rewrite
|
||||
@@ -9,12 +9,12 @@ def apply_rewrite(expr):
|
||||
return full_graph_rewrite(expr.sink()).src[0]
|
||||
|
||||
def evaluate_uop(uop, variables):
|
||||
if uop.op == UOps.CONST:
|
||||
if uop.op == Ops.CONST:
|
||||
return uop.arg
|
||||
elif uop.op == UOps.DEFINE_VAR:
|
||||
elif uop.op == Ops.DEFINE_VAR:
|
||||
var_name = uop.arg[0]
|
||||
return variables[var_name]
|
||||
elif uop.op == UOps.ALU:
|
||||
elif uop.op == Ops.ALU:
|
||||
src_values = [evaluate_uop(src, variables) for src in uop.src]
|
||||
return exec_alu(uop.arg, uop.dtype, src_values)
|
||||
else:
|
||||
@@ -23,12 +23,12 @@ def evaluate_uop(uop, variables):
|
||||
class TestArithmeticSimplifications(unittest.TestCase):
|
||||
def test_full_graph_rewrite_division_by_zero(self):
|
||||
optimized_div_uop = apply_rewrite(UOp.const(dtypes.float32, 10.0) / UOp.const(dtypes.float32, 0.0))
|
||||
self.assertEqual(optimized_div_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_div_uop.op, Ops.CONST)
|
||||
self.assertTrue(math.isinf(optimized_div_uop.arg) or math.isnan(optimized_div_uop.arg))
|
||||
|
||||
def test_full_graph_rewrite_redundant_operations(self):
|
||||
optimized_uop = apply_rewrite((UOp.const(dtypes.float32, 10.0) + UOp.const(dtypes.float32, 0.0)) * UOp.const(dtypes.float32, 1.0))
|
||||
self.assertEqual(optimized_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_uop.arg, 10.0)
|
||||
|
||||
def test_full_graph_rewrite_large_graph(self):
|
||||
@@ -36,17 +36,17 @@ class TestArithmeticSimplifications(unittest.TestCase):
|
||||
for i in range(1, 101):
|
||||
prev_uop += UOp.const(dtypes.int32, i)
|
||||
optimized_uop = apply_rewrite(prev_uop)
|
||||
self.assertEqual(optimized_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_uop.arg, sum(range(1, 101)))
|
||||
|
||||
def test_full_graph_rewrite_division_by_one(self):
|
||||
optimized_uop = apply_rewrite(UOp.const(dtypes.float32, 42.0) / UOp.const(dtypes.float32, 1.0))
|
||||
self.assertEqual(optimized_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_uop.arg, 42.0)
|
||||
|
||||
def test_full_graph_rewrite_modulo_by_one(self):
|
||||
optimized_uop = apply_rewrite(UOp.const(dtypes.int32, 42) % UOp.const(dtypes.int32, 1))
|
||||
self.assertEqual(optimized_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_uop.arg, 0)
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class TestFoldingAndReduction(unittest.TestCase):
|
||||
inner_range = UOp.range(dtypes.int32, 0, 4, 1)
|
||||
expr = (outer_range * 10) + inner_range
|
||||
optimized_reduce_uop = apply_rewrite(expr.reduce(BinaryOps.ADD, outer_range, inner_range))
|
||||
self.assertEqual(optimized_reduce_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_reduce_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_reduce_uop.arg, sum((i * 10) + j for i in range(8) for j in range(4)))
|
||||
|
||||
|
||||
@@ -98,32 +98,32 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
def test_full_graph_rewrite_modulo_folding_with_define_var(self):
|
||||
x_var_uop = UOp.variable('x', 0, 100)
|
||||
optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4)
|
||||
self.assertEqual(optimized_mod_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_mod_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_mod_uop.arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_division_folding_with_define_var(self):
|
||||
n_var_uop = UOp.variable('n', 1, 1000)
|
||||
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
|
||||
self.assertEqual(optimized_div_uop.op, UOps.ALU)
|
||||
self.assertEqual(optimized_div_uop.op, Ops.ALU)
|
||||
self.assertEqual(optimized_div_uop.arg, BinaryOps.MUL)
|
||||
self.assertEqual(optimized_div_uop.src[1].arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_complex_mod_div_folding(self):
|
||||
k_var_uop = UOp.variable('k', 0, 50)
|
||||
optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2)
|
||||
self.assertEqual(optimized_div_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_div_uop.op, Ops.CONST)
|
||||
self.assertEqual(optimized_div_uop.arg, 1)
|
||||
|
||||
def test_graph_rewrite_div_folding_bug(self):
|
||||
lhs = UOp(UOps.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(UOps.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4),
|
||||
UOp(UOps.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
|
||||
lhs = UOp(Ops.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4),
|
||||
UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
|
||||
rhs = UOp.const(dtypes.int.vec(4), 2)
|
||||
unopt = lhs.lt(rhs)
|
||||
opt = apply_rewrite(unopt)
|
||||
print(unopt)
|
||||
print(opt)
|
||||
if opt.op is UOps.VECTORIZE: self.assertFalse(all_same(opt.src))
|
||||
if opt.op is Ops.VECTORIZE: self.assertFalse(all_same(opt.src))
|
||||
|
||||
def test_full_graph_rewrite_modulo_large_divisor(self):
|
||||
x_var_uop = UOp.variable('x', 1, 5)
|
||||
@@ -180,12 +180,12 @@ class TestGEPAndVectorizeRewrite(unittest.TestCase):
|
||||
|
||||
def test_gep_on_vconst(self):
|
||||
# GEP on a VCONST to extract a single element
|
||||
vconst = UOp(UOps.VCONST, dtypes.float32.vec(4), arg=(1.0, 2.0, 3.0, 4.0))
|
||||
vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(1.0, 2.0, 3.0, 4.0))
|
||||
self.assertEqual(apply_rewrite(vconst.gep(2)).arg, 3.0)
|
||||
|
||||
def test_gep_tuple_on_vconst(self):
|
||||
# GEP on a VCONST using a tuple to extract multiple elements
|
||||
vconst = UOp(UOps.VCONST, dtypes.float32.vec(4), arg=(7.0, 8.0, 9.0, 10.0))
|
||||
vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(7.0, 8.0, 9.0, 10.0))
|
||||
optimized_uop = apply_rewrite(vconst.gep((1, 3)))
|
||||
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [8.0, 10.0])
|
||||
|
||||
@@ -198,7 +198,7 @@ class TestGEPAndVectorizeRewrite(unittest.TestCase):
|
||||
def test_vectorize_multiple_elements(self):
|
||||
# Vectorizing multiple elements using GEP
|
||||
base_vector = UOp.const(dtypes.float32.vec(4), (5.0, 10.0, 15.0, 20.0))
|
||||
vectorized_uop = UOp(UOps.VECTORIZE, dtypes.float32.vec(4), src=(base_vector.gep(0), base_vector.gep(1), base_vector.gep(2), base_vector.gep(3)))
|
||||
vectorized_uop = UOp(Ops.VECTORIZE, dtypes.float32.vec(4), src=(base_vector.gep(0), base_vector.gep(1), base_vector.gep(2), base_vector.gep(3)))
|
||||
optimized_uop = apply_rewrite(vectorized_uop)
|
||||
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [5.0, 10.0, 15.0, 20.0])
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import unittest, itertools
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOps, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
|
||||
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
|
||||
from tinygrad.ops import PatternMatcher, UPat
|
||||
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
def test_simple_match(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=1)
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.int, arg=1)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
@@ -30,9 +30,9 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
nonlocal match_cnt
|
||||
match_cnt += 1
|
||||
assert len(x.src) == 0
|
||||
return UOp(UOps.CONST, src=(UOp(UOps.CONST),))
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
# second rewrite shouldn't match anything
|
||||
c1 = matcher.rewrite(c1)
|
||||
c1 = matcher.rewrite(c1)
|
||||
@@ -42,9 +42,9 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
def fxn(ctx, x):
|
||||
ctx.append(True)
|
||||
assert len(x.src) == 0
|
||||
return UOp(UOps.CONST, src=(UOp(UOps.CONST),))
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
# second rewrite shouldn't match anything
|
||||
ctx = []
|
||||
c1 = matcher.rewrite(c1, ctx)
|
||||
@@ -52,33 +52,33 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(len(ctx), 1)
|
||||
|
||||
def test_uop(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_uop_set(self):
|
||||
matcher = PatternMatcher([(UPat({UOps.CONST, UOps.CAST}, name="x"), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
c2 = UOp(UOps.CAST, dtypes.int, (c1,))
|
||||
c3 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c3, c3), BinaryOps.ADD)
|
||||
matcher = PatternMatcher([(UPat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.bool, arg=False)
|
||||
c2 = UOp(Ops.CAST, dtypes.int, (c1,))
|
||||
c3 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c4 = UOp(Ops.ALU, dtypes.float, (c3, c3), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), c2)
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
(UPat(UOps.CONST, arg=0, name="x"), lambda x: x),
|
||||
(UPat(UOps.CONST, arg=False, name="x"), lambda x: x),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
|
||||
(UPat(Ops.CONST, arg=0, name="x"), lambda x: x),
|
||||
(UPat(Ops.CONST, arg=False, name="x"), lambda x: x),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
|
||||
])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=0.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MAX)
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MUL)
|
||||
c5 = UOp(UOps.CONST, dtypes.int, arg=-1)
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=0.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.bool, arg=False)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MAX)
|
||||
c4 = UOp(Ops.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MUL)
|
||||
c5 = UOp(Ops.CONST, dtypes.int, arg=-1)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), c2)
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
@@ -87,17 +87,17 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
|
||||
def test_filter_arg(self):
|
||||
matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=[UPat(UOps.CONST, name="c"), UPat(UOps.CONST, arg=2)], name="x"),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"),
|
||||
lambda x,c: x if c.arg in {1, -1} else None)
|
||||
])
|
||||
y1 = UOp(UOps.CONST, dtypes.int, arg=1)
|
||||
y2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
y3 = UOp(UOps.CONST, dtypes.int, arg=-1)
|
||||
c1 = UOp(UOps.ALU, dtypes.int, (y1, y2), BinaryOps.MUL)
|
||||
c2 = UOp(UOps.ALU, dtypes.int, (y2, y2), BinaryOps.MUL)
|
||||
c3 = UOp(UOps.ALU, dtypes.int, (y3, y2), BinaryOps.MUL)
|
||||
c4 = UOp(UOps.ALU, dtypes.int, (y2, y1), BinaryOps.MUL)
|
||||
c5 = UOp(UOps.ALU, dtypes.int, (y2, y3), BinaryOps.MUL)
|
||||
y1 = UOp(Ops.CONST, dtypes.int, arg=1)
|
||||
y2 = UOp(Ops.CONST, dtypes.int, arg=2)
|
||||
y3 = UOp(Ops.CONST, dtypes.int, arg=-1)
|
||||
c1 = UOp(Ops.ALU, dtypes.int, (y1, y2), BinaryOps.MUL)
|
||||
c2 = UOp(Ops.ALU, dtypes.int, (y2, y2), BinaryOps.MUL)
|
||||
c3 = UOp(Ops.ALU, dtypes.int, (y3, y2), BinaryOps.MUL)
|
||||
c4 = UOp(Ops.ALU, dtypes.int, (y2, y1), BinaryOps.MUL)
|
||||
c5 = UOp(Ops.ALU, dtypes.int, (y2, y3), BinaryOps.MUL)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
@@ -105,37 +105,37 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c5), c5)
|
||||
|
||||
def test_dup_name(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)])
|
||||
y1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
y2 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
|
||||
c2 = UOp(UOps.ALU, dtypes.float, (y1, y2), BinaryOps.ADD)
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)])
|
||||
y1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
y2 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c1 = UOp(Ops.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
|
||||
c2 = UOp(Ops.ALU, dtypes.float, (y1, y2), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), c1)
|
||||
|
||||
def test_dtype(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_dtype_set(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
|
||||
c3 = UOp(UOps.CONST, dtypes.float16, arg=1.0)
|
||||
c4 = UOp(UOps.CONST, dtypes.int, arg=1)
|
||||
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
|
||||
c3 = UOp(Ops.CONST, dtypes.float16, arg=1.0)
|
||||
c4 = UOp(Ops.CONST, dtypes.int, arg=1)
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), c2)
|
||||
self.assertEqual(matcher.rewrite(c3), None)
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_src_one(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
# that CONST/ALU -> ALU/CONST rewrite is now instant
|
||||
@@ -149,46 +149,46 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_src_permutations(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c3,c2), BinaryOps.ADD)
|
||||
c5 = UOp(UOps.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
|
||||
c6 = UOp(UOps.ALU, dtypes.float, (c3,c4), BinaryOps.ADD)
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=[UPat(Ops.CONST), UPat(Ops.ALU)]), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
c4 = UOp(Ops.ALU, dtypes.float, (c3,c2), BinaryOps.ADD)
|
||||
c5 = UOp(Ops.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
|
||||
c6 = UOp(Ops.ALU, dtypes.float, (c3,c4), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c3), None)
|
||||
self.assertEqual(matcher.rewrite(c4), c4)
|
||||
self.assertEqual(matcher.rewrite(c5), c5)
|
||||
self.assertEqual(matcher.rewrite(c6), None)
|
||||
|
||||
def test_src_repeat(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=UPat(UOps.CONST)), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
c4 = UOp(Ops.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
|
||||
self.assertEqual(matcher.rewrite(c3), c3)
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
|
||||
def test_allow_len(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(UOps.CONST, dtypes.float, arg=3.0)
|
||||
c4 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.EXP2)
|
||||
c5 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
c6 = UOp(UOps.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC)
|
||||
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
c3 = UOp(Ops.CONST, dtypes.float, arg=3.0)
|
||||
c4 = UOp(Ops.ALU, dtypes.float, (c1,), UnaryOps.EXP2)
|
||||
c5 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
|
||||
c6 = UOp(Ops.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC)
|
||||
self.assertEqual(matcher.rewrite(c4), None)
|
||||
self.assertEqual(matcher.rewrite(c5), None)
|
||||
self.assertEqual(matcher.rewrite(c6), c6)
|
||||
|
||||
def test_deep_src_permutations(self):
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
|
||||
u1 = (c1 + c2) + c1
|
||||
u2 = (c2 + c1) + c1
|
||||
matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, src=[UPat(UOps.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
|
||||
(UPat(Ops.ALU, src=[UPat(Ops.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
|
||||
])
|
||||
self.assertIsNotNone(matcher.rewrite(u1))
|
||||
self.assertIsNotNone(matcher.rewrite(u2))
|
||||
|
||||
@@ -5,14 +5,14 @@ from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad import Variable
|
||||
from tinygrad.ops import UOp, UOps, graph_rewrite
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite
|
||||
from tinygrad.codegen.uopgraph import sym
|
||||
from itertools import product
|
||||
|
||||
def shapetracker_getitem(st:ShapeTracker, val:int):
|
||||
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.int, val)])
|
||||
idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym)
|
||||
assert idx.op is UOps.CONST and valid.op is UOps.CONST
|
||||
assert idx.op is Ops.CONST and valid.op is Ops.CONST
|
||||
return idx.arg, valid.arg
|
||||
|
||||
class CheckingShapeTracker:
|
||||
|
||||
@@ -3,27 +3,27 @@ from typing import Tuple
|
||||
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, simplify_valid
|
||||
from tinygrad.ops import UOp, Ops, simplify_valid
|
||||
|
||||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||
return UOp(UOps.LOAD, dtypes.float, (
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||
return UOp(Ops.LOAD, dtypes.float, (
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||
idx,
|
||||
UOp.const(dtypes.float, 0.0),
|
||||
valid
|
||||
))
|
||||
|
||||
def get_load_image_uop(image_shape:Tuple[int, ...], valid:UOp, idx:Tuple[UOp, UOp]):
|
||||
return UOp(UOps.LOAD, dtypes.float.vec(4), (
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0),
|
||||
UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx),
|
||||
UOp(UOps.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0),)*4),
|
||||
return UOp(Ops.LOAD, dtypes.float.vec(4), (
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0),
|
||||
UOp(Ops.VECTORIZE, dtypes.int.vec(2), idx),
|
||||
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0),)*4),
|
||||
valid
|
||||
))
|
||||
|
||||
def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax))
|
||||
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax))
|
||||
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
||||
def Range(n, nmax): return UOp(UOps.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
|
||||
def Range(n, nmax): return UOp(Ops.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
|
||||
|
||||
class TestHelpers(unittest.TestCase):
|
||||
def test_is_increasing(self):
|
||||
@@ -43,7 +43,7 @@ class TestHelpers(unittest.TestCase):
|
||||
self.assertTrue(is_increasing(f2))
|
||||
self.assertTrue(is_increasing(f3))
|
||||
|
||||
rng = UOp(UOps.RANGE, dtypes.int, arg=(2, True), src=(UOp(UOps.CONST, dtypes.int, arg=0, src=()), UOp(UOps.CONST, dtypes.int, arg=5, src=()),))
|
||||
rng = UOp(Ops.RANGE, dtypes.int, arg=(2, True), src=(UOp(Ops.CONST, dtypes.int, arg=0, src=()), UOp(Ops.CONST, dtypes.int, arg=5, src=()),))
|
||||
self.assertTrue(is_increasing(rng))
|
||||
self.assertTrue(is_increasing(rng+2))
|
||||
|
||||
@@ -87,7 +87,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
def check(self, load, svalid, sidx0, sidx1):
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
idx = load.src[0].src[1]
|
||||
self.assertEqual(idx.op, UOps.VECTORIZE)
|
||||
self.assertEqual(idx.op, Ops.VECTORIZE)
|
||||
self.assertEqual(len(idx.src), 2)
|
||||
idx0, idx1 = idx.src[0], idx.src[1]
|
||||
self.assertEqual(idx0.render(simplify=False), sidx0)
|
||||
@@ -152,7 +152,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
# empty -> invalid
|
||||
load = get_load_image_uop(shape, (gidx0).lt(8) & (gidx0).lt(8).ne(True), idx)
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
self.assertEqual(load.op, UOps.VECTORIZE)
|
||||
self.assertEqual(load.op, Ops.VECTORIZE)
|
||||
self.assertEqual(load.dtype.count, 4)
|
||||
|
||||
def test_openpilot_conv1(self):
|
||||
|
||||
@@ -9,15 +9,15 @@ from typing import Tuple
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||
from tinygrad.ops import UOp, UOps, graph_rewrite, sym_infer
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, sym_infer
|
||||
from tinygrad import Variable
|
||||
import functools
|
||||
|
||||
def render(self) -> Tuple[str, ConstType, ConstType]:
|
||||
# NOTE: we need STORE so the ALU op has children
|
||||
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
|
||||
rewritten_uop = [uop for uop in uops if uop.op is UOps.STORE][0].src[-1]
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
|
||||
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
|
||||
|
||||
def NumNode(val): return UOp.const(dtypes.int, val)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import unittest, math
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
class TestVminVmaxProperties(unittest.TestCase):
|
||||
@@ -35,7 +35,7 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
def test_vmin_vmax_multiplication_0_inf(self):
|
||||
# vmin and vmax for multiplication with a variable
|
||||
x = UOp.const(dtypes.float, 0.0)
|
||||
y = UOp.load(UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float)
|
||||
y = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float)
|
||||
uop = x * y
|
||||
# TODO: these should be 0, but definitely should not be nan
|
||||
self.assertEqual(uop.vmin, -math.inf)
|
||||
|
||||
@@ -4,7 +4,7 @@ import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import UOp, UOps, ReduceOps, print_uops
|
||||
from tinygrad.ops import UOp, Ops, ReduceOps, print_uops
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad import dtypes
|
||||
@@ -12,7 +12,7 @@ from tinygrad.shape.view import View
|
||||
|
||||
class InvalidASTException(Exception): pass
|
||||
def helper_test_verify_ast(*stores:UOp) -> Kernel:
|
||||
sink = UOp(UOps.SINK, dtypes.void, stores)
|
||||
sink = UOp(Ops.SINK, dtypes.void, stores)
|
||||
if DEBUG >= 3:
|
||||
for op in stores: print(op)
|
||||
try: verify_ast(sink)
|
||||
@@ -26,59 +26,59 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel:
|
||||
class TestVerifyAST(unittest.TestCase):
|
||||
def test_tiny_add(self):
|
||||
dtype = dtypes.int
|
||||
buf_0 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 0)
|
||||
buf_1 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 1)
|
||||
buf_2 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 2)
|
||||
a = UOp(UOps.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
b = UOp(UOps.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
store = UOp(UOps.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b))
|
||||
buf_0 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0)
|
||||
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1)
|
||||
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 2)
|
||||
a = UOp(Ops.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
b = UOp(Ops.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
store = UOp(Ops.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b))
|
||||
helper_test_verify_ast(store)
|
||||
|
||||
def test_exactly_one_full_shape(self):
|
||||
dtype = dtypes.int
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)]
|
||||
a = UOp(UOps.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
b = UOp(UOps.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)]
|
||||
a = UOp(Ops.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
b = UOp(Ops.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
st0 = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), a+b)
|
||||
a = UOp(UOps.LOAD, dtype, (bufs[4], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
b = UOp(UOps.LOAD, dtype, (bufs[5], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
a = UOp(Ops.LOAD, dtype, (bufs[4], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
b = UOp(Ops.LOAD, dtype, (bufs[5], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
st1 = UOp.store(bufs[1], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1)
|
||||
|
||||
def test_no_implicit_broadcasting(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
|
||||
b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
|
||||
st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
|
||||
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
|
||||
st = UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
def test_shrink_ok(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop()))
|
||||
b = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop()))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop()))
|
||||
b = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop()))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
|
||||
helper_test_verify_ast(st)
|
||||
|
||||
def test_reduce_store(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
|
||||
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
||||
|
||||
def test_reduce_add_store(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
|
||||
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
||||
|
||||
def test_buffer_uops_st(self):
|
||||
a = Tensor.randn(4, 4)+2
|
||||
uop_sts = verify_ast(a.schedule()[-1].ast)
|
||||
store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0]
|
||||
store_st = [st for u,st in uop_sts.items() if u.op is Ops.STORE][0]
|
||||
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
|
||||
const_st = [st for u,st in uop_sts.items() if u.op is UOps.VALID][0]
|
||||
const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0]
|
||||
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \
|
||||
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, print_uops, type_verify, resolve, \
|
||||
graph_rewrite, track_rewrites, Variable, sint
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
@@ -54,7 +54,7 @@ class TensorCoreOptions:
|
||||
|
||||
class Kernel:
|
||||
def __init__(self, ast:UOp, opts:Optional[Renderer]=None):
|
||||
if ast.op is UOps.SINK: self.ast = ast
|
||||
if ast.op is Ops.SINK: self.ast = ast
|
||||
|
||||
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
||||
try: uop_sts_map = verify_ast(self.ast)
|
||||
@@ -65,7 +65,7 @@ class Kernel:
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
|
||||
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is UOps.REDUCE_AXIS])
|
||||
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
|
||||
|
||||
self.vars: List[Variable] = self.ast.variables()
|
||||
self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in BUFFER_UOPS]
|
||||
@@ -125,7 +125,7 @@ class Kernel:
|
||||
return ret
|
||||
|
||||
@property
|
||||
def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {UOps.LOAD, UOps.STORE}])
|
||||
def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
||||
|
||||
# TODO: these need more tests or it might silently be no-op
|
||||
def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
|
||||
@@ -273,16 +273,16 @@ class Kernel:
|
||||
|
||||
def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
if has_cast and not (reduceop.src[0].op is UOps.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
|
||||
if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
|
||||
|
||||
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
||||
if mul_op.arg is not BinaryOps.MUL: return None
|
||||
|
||||
def buf_index(src:UOp) -> Optional[int]:
|
||||
# TODO: apply tc even if the sources are not from LOAD
|
||||
if src.op is UOps.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
|
||||
if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
|
||||
try:
|
||||
if opt_level >= 1 and src.op is UOps.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
|
||||
if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
|
||||
except ValueError: return None
|
||||
return None
|
||||
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
|
||||
@@ -442,7 +442,7 @@ class Kernel:
|
||||
check(axis < self.first_upcast, "cannot pad upcasted")
|
||||
# ok to pad SUM if all parent ALU ops have f(0) = 0
|
||||
if (r:=self.reduceop) is not None and self.first_reduce <= axis:
|
||||
check(r.arg[0] is BinaryOps.ADD and not any(u.op is UOps.ALU and u.arg in UNSAFE_PAD_OPS for u in r.parents), "cannot pad UNSAFE_PAD_OPS")
|
||||
check(r.arg[0] is BinaryOps.ADD and not any(u.op is Ops.ALU and u.arg in UNSAFE_PAD_OPS for u in r.parents), "cannot pad UNSAFE_PAD_OPS")
|
||||
padded = False
|
||||
for i,st in enumerate(self.sts):
|
||||
if (s:=st.shape[axis]) == 1: continue # reduced
|
||||
@@ -472,7 +472,7 @@ class Kernel:
|
||||
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
||||
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
||||
(mulop:=self.reduceop.src[0]).arg is BinaryOps.MUL and mulop.src[0].op is UOps.LOAD and mulop.src[1].op is UOps.LOAD:
|
||||
(mulop:=self.reduceop.src[0]).arg is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
|
||||
st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])]
|
||||
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
||||
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
|
||||
@@ -615,20 +615,20 @@ class Kernel:
|
||||
arg = op.arg
|
||||
if op.op in BUFFER_UOPS:
|
||||
# for locals, we use the ShapeTracker that's in the srcs
|
||||
st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
|
||||
st = op.st_arg if op.src[0].op is Ops.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
|
||||
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
|
||||
if op.op is UOps.VALID: return op.replace(src=(st_uop,))
|
||||
if op.op is UOps.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
|
||||
if op.op is Ops.VALID: return op.replace(src=(st_uop,))
|
||||
if op.op is Ops.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
|
||||
return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
|
||||
if op.op is UOps.REDUCE_AXIS:
|
||||
if op.op is Ops.REDUCE_AXIS:
|
||||
reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
|
||||
alu_op: BinaryOps = op.arg[0]
|
||||
axis = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len)
|
||||
if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i]))
|
||||
if op in self.bufs_for_tensor_core and (tc := self.tensor_core):
|
||||
rsrc = op.src[0]
|
||||
if rsrc.op is UOps.CAST: rsrc = rsrc.src[0]
|
||||
assert rsrc.op is UOps.ALU and rsrc.arg is BinaryOps.MUL
|
||||
if rsrc.op is Ops.CAST: rsrc = rsrc.src[0]
|
||||
assert rsrc.op is Ops.ALU and rsrc.arg is BinaryOps.MUL
|
||||
|
||||
def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
|
||||
wd, tcd = self.global_dims, self.first_upcast
|
||||
@@ -654,30 +654,30 @@ class Kernel:
|
||||
for i,s in enumerate(self.full_shape))
|
||||
srcs = []
|
||||
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
|
||||
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
|
||||
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is Ops.LOAD]
|
||||
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
|
||||
st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop()
|
||||
membuf = UOp(UOps.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
|
||||
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
|
||||
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
|
||||
membuf = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
|
||||
local_store = fixup_ast(UOp(Ops.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
|
||||
srcs.append(UOp(Ops.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
|
||||
else:
|
||||
# for TC=2, we can't do the shapetracker fixup
|
||||
srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])]
|
||||
# MUL/SUM instead of WMMA
|
||||
ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
|
||||
ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
|
||||
else:
|
||||
# real WMMA, use CONTRACT/EXPAND to get the vectorization right
|
||||
wmma_upcast_axes = wmma_arg[-2]
|
||||
wmma_sz = [prod(x[1] for x in l) for l in wmma_upcast_axes]
|
||||
wmma = UOp(UOps.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
|
||||
UOp(UOps.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]),
|
||||
UOp(UOps.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]),
|
||||
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
|
||||
UOp(Ops.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]),
|
||||
UOp(Ops.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]),
|
||||
UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
|
||||
ret = UOp(UOps.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2])
|
||||
ret = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2])
|
||||
new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in [ax for ax, _ in tc.reduce_axes])
|
||||
return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret
|
||||
if self.group_for_reduces:
|
||||
start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(alu_op, axis))
|
||||
start = UOp(Ops.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(alu_op, axis))
|
||||
second_axis = tuple(i for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces) \
|
||||
if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i])
|
||||
# NOTE: if there's a grouped reduce, but no reduce axes for this reduce, we can skip it
|
||||
@@ -687,14 +687,14 @@ class Kernel:
|
||||
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
||||
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
||||
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
||||
local_buffer = UOp(UOps.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
|
||||
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
|
||||
grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
|
||||
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
|
||||
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
st_uop = ShapeTracker.from_shape(tuple([1 if i in second_axis else a for i,a in enumerate(local_shape)])).to_uop()
|
||||
return UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce)))
|
||||
return UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce)))
|
||||
arg = (alu_op, axis)
|
||||
elif op.op is UOps.SINK:
|
||||
elif op.op is Ops.SINK:
|
||||
arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)
|
||||
return op.replace(src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg)
|
||||
# NOTE: rewrite with an empty PatternMatcher to dedup UOps
|
||||
@@ -728,7 +728,7 @@ class Kernel:
|
||||
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
||||
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
||||
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
|
||||
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFFER_UOPS and x.src[0].op is UOps.DEFINE_GLOBAL],
|
||||
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFFER_UOPS and x.src[0].op is Ops.DEFINE_GLOBAL],
|
||||
key=lambda x: (x.op, x.src[0].arg)))
|
||||
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
@@ -738,15 +738,15 @@ class Kernel:
|
||||
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
|
||||
if not uop.has_st or uop in sts: return
|
||||
# restore globals from the two stage reduce
|
||||
if uop.op is UOps.LOAD and uop.src[0].op is UOps.DEFINE_LOCAL:
|
||||
if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
|
||||
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
|
||||
sts[uop] = sts[local_reduce]
|
||||
return
|
||||
for x in uop.src: _assert_valid_uop(x, st, sts)
|
||||
# only reduceuop is allowed to change shape, limited to turning n to 1
|
||||
if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
|
||||
if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
|
||||
# movementops are pushed to VIEW
|
||||
elif uop.op is UOps.VIEW: st = uop.arg
|
||||
elif uop.op is Ops.VIEW: st = uop.arg
|
||||
# everything else inherits shape
|
||||
else:
|
||||
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
|
||||
@@ -756,7 +756,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
|
||||
sts[uop] = st
|
||||
|
||||
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
||||
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
|
||||
assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
|
||||
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
|
||||
sts: Dict[UOp, ShapeTracker] = {}
|
||||
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Set, Dict, Tuple
|
||||
import functools, heapq
|
||||
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, UOps
|
||||
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
|
||||
@@ -10,13 +10,13 @@ def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[U
|
||||
children[u] = []
|
||||
for x in u.src:
|
||||
srcs[u].update(get_children_dfs(x, children, srcs, in_degree))
|
||||
if x.op is UOps.RANGE and x.arg[1]: srcs[u][x] = None
|
||||
if x.op is Ops.RANGE and x.arg[1]: srcs[u][x] = None
|
||||
children[x].append(u)
|
||||
in_degree[u] = len(u.src)
|
||||
return srcs[u]
|
||||
|
||||
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
# filter nodes that don't link to a sink
|
||||
# BFS toposort
|
||||
children: Dict[UOp, List[UOp]] = {}
|
||||
@@ -25,38 +25,38 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
get_children_dfs(sink, children, range_srcs, in_degree)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
|
||||
if x.op is UOps.SINK: return set()
|
||||
def get_recursive_children(x:UOp, end:Ops, include_self=False) -> Set[UOp]:
|
||||
if x.op is Ops.SINK: return set()
|
||||
return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
|
||||
|
||||
# scope children impact the toposort and END* insertion
|
||||
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
|
||||
range_phi = {r:[p for p in scope_children[r] if p.op is UOps.ASSIGN] for r in scope_children if r.op is UOps.RANGE}
|
||||
range_phi = {r:[p for p in scope_children[r] if p.op is Ops.ASSIGN] for r in scope_children if r.op is Ops.RANGE}
|
||||
|
||||
# assign priorities
|
||||
def get_priority(u:UOp):
|
||||
priority = 0
|
||||
# prefer ranges that depend on the least number of independent ranges
|
||||
if u.op is UOps.RANGE and u.arg[1]:
|
||||
if u.op is Ops.RANGE and u.arg[1]:
|
||||
priority += u.arg[0]
|
||||
for p in range_phi[u]:
|
||||
priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
|
||||
elif u.op is UOps.CONST:
|
||||
elif u.op is Ops.CONST:
|
||||
# place consts first here, they don't do anything and it can cause issues with DEFINE_ACC
|
||||
priority -= 100000000000
|
||||
else:
|
||||
# prefer uops that are loop children
|
||||
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss])
|
||||
if u.op is UOps.IF and len(u.src) == 1: priority += 10000000 # if penalty
|
||||
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is Ops.RANGE and u in ss])
|
||||
if u.op is Ops.IF and len(u.src) == 1: priority += 10000000 # if penalty
|
||||
return priority
|
||||
priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}
|
||||
|
||||
# prevent priority inversion
|
||||
@functools.lru_cache(None)
|
||||
def fix_priority(u:UOp, lowest_priority):
|
||||
if u.op in {UOps.CAST, UOps.BITCAST, UOps.ALU, UOps.VECTORIZE, UOps.GEP, UOps.SPECIAL, UOps.DEFINE_LOCAL, UOps.LOAD}:
|
||||
if u.op in {Ops.CAST, Ops.BITCAST, Ops.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
|
||||
priorities[u] = min(priorities[u], lowest_priority)
|
||||
if u.op is UOps.LOAD: priorities[u] += 100 # load penalty (here)
|
||||
if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here)
|
||||
for x in u.src: fix_priority(x, priorities[u])
|
||||
fix_priority(sink, 0)
|
||||
|
||||
@@ -73,8 +73,8 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
p,_,x = heapq.heappop(queue)
|
||||
if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg)
|
||||
if x in scope_children: scope_end[x] = x
|
||||
if x.op is UOps.DEFINE_ACC:
|
||||
idx = min([_uops.index(l) for l in x.src if l.op is UOps.RANGE])
|
||||
if x.op is Ops.DEFINE_ACC:
|
||||
idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE])
|
||||
_uops.insert(idx, x)
|
||||
else: _uops.append(x)
|
||||
for u, ss in scope_children.items():
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import variable_to_uop
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, sint, identity_element
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten
|
||||
|
||||
@@ -33,7 +33,7 @@ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
||||
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
|
||||
if reverse: dims = dims[::-1]
|
||||
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
|
||||
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
||||
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
||||
if limited != dims:
|
||||
ret = []
|
||||
# cast for mypy, get_contraction won't be None
|
||||
@@ -57,8 +57,8 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
first_upcasted = len(full_shape)-ki.upcasted
|
||||
first_output_st: ShapeTracker = ast.src[0].st_arg
|
||||
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
|
||||
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is UOps.REDUCE_AXIS))
|
||||
local_loads = [x for x in ast.parents if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL]
|
||||
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS))
|
||||
local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
|
||||
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
|
||||
group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
|
||||
[[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
|
||||
@@ -75,22 +75,22 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
else:
|
||||
# all loops are RANGES
|
||||
idxs = [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
|
||||
idxs = [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
|
||||
for i,g in enumerate(full_shape[:first_reduce])]
|
||||
|
||||
# reduce loops
|
||||
idxs += [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
|
||||
idxs += [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
|
||||
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
||||
|
||||
# upcast loops
|
||||
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
||||
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
||||
idxs.append(UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
idxs.append(UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
|
||||
# late indexes (group for reduce)
|
||||
ridxs = idxs[:]
|
||||
for a in range(first_reduce, first_reduce+group_for_reduces):
|
||||
ridxs[a] = UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
|
||||
return IndexContext(idxs, ridxs)
|
||||
|
||||
@@ -98,50 +98,50 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
|
||||
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is UOps.RANGE)
|
||||
assert all(x.op is UOps.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}"
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
|
||||
assert all(x.op is Ops.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}"
|
||||
alu_op: BinaryOps = x.arg[0]
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
|
||||
return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
|
||||
return UOp(Ops.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
|
||||
|
||||
def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else ctx.idxs)
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
has_valid = valid.op is not Ops.CONST or valid.arg is not True
|
||||
buf = x.src[0]
|
||||
if x.op is UOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
|
||||
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
|
||||
if x.op is Ops.LOAD:
|
||||
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
|
||||
return UOp(Ops.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE and \
|
||||
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
|
||||
store_back = x.src[0].op is Ops.DEFINE_LOCAL and x.src[2].op is Ops.REDUCE and \
|
||||
x.src[2].src[0].op is Ops.LOAD and x.src[2].src[0].src[0].op is Ops.DEFINE_LOCAL
|
||||
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
|
||||
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
|
||||
if x.src[0].op is Ops.DEFINE_GLOBAL or store_back:
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx is not ridx: valid = valid * oidx.eq(0)
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
return UOp(UOps.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
|
||||
has_valid = valid.op is not Ops.CONST or valid.arg is not True
|
||||
return UOp(Ops.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
|
||||
|
||||
pm_lowerer = PatternMatcher([
|
||||
(UPat(UOps.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(UPat(UOps.VALID, src=(UPat(UOps.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
||||
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
||||
])
|
||||
|
||||
def do_reduce(ctx:List[int], root:UOp):
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype,
|
||||
acc = UOp(Ops.DEFINE_ACC, root.dtype,
|
||||
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],))
|
||||
ctx[0] += 1
|
||||
return acc.assign(acc.alu(root.arg, root.src[0]))
|
||||
|
||||
just_reduce = PatternMatcher([
|
||||
# do reduce
|
||||
(UPat(UOps.REDUCE, name="root"), do_reduce),
|
||||
(UPat(Ops.REDUCE, name="root"), do_reduce),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict,
|
||||
import functools, itertools, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat, symbolic_simple
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
|
||||
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
@@ -16,15 +16,15 @@ def fold_expanded(ex, buf):
|
||||
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
|
||||
new_srcs = dedup(list(ex.src))
|
||||
old_new_srcs = new_srcs[:]
|
||||
is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType)
|
||||
is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
|
||||
|
||||
# first, extract all the relevant offsets
|
||||
offsets_rootsrc: DefaultDict[Any, dict] = defaultdict(dict)
|
||||
for i,s in enumerate(new_srcs):
|
||||
idx = s.src[0].src[1]
|
||||
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
|
||||
if idx.arg is BinaryOps.ADD and idx.src[1].op is UOps.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is UOps.CONST: root_src, arg = "CONST", idx.arg
|
||||
if idx.arg is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
# add gates for gated
|
||||
if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src)
|
||||
@@ -45,20 +45,20 @@ def fold_expanded(ex, buf):
|
||||
if is_image:
|
||||
# for images, we rewrite the index. it must evenly divide 4 from the above check
|
||||
new_src[0] = buf.index(
|
||||
UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
||||
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
||||
rootsrc[0] if isinstance(rootsrc, tuple) else None)
|
||||
else:
|
||||
# for non image, we upcast the index pointer
|
||||
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(new_src[0].dtype.local))
|
||||
# vectorize the store
|
||||
if not is_load:
|
||||
new_src[1] = UOp(UOps.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
|
||||
new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
|
||||
# generate the folded new_srcs
|
||||
if is_load:
|
||||
new_load = UOp(UOps.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
|
||||
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
|
||||
for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i)
|
||||
else:
|
||||
for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(UOps.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
|
||||
for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
|
||||
for i in range(fold_length): used.add((rootsrc,o+i))
|
||||
|
||||
# dedup expand for LOAD
|
||||
@@ -72,15 +72,15 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp):
|
||||
new_src = list(load.src)
|
||||
# TODO: copied logic from above
|
||||
new_src[0] = load.src[0].src[0].index(
|
||||
UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
||||
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
|
||||
load.src[0].src[2] if len(load.src[0].src) == 3 else None)
|
||||
vec_load = UOp(UOps.LOAD, load.dtype.vec(4), tuple(new_src))
|
||||
vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
|
||||
|
||||
buf_idx_pat = UPat(UOps.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
|
||||
buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
|
||||
float4_folding = PatternMatcher([
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(UPat((UOps.BARRIER, UOps.SINK), src=UPat(UOps.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
||||
(UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
|
||||
])
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
@@ -124,18 +124,18 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
|
||||
powers_of_two = {2**i:i for i in range(64)}
|
||||
@functools.lru_cache(None)
|
||||
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
||||
pat: List[Tuple[UPat, Callable]] = [(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \
|
||||
pat: List[Tuple[UPat, Callable]] = [(UPat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \
|
||||
((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental]
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
|
||||
if BinaryOps.AND in ops:
|
||||
pat += [(UPat(UOps.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
|
||||
pat += [(UPat(Ops.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
|
||||
lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
|
||||
# rewrite MUL/IDIV to SHL+SHR
|
||||
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
|
||||
pat += [
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
|
||||
mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y)
|
||||
(UPat(UOps.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
|
||||
(UPat(Ops.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
|
||||
div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
|
||||
if UnaryOps.NEG in ops:
|
||||
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
|
||||
@@ -175,8 +175,8 @@ def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extr
|
||||
if vec is not None:
|
||||
# add, mul, loop_start, loop_end
|
||||
def dvec(x:UOp):
|
||||
if x.op is UOps.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg)
|
||||
return UOp(UOps.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
|
||||
if x.op is Ops.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg)
|
||||
return UOp(Ops.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
|
||||
add, mul, loop_start, loop_end = dvec(add), dvec(mul), dvec(loop_start), dvec(loop_end)
|
||||
if mul.vmin > 0 and ne is not None:
|
||||
comprange = UOp.minimum(loop_end, UOp.maximum((add-compval)//mul + (loop_end-loop_start), loop_start))
|
||||
@@ -209,7 +209,7 @@ def gep_through_wmma(gep:UOp, wmma:UOp):
|
||||
ssz = prod(x[1] for x in sz)
|
||||
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
|
||||
tsrcs.append(s.gep(tuple(src_args)))
|
||||
return UOp(UOps.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
|
||||
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
|
||||
|
||||
def no_vectorized_wmma(wmma:UOp):
|
||||
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
||||
@@ -218,9 +218,9 @@ def no_vectorized_wmma(wmma:UOp):
|
||||
for s,sz in zip(wmma.src, wmma.arg[6]):
|
||||
ssz = prod(x[1] for x in sz)
|
||||
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
|
||||
wmmas = [UOp(UOps.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
|
||||
wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
|
||||
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
||||
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
||||
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
||||
|
||||
def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
||||
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents)
|
||||
@@ -231,81 +231,81 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
|
||||
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
acc_pat, rng_pat = UPat(UOps.DEFINE_ACC, name="acc"), UPat(UOps.RANGE, name="rng")
|
||||
acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
|
||||
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
|
||||
|
||||
index_load = UPat.var("buf").index(rng_aug).load(name="ld")
|
||||
|
||||
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=rng_aug))
|
||||
arange_m = arange_augrng.lt(UPat.cvar("compval")).ne(UPat(UOps.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
|
||||
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
|
||||
arange_m = arange_augrng.lt(UPat.cvar("compval")).ne(UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
|
||||
|
||||
# this is symbolic 2.0
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# self ASSIGN is just self
|
||||
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
||||
# ASSIGN to global is just self
|
||||
(UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
|
||||
# VECTORIZE/CONST, VECTORIZE/GEP
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
||||
# reorder ALU/VECTORIZE
|
||||
(UPat(UOps.ALU, src=(UPat(UOps.VECTORIZE, src=UPat(name='x')), UPat(UOps.VECTORIZE, src=UPat(name='y'))), name='alu'),
|
||||
lambda x,y,alu: UOp(UOps.VECTORIZE, alu.dtype, (UOp(UOps.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
|
||||
(UPat(Ops.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
|
||||
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(Ops.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
|
||||
# VECTORIZE of a single element is just that element
|
||||
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
# VECTORIZE void is SINK
|
||||
(UPat(UOps.VECTORIZE, dtype=dtypes.void, src=UPat(UOps.BARRIER, name='b')), lambda b: b),
|
||||
(UPat(UOps.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(UOps.SINK, dtypes.void, x.src)),
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
|
||||
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
|
||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||
(UPat(UOps.GEP, src=(UPat(UOps.GEP, name='g2'),), name='g1'),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
|
||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
||||
(UPat(UOps.GEP, src=(UPat(UOps.VECTORIZE, name="vec"),), name="gep"),
|
||||
lambda gep, vec: UOp(UOps.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat(UOps.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(UOps.GEP, src=(UPat(UOps.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
|
||||
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
|
||||
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
|
||||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(UPat(UOps.GEP, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name='alu'),), name='gep'),
|
||||
(UPat(Ops.GEP, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)),
|
||||
# push some GEPs through WMMAs
|
||||
(UPat(UOps.GEP, src=(UPat(UOps.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
||||
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
||||
# tensor core with a 0 input is acc
|
||||
(UPat(UOps.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
|
||||
(UPat(UOps.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
|
||||
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
|
||||
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
|
||||
# tensor core cleanups
|
||||
(UPat.var("add") + UPat(UOps.WMMA, name="wmma"),
|
||||
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
# threefry
|
||||
(UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
(UPat(Ops.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
|
||||
# arange loop folding
|
||||
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
|
||||
# indexing, with cast or where
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
|
||||
# parentless reduce
|
||||
(acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
|
||||
(acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
|
||||
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
|
||||
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
|
||||
# ** self folding **
|
||||
(UPat(UOps.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
|
||||
(UPat(UOps.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
|
||||
(UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
|
||||
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
|
||||
# x!=0 -> (bool)x
|
||||
(UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
|
||||
# ** load/store folding **
|
||||
(UPat.store(UPat(UOps.INDEX, name="index"), UPat.load(UPat(UOps.INDEX, name="index"))), lambda index: UOp(UOps.NOOP)),
|
||||
(UPat.store(UPat(UOps.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(UOps.INDEX, name="index")))),
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
|
||||
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),
|
||||
lambda index, gate, alt: UOp.store(index.src[0].index(index.src[1], gate), alt)),
|
||||
# fold gated LOAD/STORE
|
||||
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
|
||||
(UPat().index(UPat(), UPat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer
|
||||
(UPat(UOps.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
|
||||
(UPat(UOps.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(UOps.NOOP)), # NULL pointer store does nothing
|
||||
(UPat(Ops.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
|
||||
(UPat(Ops.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing
|
||||
# remove NOOPs from SINK
|
||||
(UPat(UOps.SINK, name="root"),
|
||||
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None),
|
||||
# remove EXPANDs from SINK/BARRIER
|
||||
(UPat(UOps.BARRIER, src=(UPat((UOps.VECTORIZE, UOps.SINK), name='sink'),)), lambda sink: UOp(UOps.BARRIER, dtypes.void, sink.src)),
|
||||
(UPat(UOps.SINK, name="root"),
|
||||
lambda root: UOp(UOps.SINK, root.dtype, tuple(flatten(x.src if x.op in {UOps.SINK, UOps.EXPAND} else (x,) for x in root.src)), root.arg)
|
||||
if any(x.op in {UOps.SINK, UOps.EXPAND} for x in root.src) else None),
|
||||
(UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.EXPAND} else (x,) for x in root.src)), root.arg)
|
||||
if any(x.op in {Ops.SINK, Ops.EXPAND} for x in root.src) else None),
|
||||
])
|
||||
|
||||
# *** uop expander ***
|
||||
@@ -325,10 +325,10 @@ def _swizzle_args(cargs:Tuple[Tuple[int, int], ...], eargs:Tuple[Tuple[int, int]
|
||||
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
|
||||
|
||||
def do_expand(root:UOp):
|
||||
expands = [x for x in root.src if x.op is UOps.EXPAND]
|
||||
expands = [x for x in root.src if x.op is Ops.EXPAND]
|
||||
if len(expands) == 0: return None
|
||||
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
|
||||
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is UOps.WMMA else ()
|
||||
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
|
||||
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
|
||||
# if there's only one expand arg, it's okay to use it (optimization)
|
||||
expand_args = expands[0].arg
|
||||
@@ -338,8 +338,8 @@ def do_expand(root:UOp):
|
||||
expand_sz = prod([x[1] for x in expand_args])
|
||||
new_srcs = []
|
||||
for i,src in enumerate(root.src):
|
||||
if src.op is UOps.EXPAND:
|
||||
if root.op is UOps.IF and i == 0:
|
||||
if src.op is Ops.EXPAND:
|
||||
if root.op is Ops.IF and i == 0:
|
||||
# IF means OR on first arg to IF
|
||||
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
|
||||
elif expand_args == src.arg:
|
||||
@@ -352,70 +352,70 @@ def do_expand(root:UOp):
|
||||
new_srcs.append(src.src[0].gep(tuple(lst)))
|
||||
else:
|
||||
# non-EXPAND input
|
||||
if (root.op is UOps.IF) or (root.op is UOps.REDUCE and i != 0):
|
||||
if (root.op is Ops.IF) or (root.op is Ops.REDUCE and i != 0):
|
||||
# for the first arg of IF and the RANGE args of REDUCE, just pass them through ignoring EXPANDS
|
||||
new_srcs.append(src)
|
||||
elif src.dtype.count > 1:
|
||||
# put any input dtype > 1 grouped together
|
||||
new_srcs.append(UOp(UOps.VECTORIZE,
|
||||
new_srcs.append(UOp(Ops.VECTORIZE,
|
||||
src.dtype.scalar().vec(expand_sz*src.dtype.count), tuple(src.gep(i) for i in range(src.dtype.count))*expand_sz))
|
||||
else:
|
||||
# repeat the arg
|
||||
new_srcs.append(src.broadcast(expand_sz))
|
||||
|
||||
new_arg = root.arg
|
||||
if root.op is UOps.GEP:
|
||||
if root.op is Ops.GEP:
|
||||
assert root.dtype.count == 1
|
||||
# is this right?
|
||||
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
|
||||
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
|
||||
return UOp(UOps.EXPAND, root.dtype, (nsrc,), expand_args)
|
||||
return UOp(Ops.EXPAND, root.dtype, (nsrc,), expand_args)
|
||||
|
||||
def do_contract(con:UOp):
|
||||
ex = con.src[0]
|
||||
# CONTRACT without EXPAND repeats the element VECTORIZED
|
||||
if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
||||
if ex.op is not Ops.EXPAND: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
|
||||
# CONTRACT may remove several axes from EXPAND
|
||||
assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
|
||||
idxs = []
|
||||
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
|
||||
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
|
||||
return UOp(UOps.EXPAND, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
|
||||
return UOp(Ops.EXPAND, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
|
||||
|
||||
def no_vectorized_alu(alu):
|
||||
if alu.dtype.vcount == 1: return None
|
||||
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
||||
return UOp(UOps.VECTORIZE, alu.dtype, alus)
|
||||
return UOp(Ops.VECTORIZE, alu.dtype, alus)
|
||||
|
||||
def create_gate(root:UOp) -> Optional[UOp]:
|
||||
@functools.lru_cache(None)
|
||||
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
|
||||
if u.op is UOps.BARRIER: return u
|
||||
if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER:
|
||||
return UOp(u.op, u.dtype, u.src[:-1]+(UOp(UOps.IF, dtypes.void, (gate, u.src[-1])),), u.arg)
|
||||
if u.op is Ops.BARRIER: return u
|
||||
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
|
||||
return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, dtypes.void, (gate, u.src[-1])),), u.arg)
|
||||
return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
|
||||
idx = root.src[0]
|
||||
if idx.op is UOps.CAST: idx = idx.src[0]
|
||||
return None if idx.op is not UOps.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
|
||||
if idx.op is Ops.CAST: idx = idx.src[0]
|
||||
return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
|
||||
|
||||
expander = PatternMatcher([
|
||||
# double expand
|
||||
(UPat(UOps.EXPAND, name="outer", src=(UPat(UOps.EXPAND, name="inner"),)),
|
||||
lambda outer, inner: UOp(UOps.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
(UPat(Ops.EXPAND, name="outer", src=(UPat(Ops.EXPAND, name="inner"),)),
|
||||
lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
|
||||
# do expansion
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, UOps.INDEX, UOps.ASSIGN,
|
||||
UOps.VECTORIZE, UOps.REDUCE, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
|
||||
(UPat(UOps.CONTRACT, name="con"), do_contract),
|
||||
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
|
||||
Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([(Ops.EXPAND, None)])), do_expand),
|
||||
(UPat(Ops.CONTRACT, name="con"), do_contract),
|
||||
# vectorize DEFINE_ACC
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
|
||||
# BARRIERs aren't actually expanded
|
||||
(UPat(UOps.BARRIER, src=(UPat(UOps.EXPAND, name="ex"),)),
|
||||
lambda ex: UOp(UOps.EXPAND, dtypes.void, (UOp(UOps.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
|
||||
(UPat(Ops.BARRIER, src=(UPat(Ops.EXPAND, name="ex"),)),
|
||||
lambda ex: UOp(Ops.EXPAND, dtypes.void, (UOp(Ops.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
|
||||
# empty EXPAND is NOOP
|
||||
(UPat(UOps.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x),
|
||||
(UPat(Ops.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x),
|
||||
# EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU
|
||||
(UPat(UOps.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
||||
lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
||||
(UPat(Ops.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
|
||||
lambda ex,x,y: UOp(Ops.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
|
||||
])
|
||||
|
||||
def no_vectorized_load_store(ls:UOp):
|
||||
@@ -423,75 +423,75 @@ def no_vectorized_load_store(ls:UOp):
|
||||
assert isinstance(idx.dtype, PtrDType)
|
||||
if idx.dtype.v == 1: return None
|
||||
tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)]
|
||||
return UOp(UOps.VECTORIZE, ls.dtype, tuple(tv))
|
||||
return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv))
|
||||
|
||||
def no_vectorized_acc(acc:UOp):
|
||||
if acc.dtype.count == 1: return None
|
||||
alus = tuple(UOp(acc.op, acc.dtype.scalar(),
|
||||
tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
|
||||
return UOp(UOps.VECTORIZE, acc.dtype, alus)
|
||||
return UOp(Ops.VECTORIZE, acc.dtype, alus)
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.INDEX), name="alu"), no_vectorized_alu),
|
||||
(UPat(UOps.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat(UOps.DEFINE_ACC, name="acc"), no_vectorized_acc),
|
||||
(UPat((UOps.LOAD, UOps.STORE), name="ls"), no_vectorized_load_store),
|
||||
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
|
||||
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
|
||||
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
||||
])
|
||||
|
||||
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optional[UOp]=None) -> Optional[UOp]:
|
||||
if store_gate not in [gate.src[0] for gate in val.sparents if gate.op is UOps.IF]: return None
|
||||
if store_gate not in [gate.src[0] for gate in val.sparents if gate.op is Ops.IF]: return None
|
||||
# remove the gate from the index
|
||||
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
|
||||
|
||||
load_store_indexing = PatternMatcher([
|
||||
# late fixup of unfoldable image loads
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
# simplify valid
|
||||
(UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
|
||||
(UPat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
|
||||
# image load valid idx simplification
|
||||
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
# delete_redundant_gates (after expand)
|
||||
(UPat(UOps.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
||||
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
||||
UPat.var("val"))), delete_redundant_gates),
|
||||
])
|
||||
|
||||
def idx_load_store(x:UOp):
|
||||
idx = x.src[0].index(x.src[1], x.src[3] if len(x.src) > 3 else None)
|
||||
v = x.dtype.count if x.op is UOps.LOAD else x.src[2].dtype.count
|
||||
v = x.dtype.count if x.op is Ops.LOAD else x.src[2].dtype.count
|
||||
if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local))
|
||||
post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is UOps.LOAD else x.src[3:])
|
||||
if x.op is UOps.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg)
|
||||
post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is Ops.LOAD else x.src[3:])
|
||||
if x.op is Ops.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg)
|
||||
return UOp(x.op, x.dtype, (idx,x.src[2])+post_mask, x.arg)
|
||||
|
||||
migrate_indexing = PatternMatcher([
|
||||
# use indexing for LOAD/STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
|
||||
# create gate MUST BE BEFORE expander
|
||||
(UPat(UOps.STORE, name="root"), create_gate),
|
||||
(UPat(Ops.STORE, name="root"), create_gate),
|
||||
])
|
||||
|
||||
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp:
|
||||
# this moves the mask from the indexing to the load/store op for rendering
|
||||
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
|
||||
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is UOps.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
|
||||
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
|
||||
|
||||
pm_render = PatternMatcher([
|
||||
# for rendering, we use explicit VECTORIZE
|
||||
(UPat(UOps.CONST, name='c'),
|
||||
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
||||
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
(UPat(Ops.CONST, name='c'),
|
||||
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
|
||||
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
|
||||
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
# move masks of loads/stores
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index:=UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
|
||||
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
|
||||
])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
|
||||
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
|
||||
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.ops import Ops
|
||||
|
||||
# **************** memory planning ****************
|
||||
|
||||
@@ -47,5 +47,5 @@ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]
|
||||
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([si.bufs for si in schedule],
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not UOps.SINK for b in si.bufs})
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule]
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List, Dict, Optional, cast, Generator, Tuple
|
||||
import time, pprint
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA
|
||||
from tinygrad.ops import UOps, UOp, Variable, sym_infer, sint
|
||||
from tinygrad.ops import Ops, UOp, Variable, sym_infer, sint
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, Program
|
||||
@@ -181,18 +181,18 @@ class ExecItem:
|
||||
return et
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
||||
assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is UOps.COPY
|
||||
if si.ast.op is UOps.SINK:
|
||||
assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is Ops.COPY
|
||||
if si.ast.op is Ops.SINK:
|
||||
runner = get_runner(si.outputs[0].device, si.ast)
|
||||
return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
|
||||
out, arg = si.outputs[0], si.ast.arg
|
||||
if si.ast.op is UOps.COPY:
|
||||
if si.ast.op is Ops.COPY:
|
||||
kernel_type = BufferCopy
|
||||
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
|
||||
kernel_type = BufferXfer
|
||||
return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
|
||||
if si.ast.op is UOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
||||
if si.ast.op is UOps.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
|
||||
if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
||||
if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
|
||||
raise RuntimeError(f"don't know how to lower {si.ast}")
|
||||
|
||||
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys, atexit, functools, itertools
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import BUFFER_UOPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.ops import BUFFER_UOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -15,7 +15,7 @@ from tinygrad.device import Buffer
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
BUF_LIMIT = {"METAL":32}
|
||||
METAOPS = {MetaOps.COPY:UOps.COPY, MetaOps.EMPTY:UOps.EMPTY, MetaOps.VIEW:UOps.BUFFER_VIEW}
|
||||
METAOPS = {MetaOps.COPY:Ops.COPY, MetaOps.EMPTY:Ops.EMPTY, MetaOps.VIEW:Ops.BUFFER_VIEW}
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
||||
@@ -34,7 +34,7 @@ class ScheduleItem:
|
||||
"""Read only buffers in the schedule."""
|
||||
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
|
||||
@functools.cached_property
|
||||
def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is UOps.SINK else (0,)
|
||||
def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
||||
|
||||
# **** small wrapper for LazyBuffer -> UOp
|
||||
|
||||
@@ -62,21 +62,21 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
|
||||
# consts are always fused and generated
|
||||
if buf.op is MetaOps.CONST:
|
||||
if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
|
||||
return UOp(UOps.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(dtype, buf.arg), v.const_like(0))
|
||||
return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(dtype, buf.arg), v.const_like(0))
|
||||
# everything else has BUFFER
|
||||
ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp(UOps.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype))))
|
||||
ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp(Ops.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype))))
|
||||
# if the buffer is already realized we just load it
|
||||
if buf.is_realized(): return UOp(UOps.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
|
||||
if buf.is_realized(): return UOp(Ops.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
|
||||
# everything else needs sources
|
||||
src = tuple(to_uop(x, ctx, cache) for x in buf.srcs)
|
||||
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
|
||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src)
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
|
||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
|
||||
elif buf.op in METAOPS: ret = UOp(METAOPS[cast(MetaOps, buf.op)], buf.dtype, (ubuf, *src), buf.arg)
|
||||
elif buf.op is UnaryOps.CAST: ret = UOp(UOps.CAST, dtype, src)
|
||||
elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src)
|
||||
else: ret = UOp(UOps.ALU, dtype, src, buf.op)
|
||||
cache[buf] = ret = UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
|
||||
elif buf.op is UnaryOps.CAST: ret = UOp(Ops.CAST, dtype, src)
|
||||
elif buf.op is UnaryOps.BITCAST: ret = UOp(Ops.BITCAST, dtype, src)
|
||||
else: ret = UOp(Ops.ALU, dtype, src, buf.op)
|
||||
cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
|
||||
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
|
||||
if buf.forced_realize: ctx.realizes[ubuf] = ubuf
|
||||
return ret
|
||||
@@ -87,7 +87,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
|
||||
|
||||
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
|
||||
if (n:=cache.get(u)) is not None: return n
|
||||
if u.op is UOps.VIEW: return u.replace(arg=apply_to_st(u.arg))
|
||||
if u.op is Ops.VIEW: return u.replace(arg=apply_to_st(u.arg))
|
||||
if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u
|
||||
cache[u] = ret = u.replace(src=tuple(st_fixup(x, apply_to_st, cache) for x in u.src))
|
||||
return ret
|
||||
@@ -123,7 +123,7 @@ def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
|
||||
return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
|
||||
|
||||
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
swizzles = [x for x in root.src if x.op is UOps.VIEW and len(x.src) != 0]
|
||||
swizzles = [x for x in root.src if x.op is Ops.VIEW and len(x.src) != 0]
|
||||
if len(swizzles) == 0: return None
|
||||
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
|
||||
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
|
||||
@@ -131,34 +131,34 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
fixup_cache: Dict[UOp, UOp] = {}
|
||||
new_srcs = [x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src]
|
||||
ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg)
|
||||
return ret if ret.op is UOps.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
||||
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
||||
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
||||
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
|
||||
|
||||
merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
|
||||
merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
|
||||
|
||||
# push VIEW to loads
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view before ALU
|
||||
(UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
||||
])
|
||||
|
||||
# push VIEW to stores
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# ASSIGN can override st
|
||||
(UPat(UOps.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(UOps.ASSIGN, name="a"))),
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
|
||||
lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
|
||||
# VIEW on a reduce creates a new VIEW
|
||||
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
|
||||
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, UOps.STORE), name="root"), push_swizzle_down_through_elementwise),
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
# ** ScheduleItem context builder
|
||||
@@ -180,18 +180,18 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
|
||||
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
ctx.bufs.append(x)
|
||||
return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
|
||||
append_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), _append_buf)])
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
|
||||
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
||||
|
||||
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
||||
if b in ctx.assigned: ctx.assign_preloads.append(b)
|
||||
return x.replace(op=UOps.LOAD)
|
||||
return x.replace(op=Ops.LOAD)
|
||||
|
||||
to_si = PatternMatcher([
|
||||
(UPat(UOps.VIEW, name="x"), _append_st_vars),
|
||||
(UPat(UOps.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
|
||||
(UPat(UOps.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
|
||||
(UPat(UOps.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x),
|
||||
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
|
||||
(UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x),
|
||||
])
|
||||
|
||||
# ** fusion
|
||||
@@ -206,15 +206,15 @@ def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -
|
||||
# fuse and fold store -> loads
|
||||
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.src[0]:x.src[2] for x in pre.src})
|
||||
# assert cyclic dependency
|
||||
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {UOps.PRELOAD,UOps.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]):
|
||||
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]):
|
||||
if not all_same([x.op for x in ops]):
|
||||
raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
|
||||
# do movementops
|
||||
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
|
||||
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
|
||||
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.PRELOAD and x.src[0] in assign_targets):
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.src[0] in assign_targets):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
# convert to AST
|
||||
@@ -232,13 +232,13 @@ if getenv("RUN_PROCESS_REPLAY"):
|
||||
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp:
|
||||
ctx[b] = store
|
||||
return UOp(UOps.LOAD, load.dtype, (b, load.st_arg.to_uop()))
|
||||
return UOp(Ops.LOAD, load.dtype, (b, load.st_arg.to_uop()))
|
||||
|
||||
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
|
||||
do_realize = PatternMatcher([
|
||||
# always realize meta ops
|
||||
(UPatLoadStore(UPat((UOps.ASSIGN, UOps.CONTIGUOUS, *METAOPS.values()))), realize),
|
||||
(UPat((UOps.COPY, UOps.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"),
|
||||
(UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *METAOPS.values()))), realize),
|
||||
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"),
|
||||
lambda ctx,root,u,v=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if v is None else realize(ctx,**kwargs).view(v.st))),)
|
||||
])
|
||||
break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
||||
from collections import defaultdict
|
||||
from dataclasses import replace
|
||||
from tinygrad.ops import UOp, UOps, Variable, sym_infer
|
||||
from tinygrad.ops import UOp, Ops, Variable, sym_infer
|
||||
from tinygrad.device import Device, Buffer, Compiler
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
||||
from tinygrad.dtype import ImageDType, PtrDType
|
||||
@@ -88,7 +88,7 @@ def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_
|
||||
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
|
||||
bufsts: DefaultDict[int, List[UOp]] = defaultdict(list)
|
||||
for x in lin.bufs:
|
||||
if x.src[0].op is UOps.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
|
||||
if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
|
||||
rawbufs: List[Optional[Buffer]] = [None]*len(bufsts)
|
||||
for k,lx in bufsts.items():
|
||||
buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)
|
||||
|
||||
334
tinygrad/ops.py
334
tinygrad/ops.py
@@ -123,8 +123,8 @@ REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
# the order of these UOps controls the order of the toposort
|
||||
class UOps(FastEnum):
|
||||
# the order of these Ops controls the order of the toposort
|
||||
class Ops(FastEnum):
|
||||
# uops that aren't rendered
|
||||
SINK = auto()
|
||||
CONTIGUOUS = auto()
|
||||
@@ -183,9 +183,9 @@ class UOps(FastEnum):
|
||||
VCONST = auto()
|
||||
CONST = auto()
|
||||
|
||||
BUFFER_UOPS = {UOps.LOAD, UOps.PRELOAD, UOps.STORE, UOps.VALID}
|
||||
BUFFER_UOPS = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
|
||||
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
|
||||
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
|
||||
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
def resolve(x, default:bool=True):
|
||||
@@ -218,14 +218,14 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
|
||||
|
||||
class UOpMetaClass(type):
|
||||
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
|
||||
def __call__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
|
||||
return ret
|
||||
|
||||
class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
__slots__ = ["op", "dtype", "src", "arg"]
|
||||
def __init__(self, op:UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
# TODO: instant check rules here make debugging easier
|
||||
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
|
||||
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
|
||||
@@ -243,7 +243,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def key(self) -> bytes:
|
||||
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
|
||||
@functools.cached_property # parents with self
|
||||
@@ -251,31 +251,31 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
@functools.cached_property
|
||||
def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
|
||||
return (self.op.value, self.arg.value if self.op is UOps.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
||||
return (self.op.value, self.arg.value if self.op is Ops.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
||||
|
||||
# *** uop shape stuff ***
|
||||
|
||||
@property
|
||||
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.BUFFER, UOps.CONST, UOps.DEFINE_VAR}
|
||||
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
|
||||
@functools.cached_property
|
||||
def st(self) -> Optional[ShapeTracker]:
|
||||
if not self.has_st: return None
|
||||
if self.op in BUFFER_UOPS: return self.st_arg
|
||||
if self.op is UOps.VIEW: return self.arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
src_sts = [x.st for x in self.src if x.st is not None]
|
||||
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is UOps.REDUCE_AXIS else src_sts[0]
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is Ops.REDUCE_AXIS else src_sts[0]
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
|
||||
# *** uop evaluation ***
|
||||
|
||||
def simplify(self):
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return graph_rewrite(self, symbolic)
|
||||
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is UOps.CONST else ret
|
||||
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
|
||||
def _eval(self, dtype, expected_type:Type[T]) -> T:
|
||||
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
||||
vmin, vmax = (simple_self:=self.simplify())._min_max
|
||||
@@ -294,98 +294,98 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def st_arg(self) -> ShapeTracker:
|
||||
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
|
||||
ret = self.src[0 if self.op is UOps.VALID else 1]
|
||||
assert ret.op is UOps.VIEW, f"st_arg trying to return {ret}"
|
||||
ret = self.src[0 if self.op is Ops.VALID else 1]
|
||||
assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
|
||||
return ret.arg
|
||||
@property
|
||||
def axis_arg(self) -> Tuple[int, ...]:
|
||||
assert self.op in {UOps.REDUCE_AXIS, UOps.WMMA}, f"axis_arg called on {self.op}"
|
||||
ret = self.arg[1] if self.op is UOps.REDUCE_AXIS else self.arg[7]
|
||||
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
|
||||
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
||||
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
|
||||
return ret
|
||||
def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
|
||||
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st:ShapeTracker): return UOp(UOps.VIEW, self.dtype, (self,), st)
|
||||
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
|
||||
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st:ShapeTracker): return UOp(Ops.VIEW, self.dtype, (self,), st)
|
||||
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.count == 1
|
||||
if count == 1: return self
|
||||
return UOp(UOps.VECTORIZE, self.dtype.vec(count), (self,)*count)
|
||||
def cast(self, dtype:DType): return UOp(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype:DType): return UOp(UOps.BITCAST, dtype, (self,))
|
||||
return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
|
||||
def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:Union[Tuple[int, ...], int]):
|
||||
if isinstance(i, int):
|
||||
# NOTE: these are just shortcuts to not have to create and fold later
|
||||
if self.op is UOps.VECTORIZE: return self.src[i]
|
||||
if self.op is UOps.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
|
||||
if self.op is UOps.CONST: return UOp.const(self.dtype.scalar(), self.arg)
|
||||
if self.op is Ops.VECTORIZE: return self.src[i]
|
||||
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
|
||||
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
|
||||
i = (i,)
|
||||
if self.dtype == dtypes.void or (i == tuple(range(len(i))) and self.dtype.vcount == len(i)): return self
|
||||
assert len(i) >= 1 and all(x < self.dtype.vcount for x in i), f"bad GEP on {self.dtype}, {i}"
|
||||
return UOp(UOps.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
def load(self, *src:UOp, **kwargs): return UOp(UOps.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UOp, **kwargs): return UOp(UOps.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
|
||||
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def alu(self, arg, *src:UOp):
|
||||
out_dtype = (self, *src)[-1].dtype
|
||||
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
|
||||
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return UOp(UOps.ALU, out_dtype, (self,)+src, arg)
|
||||
return UOp(Ops.ALU, out_dtype, (self,)+src, arg)
|
||||
@staticmethod
|
||||
def const(dtype:DType, b:ConstLike):
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
def r(self, op, axis): return UOp(UOps.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis))
|
||||
def assign(self, x:UOp): return UOp(UOps.ASSIGN, self.dtype, (self,x))
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(Ops.REDUCE, self.dtype, (self,) + rng, op)
|
||||
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
|
||||
# *** uop Variable stuff ***
|
||||
|
||||
@staticmethod
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int):
|
||||
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
|
||||
return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@property
|
||||
def expr(self):
|
||||
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
return self.arg[0]
|
||||
def bind(self, val:int):
|
||||
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
|
||||
return UOp(UOps.BIND, self.dtype, (self, self.const_like(val)))
|
||||
return UOp(Ops.BIND, self.dtype, (self, self.const_like(val)))
|
||||
def unbind(self) -> Tuple[Variable, int]:
|
||||
assert self.op is UOps.BIND and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}"
|
||||
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
||||
return self.src[0], self.src[1].arg
|
||||
@property
|
||||
def val(self) -> int: return self.unbind()[1]
|
||||
def vars(self) -> Set[UOp]:
|
||||
bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR])
|
||||
bound_vars = set([x for x in self.sparents if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
|
||||
bound_var_base = set(x.src[0] for x in bound_vars)
|
||||
all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR])
|
||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
if self.op is UOps.VCONST: return functools.reduce(math.gcd, self.arg)
|
||||
if self.op is UOps.ALU:
|
||||
if self.op is Ops.CONST: return self.arg
|
||||
if self.op is Ops.VCONST: return functools.reduce(math.gcd, self.arg)
|
||||
if self.op is Ops.ALU:
|
||||
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
||||
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
|
||||
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
||||
return 1
|
||||
def divides(self, v) -> Optional[UOp]:
|
||||
if v==1: return self
|
||||
if self.op is UOps.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
||||
if self.op is UOps.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
||||
if self.op is UOps.ALU:
|
||||
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
|
||||
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
|
||||
if self.op is Ops.ALU:
|
||||
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
|
||||
if self.arg is BinaryOps.MUL:
|
||||
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
|
||||
@@ -398,20 +398,20 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@functools.cached_property
|
||||
def _min_max(self) -> Tuple[ConstType, ConstType]:
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
|
||||
if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
||||
if self.op is UOps.BIND: return self.src[0].vmin, self.src[0].vmax # ignore the bound value
|
||||
if self.op in {UOps.EXPAND, UOps.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
||||
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
|
||||
if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
||||
if self.op is Ops.BIND: return self.src[0].vmin, self.src[0].vmax # ignore the bound value
|
||||
if self.op in {Ops.EXPAND, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
||||
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
||||
if self.op is UOps.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
|
||||
if self.op is UOps.CONST: return self.arg, self.arg
|
||||
if self.op is UOps.VCONST: return (min(self.arg), max(self.arg))
|
||||
if self.op is UOps.ALU and not dtypes.is_float(self.dtype):
|
||||
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
|
||||
if self.op is Ops.CONST: return self.arg, self.arg
|
||||
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
|
||||
if self.op is Ops.ALU and not dtypes.is_float(self.dtype):
|
||||
s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)]
|
||||
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
|
||||
if self.arg is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals)
|
||||
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
||||
if self.arg is BinaryOps.IDIV and s1.op is Ops.CONST:
|
||||
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg
|
||||
if s1.arg < 0 and s0.vmin >= 0: return -(s0.vmax//-s1.arg), -(s0.vmin//-s1.arg)
|
||||
if self.arg is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax)
|
||||
@@ -430,7 +430,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@functools.cached_property
|
||||
def _sym_fxn(self):
|
||||
sself = self.simplify()
|
||||
varnames = tuple(x.arg[0] for x in sself.sparents if x.op is UOps.DEFINE_VAR)
|
||||
varnames = tuple(x.arg[0] for x in sself.sparents if x.op is Ops.DEFINE_VAR)
|
||||
# TODO: sanitize varnames, or don't use naked eval while staying fast
|
||||
return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
|
||||
|
||||
@@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
def render(self, simplify=True) -> str:
|
||||
ret = graph_rewrite(self.simplify() if simplify else self, renderer)
|
||||
return ret.arg if ret.op is UOps.NOOP else str(ret)
|
||||
return ret.arg if ret.op is Ops.NOOP else str(ret)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelInfo:
|
||||
@@ -476,7 +476,7 @@ def exec_alu(op:Op, dtype:DType, operands, truncate_output=True):
|
||||
|
||||
def print_uops(uops:List[UOp]):
|
||||
for i,u in enumerate(uops):
|
||||
formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src]
|
||||
formatted_parents = [uops.index(x) if x.op is not Ops.CONST else f"{x.arg}" for x in u.src]
|
||||
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):25s} " f"{str(formatted_parents):32s} {u.arg}")
|
||||
|
||||
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
@@ -487,26 +487,26 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
dont_count: Set[UOp] = set()
|
||||
if ignore_indexing:
|
||||
for u in uops:
|
||||
if u.op in {UOps.LOAD, UOps.STORE}:
|
||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||
dont_count = dont_count.union(u.src[0].sparents)
|
||||
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents)
|
||||
elif u.op is UOps.IF:
|
||||
elif u.op is Ops.IF:
|
||||
dont_count = dont_count.union(u.src[0].sparents)
|
||||
for u in uops:
|
||||
if u.op is UOps.RANGE:
|
||||
if u.op is Ops.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= (u.src[1] - u.src[0]).ssimplify()
|
||||
elif u.op is UOps.ENDRANGE:
|
||||
elif u.op is Ops.ENDRANGE:
|
||||
mults = mult_stack.pop(-1)
|
||||
elif u.op is UOps.SPECIAL:
|
||||
elif u.op is Ops.SPECIAL:
|
||||
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
|
||||
elif u.op is UOps.LOAD:
|
||||
elif u.op is Ops.LOAD:
|
||||
mem += u.dtype.itemsize * mults
|
||||
elif u.op is UOps.STORE:
|
||||
elif u.op is Ops.STORE:
|
||||
mem += u.src[1].dtype.itemsize * mults
|
||||
elif u.op is UOps.ALU and u not in dont_count:
|
||||
elif u.op is Ops.ALU and u not in dont_count:
|
||||
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
|
||||
elif u.op is UOps.WMMA and u not in dont_count:
|
||||
elif u.op is Ops.WMMA and u not in dont_count:
|
||||
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
|
||||
return flops, mem
|
||||
|
||||
@@ -525,11 +525,11 @@ def lines(fn) -> List[str]:
|
||||
|
||||
class UPat(MathTrait):
|
||||
__slots__ = ["op", "dtype", "arg", "name", "src"]
|
||||
def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
|
||||
def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
|
||||
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
|
||||
name:Optional[str]=None, allow_any_len:bool=False, location=None,
|
||||
custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None):
|
||||
self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op
|
||||
custom_early_reject:Optional[Set[Tuple[Ops, Any]]]=None):
|
||||
self.op: Optional[Tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else op
|
||||
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
|
||||
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
|
||||
self.src: Any = None
|
||||
@@ -561,24 +561,24 @@ class UPat(MathTrait):
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
|
||||
return UPat((UOps.CONST, UOps.VCONST) if vec else UOps.CONST, dtype=dtype, name=name)
|
||||
return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b)
|
||||
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
||||
|
||||
# copied from UOp
|
||||
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st=None, **kwargs): return UPat(UOps.VIEW, self.dtype, (self,), st, **kwargs)
|
||||
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,))
|
||||
def load(self, *src:UPat, **kwargs): return UPat(UOps.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UPat, **kwargs): return UPat(UOps.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def assign(self, x:UPat): return UPat(UOps.ASSIGN, self.dtype, (self,x))
|
||||
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
|
||||
def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,))
|
||||
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
|
||||
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
|
||||
def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x))
|
||||
|
||||
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
|
||||
def alu(self, arg, *src:UPat):
|
||||
asrc = (self,)+src
|
||||
return UPat(UOps.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg)
|
||||
return UPat(Ops.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg)
|
||||
|
||||
def printable(self:UPat) -> str:
|
||||
try: return lines(self.location[0])[self.location[1]-1].strip()
|
||||
@@ -627,7 +627,7 @@ class PatternMatcher:
|
||||
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
|
||||
self.patterns = patterns
|
||||
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
|
||||
self.pdict: Dict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set, bool]]] = {}
|
||||
self.pdict: Dict[Tuple[Ops, Any], List[Tuple[UPat, Callable, Set, bool]]] = {}
|
||||
# uop is required, arg is optional
|
||||
for p,fxn in self.patterns:
|
||||
assert p.op is not None
|
||||
@@ -745,84 +745,84 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
||||
# this is the matcher for the final rendered UOps
|
||||
# matcher functions returns True or False (or None to not match)
|
||||
spec = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
||||
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
||||
(UPat(UOps.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
|
||||
lambda x,c: all(y.op is UOps.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
||||
(UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
||||
(UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
|
||||
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
||||
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
|
||||
(UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
||||
(UPat(UOps.SPECIAL, src=()), lambda: True),
|
||||
(UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
||||
(UPat(Ops.SPECIAL, src=()), lambda: True),
|
||||
|
||||
# TODO: confirm the args of both of these are shapetrackers
|
||||
(UPat(UOps.VIEW, src=()), lambda: True),
|
||||
(UPat(UOps.VIEW, src=(UPat(),)), lambda: True),
|
||||
(UPat(Ops.VIEW, src=()), lambda: True),
|
||||
(UPat(Ops.VIEW, src=(UPat(),)), lambda: True),
|
||||
|
||||
(UPat(UOps.VALID, dtypes.bool, (UPat(UOps.VIEW),)), lambda: True),
|
||||
(UPat(UOps.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
|
||||
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
|
||||
|
||||
# early LOAD has a <buf, shapetracker, store?>
|
||||
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW))), lambda: True),
|
||||
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW), UPat(UOps.STORE))), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
|
||||
|
||||
# early STORE has a <buf, shapetracker, val>
|
||||
(UPat(UOps.STORE, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW), UPat())), lambda: True),
|
||||
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
|
||||
|
||||
# **** new style load/store ****
|
||||
|
||||
# INDEX is used in new style load/store
|
||||
(UPat(UOps.INDEX, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat())), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
|
||||
|
||||
# LOAD takes a <bufidx, alt?, gate?, barrier?>
|
||||
(UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)),)), lambda: True),
|
||||
(UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)), UPat((UOps.IF, UOps.BARRIER)))), lambda: True),
|
||||
(UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
|
||||
|
||||
# STORE takes a <bufidx, val, gate?>
|
||||
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat())), lambda: True),
|
||||
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(), UPat(UOps.IF))), lambda: True),
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
|
||||
|
||||
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
||||
(UPat(UOps.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
|
||||
(UPat(Ops.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
|
||||
lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
||||
(UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
|
||||
(UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
|
||||
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
|
||||
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
|
||||
# and SHL/SHR, the shift distance is an int
|
||||
(UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL),
|
||||
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL),
|
||||
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
|
||||
(UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR),
|
||||
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR),
|
||||
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
||||
|
||||
(UPat(UOps.ASSIGN, src=(UPat((UOps.DEFINE_ACC, UOps.DEFINE_GLOBAL)), UPat())), lambda: True),
|
||||
(UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True),
|
||||
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
|
||||
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
|
||||
|
||||
# all WMMA has 3 args, <x, w, acc>
|
||||
(UPat(UOps.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
|
||||
(UPat(UOps.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(UOps.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
|
||||
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# if has a <gate, barrier?>
|
||||
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
||||
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True),
|
||||
(UPat(UOps.ENDIF, dtype=dtypes.void, src=(UPat(UOps.IF),)), lambda: True),
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
|
||||
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
|
||||
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
|
||||
|
||||
(UPat(UOps.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
|
||||
(UPat(UOps.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
(UPat(UOps.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat((UOps.BITCAST, UOps.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
||||
(UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
|
||||
(UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
|
||||
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
|
||||
|
||||
# NOTE: for testing, we let sinks be anything
|
||||
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
|
||||
(UPat(UOps.SINK, dtypes.void), lambda: True),
|
||||
(UPat(UOps.NOOP), lambda: True),
|
||||
(UPat(Ops.SINK, dtypes.void), lambda: True),
|
||||
(UPat(Ops.NOOP), lambda: True),
|
||||
|
||||
# PTX LOAD/STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
(UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
|
||||
])
|
||||
|
||||
def type_verify(uops:List[UOp]):
|
||||
@@ -842,7 +842,7 @@ def cast_float_to_bf16(x: UOp) -> UOp:
|
||||
# *** most of symbolic lives here now ***
|
||||
|
||||
def split_uop(x:UOp, sep:BinaryOps):
|
||||
if x.op is UOps.ALU and x.arg is sep:
|
||||
if x.op is Ops.ALU and x.arg is sep:
|
||||
for s in x.src: yield from split_uop(s, sep)
|
||||
else: yield x
|
||||
|
||||
@@ -859,7 +859,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
assert divides is not None
|
||||
remainder.append(divides)
|
||||
something_changed = True
|
||||
elif u.op is UOps.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is UOps.CONST and s1.arg%c == 0:
|
||||
elif u.op is Ops.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
|
||||
remainder.append(u.src[0])
|
||||
something_changed = True
|
||||
else: remainder.append(u)
|
||||
@@ -874,7 +874,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
|
||||
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
|
||||
for u in split_uop(x, BinaryOps.ADD):
|
||||
if u.op is UOps.CONST:
|
||||
if u.op is Ops.CONST:
|
||||
# add all const together first
|
||||
if rem_const != 0: something_changed = True
|
||||
rem_const += u.arg
|
||||
@@ -886,7 +886,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
something_changed = True
|
||||
else:
|
||||
# divisor is the smallest common divisor of all MULs
|
||||
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
|
||||
if u.op is Ops.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
|
||||
remainder.append(u)
|
||||
gcd = math.gcd(gcd, factor)
|
||||
|
||||
@@ -917,11 +917,11 @@ def fold_unrolled_divs(divs:UOp):
|
||||
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
||||
add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None
|
||||
for u in add_chain:
|
||||
if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST): return None
|
||||
if not (u.op is Ops.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is Ops.CONST): return None
|
||||
if denominator is None: denominator = u.src[1].arg
|
||||
if denominator != u.src[1].arg: return None
|
||||
# assumed CONST is the last of an ADD
|
||||
if (s0:=u.src[0]).op is UOps.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is UOps.CONST and s0.src[1].op is UOps.CONST:
|
||||
if (s0:=u.src[0]).op is Ops.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
|
||||
seen_const.append(s0.src[1].arg)
|
||||
s0 = s0.src[0]
|
||||
else: seen_const.append(0)
|
||||
@@ -933,7 +933,7 @@ def fold_unrolled_divs(divs:UOp):
|
||||
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
|
||||
return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
|
||||
|
||||
def is_irreducible(u:UOp): return u.op in (UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE)
|
||||
def is_irreducible(u:UOp): return u.op in (Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE)
|
||||
|
||||
def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
||||
@@ -941,7 +941,7 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
changed, ret = False, []
|
||||
for u in split_uop(X, BinaryOps.ADD):
|
||||
# assumed the const is the last src of MUL
|
||||
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0:
|
||||
if u.op is Ops.ALU and u.arg is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
|
||||
changed = True
|
||||
u = u.src[0]
|
||||
if not (is_irreducible(u) and u.vmin >= 0): return None
|
||||
@@ -951,8 +951,8 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
def is_increasing(f:UOp) -> bool:
|
||||
# is f a monotonically increasing function regards its input
|
||||
if is_irreducible(f): return True
|
||||
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
if f.op is Ops.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op is Ops.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
return False # False if not sure
|
||||
|
||||
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
||||
@@ -960,10 +960,10 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
||||
# if it's X >= c, returns X, False, c
|
||||
|
||||
# (X < c).ne(True) -> X >= c
|
||||
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is UOps.CONST and valid.src[1].arg == 1 and \
|
||||
(s0:=valid.src[0]).op is UOps.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is UOps.CONST: return s0.src[0], False, s0.src[1].arg
|
||||
if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
|
||||
(s0:=valid.src[0]).op is Ops.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
|
||||
# X < c -> X <= c-1
|
||||
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is UOps.CONST: return valid.src[0], True, valid.src[1].arg-1
|
||||
if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
|
||||
raise ValueError(f"not able to parse {valid=}")
|
||||
|
||||
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
@@ -983,7 +983,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
|
||||
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
|
||||
candidates = []
|
||||
if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)):
|
||||
if expr.op is Ops.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
|
||||
# try checking the whole clause
|
||||
@@ -992,7 +992,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
for candidate in candidates:
|
||||
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
|
||||
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
|
||||
if uop.op is UOps.VECTORIZE and len(uop.src) == 2:
|
||||
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
|
||||
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
|
||||
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
|
||||
elif all_same(newuops): uop = newuops[0]
|
||||
@@ -1037,17 +1037,17 @@ symbolic_simple = PatternMatcher([
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
# ** constant folding **
|
||||
(UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))),
|
||||
(UPat(Ops.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))),
|
||||
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
|
||||
# ** COMMUTATIVE flipping **
|
||||
*[(UPat(UOps.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
|
||||
*[(UPat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
|
||||
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
|
||||
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
|
||||
# *** cast ***
|
||||
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
||||
(UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
|
||||
])
|
||||
|
||||
symbolic = symbolic_simple+PatternMatcher([
|
||||
@@ -1063,7 +1063,7 @@ symbolic = symbolic_simple+PatternMatcher([
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ALU min==max -> CONST (slow!)
|
||||
(UPat(UOps.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
(UPat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
|
||||
# max folding
|
||||
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
# TODO: why does this rule break beautiful_mnist?
|
||||
@@ -1090,11 +1090,11 @@ symbolic = symbolic_simple+PatternMatcher([
|
||||
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
|
||||
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
|
||||
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
|
||||
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
||||
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
||||
# *** rules from symbolic ***
|
||||
# unrolled arange div folding
|
||||
(UPat(UOps.ALU, name="divs", src=[UPat(), UPat(UOps.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
|
||||
(UPat(Ops.ALU, name="divs", src=[UPat(), UPat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
|
||||
# generic lt folding
|
||||
(UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
||||
# canonicalize a simplex with positive coefficients > 0
|
||||
@@ -1115,23 +1115,23 @@ symbolic_flat = symbolic+PatternMatcher([
|
||||
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
])
|
||||
|
||||
_substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
|
||||
# for debug
|
||||
syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>",
|
||||
BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
|
||||
renderer = PatternMatcher([
|
||||
(UPat((UOps.DEFINE_VAR, UOps.SPECIAL), name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
|
||||
(UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")),
|
||||
(UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))),
|
||||
(UPat(UOps.BIND, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(UOps.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(UOps.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.MULACC),
|
||||
lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.WHERE),
|
||||
lambda x: UOp(UOps.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
|
||||
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
|
||||
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
|
||||
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.MULACC),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.WHERE),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
||||
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
|
||||
])
|
||||
|
||||
# *** what was symbolic.py ***
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional, List, Tuple, Dict, Callable, Any
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import to_function_name, dedup, prod
|
||||
from tinygrad.ops import Op, UOps, UOp, flops_mem, sym_infer, sint, Variable
|
||||
from tinygrad.ops import Op, Ops, UOp, flops_mem, sym_infer, sint, Variable
|
||||
from tinygrad.dtype import DType
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -42,10 +42,10 @@ class Program:
|
||||
if not self._ran_post_init and self.uops is not None:
|
||||
# single pass through the uops
|
||||
for u in self.uops:
|
||||
if u.op is UOps.DEFINE_VAR: self.vars.append(u)
|
||||
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
|
||||
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
|
||||
if u.op is UOps.SPECIAL:
|
||||
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
|
||||
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
|
||||
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL])
|
||||
if u.op is Ops.SPECIAL:
|
||||
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
|
||||
if u.arg[0][0] == 'i': self.local_size = None
|
||||
special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
|
||||
|
||||
@@ -2,66 +2,66 @@ from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
|
||||
import os, math
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat, cast_float_to_bf16
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
base_rewrite = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
(UPat(UOps.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
|
||||
(UPat(UOps.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
|
||||
(UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda ctx: "}"),
|
||||
(UPat(UOps.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
|
||||
(UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
(UPat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
|
||||
(UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
|
||||
# r method accesses
|
||||
(UPat(UOps.RANGE, name="x"),
|
||||
(UPat(Ops.RANGE, name="x"),
|
||||
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
|
||||
(UPat(UOps.VECTORIZE, name="x"),
|
||||
(UPat(Ops.VECTORIZE, name="x"),
|
||||
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
||||
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
|
||||
(UPat(UOps.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"),
|
||||
(UPat(UOps.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
|
||||
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
|
||||
(UPat(UOps.BARRIER), lambda ctx: ctx.barrier),
|
||||
(UPat(UOps.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
(UPat(UOps.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
|
||||
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
||||
(UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
|
||||
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
|
||||
# const
|
||||
(UPat(UOps.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"),
|
||||
(UPat(UOps.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"),
|
||||
(UPat(UOps.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None),
|
||||
(UPat(UOps.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
|
||||
(UPat(UOps.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
|
||||
(UPat(UOps.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
|
||||
(UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
|
||||
(UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
|
||||
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"),
|
||||
(UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None),
|
||||
(UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
|
||||
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
|
||||
# consts are rendered to larger type and casted
|
||||
(UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"),
|
||||
(UPat(UOps.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"),
|
||||
(UPat(UOps.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"),
|
||||
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"),
|
||||
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"),
|
||||
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"),
|
||||
# default const render
|
||||
(UPat(UOps.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
|
||||
# new load/store
|
||||
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == BinaryOps.ADD else ctx[idx]})"),
|
||||
(UPat(UOps.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(UOps.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
|
||||
(UPat(UOps.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
|
||||
# alu/gep
|
||||
(UPat(UOps.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg](
|
||||
(UPat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg](
|
||||
*([strip_parens(ctx[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
||||
(UPat(UOps.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
])
|
||||
|
||||
extra_pm = PatternMatcher([
|
||||
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
|
||||
(UPat(UOps.BITCAST, name="x"),
|
||||
lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None),
|
||||
(UPat(Ops.BITCAST, name="x"),
|
||||
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
||||
lambda store: UOp(UOps.STORE, src=store.src[:2]+(UOp(UOps.IF, src=(store.src[2],)),))),
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
||||
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
||||
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||
(UPat(UOps.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
(UPat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
])
|
||||
|
||||
def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
|
||||
@@ -125,41 +125,41 @@ class CStyleLanguage(Renderer):
|
||||
depth = 1
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
for u in uops:
|
||||
if u.op in (UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR):
|
||||
r[u] = f"data{u.arg}" if u.op is UOps.DEFINE_GLOBAL else u.arg[0]
|
||||
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
||||
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
|
||||
bufs[u] = (r[u], (u.dtype, False))
|
||||
continue
|
||||
|
||||
# mark buffers that we store to writable
|
||||
if u.op is UOps.STORE:
|
||||
if u.op is Ops.STORE:
|
||||
for up in u.src[0].sparents:
|
||||
if up.op is UOps.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
|
||||
if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
|
||||
|
||||
# naming
|
||||
prefix = None
|
||||
if u.op is UOps.SPECIAL:
|
||||
if u.op is Ops.SPECIAL:
|
||||
r[u] = u.arg[0]
|
||||
else:
|
||||
prefix = {UOps.RANGE: "ridx", UOps.ALU: "alu", UOps.WMMA: "wmma", UOps.DEFINE_LOCAL: "temp", UOps.CONST: "const",
|
||||
UOps.CAST: "cast", UOps.BITCAST: "cast", UOps.GEP: "gep", UOps.VECTORIZE: "cast", UOps.NOOP: "precast",
|
||||
UOps.INDEX: "bidx", UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk")
|
||||
prefix = {Ops.RANGE: "ridx", Ops.ALU: "alu", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
|
||||
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.NOOP: "precast",
|
||||
Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "unk")
|
||||
r[u] = f"{prefix}{c[prefix]}"
|
||||
|
||||
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1
|
||||
if u.op in {UOps.CONST, UOps.GEP, UOps.INDEX} or (u.op in {UOps.VECTORIZE, UOps.ALU, UOps.CAST, UOps.BITCAST}
|
||||
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
|
||||
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, Ops.ALU, Ops.CAST, Ops.BITCAST}
|
||||
and child_count[u] == 1 and not getenv("EXPAND_SSA")):
|
||||
r[u] = l
|
||||
else:
|
||||
if u.op in {UOps.RANGE, UOps.ASSIGN, UOps.DEFINE_LOCAL} or u.dtype == dtypes.void:
|
||||
if u.op is UOps.ASSIGN: r[u] = r[u.src[0]]
|
||||
if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
|
||||
if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
|
||||
else:
|
||||
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not UOps.SPECIAL else "")
|
||||
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
|
||||
kernel.append(" "*depth + l)
|
||||
if prefix: c[prefix] += 1 # if it was used, increment
|
||||
if u.op in {UOps.IF, UOps.RANGE}: depth += 1
|
||||
if u.op in {Ops.IF, Ops.RANGE}: depth += 1
|
||||
del self.r
|
||||
|
||||
# NOTE: this relies on bufs dict preserving order
|
||||
@@ -189,7 +189,7 @@ class ClangRenderer(CStyleLanguage):
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
|
||||
# https://github.com/corsix/amx
|
||||
for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
|
||||
for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
||||
prefix += [
|
||||
'#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
|
||||
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
|
||||
@@ -214,13 +214,13 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(UOps.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
|
||||
# load/store image (OpenCL)
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
|
||||
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
|
||||
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
|
||||
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
|
||||
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
|
||||
(UPat(UOps.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
|
||||
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
|
||||
]) + base_rewrite
|
||||
|
||||
@@ -234,13 +234,13 @@ class IntelRenderer(OpenCLRenderer):
|
||||
st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
|
||||
(UPat(UOps.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
|
||||
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
|
||||
]) + OpenCLRenderer.string_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix = []
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
||||
dt_in = ("ushort", "bf16") if arg[2] == dtypes.bfloat16 else (arg[2].name, "f16")
|
||||
prefix.append(f"""{arg[3].name}8 __{arg[0]}({dt_in[0]}16 a, {dt_in[0]}16 b, {arg[3].name}8 c) {{
|
||||
return intel_sub_group_{dt_in[1]}_{dt_in[1]}_matrix_mad_k16(as_int8(a), as_int8(b), c);\n}}""")
|
||||
@@ -272,17 +272,17 @@ class MetalRenderer(CStyleLanguage):
|
||||
# upcast to float32 all the ops that don't support bfloat16
|
||||
extra_matcher = PatternMatcher([
|
||||
# NOTE: this is copied from PTX
|
||||
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
|
||||
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))
|
||||
for op in [UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]]
|
||||
]) + extra_pm
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(UOps.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
|
||||
]) + base_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is Ops.WMMA])
|
||||
for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
|
||||
simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
|
||||
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
||||
@@ -335,7 +335,7 @@ class CUDARenderer(CStyleLanguage):
|
||||
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
|
||||
|
||||
dt_map = { dtypes.half: "f16", dtypes.bfloat16: "bf16" }
|
||||
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
|
||||
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
|
||||
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
|
||||
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
|
||||
n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
|
||||
@@ -353,7 +353,7 @@ class CUDARenderer(CStyleLanguage):
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
|
||||
|
||||
def get_kernel_modifier(self, uops:List[UOp]) -> str:
|
||||
maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is UOps.SPECIAL and u.arg[0][0] == "l")
|
||||
maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
|
||||
return f"__launch_bounds__({maxThreadsPerBlock}) "
|
||||
|
||||
@@ -387,20 +387,20 @@ class AMDRenderer(CStyleLanguage):
|
||||
type_map = {dtypes.bfloat16: "hip_bfloat16"}
|
||||
extra_matcher = PatternMatcher([
|
||||
# cast bfloat16 alus to float
|
||||
(UPat(UOps.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda b,x,y: UOp(UOps.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
|
||||
(UPat(UOps.ALU, dtype=dtypes.bfloat16, name="x"),
|
||||
(UPat(Ops.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda b,x,y: UOp(Ops.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
|
||||
(UPat(Ops.ALU, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
|
||||
(UPat(UOps.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
(UPat(Ops.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
|
||||
# add float intermediate casting for bfloat16
|
||||
(UPat(UOps.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
||||
(UPat(UOps.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
||||
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
|
||||
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
|
||||
# bfloat16 casting
|
||||
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
|
||||
(UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
|
||||
(UPat(Ops.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
|
||||
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
|
||||
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
|
||||
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
|
||||
|
||||
def render_vector_prefix(self, dtype:DType) -> str:
|
||||
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
|
||||
@@ -414,7 +414,7 @@ class AMDRenderer(CStyleLanguage):
|
||||
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("struct hip_bfloat16 { unsigned short data; };")
|
||||
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
|
||||
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
|
||||
else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
|
||||
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
|
||||
@@ -423,7 +423,7 @@ class AMDRenderer(CStyleLanguage):
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
def get_kernel_modifier(self, uops:List[UOp]) -> str:
|
||||
requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is UOps.SPECIAL and u.arg[0][0] == "l")
|
||||
requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
|
||||
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
||||
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
||||
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, Callable, List, Optional
|
||||
from llvmlite import ir
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes
|
||||
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps, UOps, UOp
|
||||
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps, Ops, UOp
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
||||
@@ -71,7 +71,7 @@ class LLVMRenderer(Renderer):
|
||||
module = ir.Module(name=__file__)
|
||||
|
||||
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}}
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
|
||||
# create llvm function
|
||||
@@ -90,14 +90,14 @@ class LLVMRenderer(Renderer):
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
||||
if uop is UOps.INDEX:
|
||||
if uop is Ops.INDEX:
|
||||
lvars[u] = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
|
||||
elif uop is UOps.STORE:
|
||||
elif uop is Ops.STORE:
|
||||
if len(src) > 2:
|
||||
with bb[-1].if_then(lvars[src[2]]): bb[-1].store(lvars[src[1]], lvars[src[0]])
|
||||
else:
|
||||
bb[-1].store(lvars[src[1]], lvars[src[0]])
|
||||
elif uop is UOps.ENDRANGE:
|
||||
elif uop is Ops.ENDRANGE:
|
||||
loop_entry_bb, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
|
||||
lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
|
||||
@@ -105,7 +105,7 @@ class LLVMRenderer(Renderer):
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
|
||||
else:
|
||||
if uop is UOps.RANGE:
|
||||
if uop is Ops.RANGE:
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
||||
bb[-2].branch(bb[-1].block)
|
||||
|
||||
@@ -119,10 +119,10 @@ class LLVMRenderer(Renderer):
|
||||
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
||||
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
|
||||
loop_blocks.append((bb[-1].block, phis))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
elif uop is Ops.DEFINE_ACC:
|
||||
lvars[u] = const(src[0].arg, dtype)
|
||||
reduce_phis.append(u)
|
||||
elif uop is UOps.LOAD:
|
||||
elif uop is Ops.LOAD:
|
||||
if len(src) > 1:
|
||||
with bb[-1].if_else(lvars[src[2]]) as (then, otherwise):
|
||||
with then:
|
||||
@@ -135,17 +135,17 @@ class LLVMRenderer(Renderer):
|
||||
else:
|
||||
val = bb[-1].load(lvars[src[0]])
|
||||
lvars[u] = val
|
||||
elif uop is UOps.ASSIGN:
|
||||
elif uop is Ops.ASSIGN:
|
||||
lvars[u] = lvars[src[1]]
|
||||
# ASSIGN UOps can link to other ASSIGN Uops, backtrace this to DEFINE_ACC
|
||||
backward = src[0]
|
||||
while backward.op is UOps.ASSIGN: backward = backward.src[0]
|
||||
while backward.op is Ops.ASSIGN: backward = backward.src[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop is UOps.ALU:
|
||||
elif uop is Ops.ALU:
|
||||
lvars[u] = self.code_for_op[args](bb[-1], *[lvars[x] for x in src], src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
|
||||
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
|
||||
elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST)
|
||||
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is Ops.CONST: lvars[u] = const(args, dtype)
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
bb[-1].ret_void()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, Ops, UOp, PatternMatcher, UPat
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
@@ -38,20 +38,20 @@ ptx_matcher = PatternMatcher([
|
||||
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
|
||||
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
|
||||
# upcast to float32 all the ops that don't support half
|
||||
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"),
|
||||
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
|
||||
for op in asm_for_op.keys() if op not in supports_half],
|
||||
# load/store bool -> uint8
|
||||
(UPat(UOps.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
|
||||
(UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
|
||||
(UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
|
||||
# load/store use pointer arithmetic, and the cast does nothing
|
||||
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
|
||||
(UPat(UOps.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
|
||||
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
|
||||
# ptx shr and shl instructions require y to be uint
|
||||
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(UOps.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None),
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(UOps.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None),
|
||||
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None),
|
||||
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None),
|
||||
])
|
||||
|
||||
class PTXRenderer(Renderer):
|
||||
@@ -146,51 +146,51 @@ class PTXRenderer(Renderer):
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
||||
if uop is UOps.IF:
|
||||
if uop is Ops.IF:
|
||||
pred_reg = _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)
|
||||
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", pred_reg, invert=True))
|
||||
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
|
||||
elif uop is UOps.ENDRANGE:
|
||||
elif uop is Ops.BARRIER and self.barrier: kk(self.barrier)
|
||||
elif uop is Ops.ENDRANGE:
|
||||
kk(self.code_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
|
||||
self.code_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
|
||||
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
|
||||
elif uop is UOps.ENDIF:
|
||||
elif uop is Ops.ENDIF:
|
||||
kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:")
|
||||
elif uop is UOps.STORE:
|
||||
elif uop is Ops.STORE:
|
||||
assert src[0].dtype == dtypes.int64, "store isn't int64"
|
||||
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
||||
gate = f"@{r[src[2]]} " if len(src)>2 and src[2].op is not UOps.IF else ""
|
||||
mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
||||
gate = f"@{r[src[2]]} " if len(src)>2 and src[2].op is not Ops.IF else ""
|
||||
if src[1].dtype.count > 1:
|
||||
kk(gate + f"st{mem_type}.v{src[1].dtype.count}.{self.mem_types[src[1].dtype.scalar()]} [{r[src[0]]}+0], {{{', '.join(r[src[1]])}}};")
|
||||
else:
|
||||
kk(gate + f"st{mem_type}.{self.mem_types[src[1].dtype]} [{r[src[0]]}+0], {r[src[1]]};")
|
||||
else:
|
||||
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
||||
elif uop is UOps.ALU:
|
||||
if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
||||
elif uop is Ops.ALU:
|
||||
src_dtype = src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype
|
||||
kk(self.code_for_op[args](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
elif uop is Ops.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};")
|
||||
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
elif uop is Ops.SPECIAL:
|
||||
assert args[0][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[0]}, %{'ctaid' if args[0][0] == 'g' else 'tid'}.{chr(120+int(args[0][-1]))};")
|
||||
r[u] = "%" + args[0]
|
||||
kernel = [f".reg .u32 %{args[0]};"] + kernel
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
elif uop is Ops.DEFINE_VAR:
|
||||
bufs.append((args[0], dtype))
|
||||
r[u] = f"%{args[0]}"
|
||||
kk(*self.render_load(args[0], ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True)
|
||||
elif uop is UOps.GEP:
|
||||
elif uop is Ops.CONST: r[u] = const(args, dtype, mov=True)
|
||||
elif uop is Ops.GEP:
|
||||
assert len(u.arg) == 1
|
||||
r[u] = r[src[0]][u.arg[0]]
|
||||
elif uop is UOps.LOAD:
|
||||
elif uop is Ops.LOAD:
|
||||
assert src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
||||
has_gate = len(src) > 2 and src[2].op is UOps.ALU
|
||||
mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
||||
has_gate = len(src) > 2 and src[2].op is Ops.ALU
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if has_gate:
|
||||
@@ -200,25 +200,25 @@ class PTXRenderer(Renderer):
|
||||
else:
|
||||
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None,
|
||||
alt=r[src[1]] if has_gate else None, ss=mem_type, offset=0))
|
||||
elif uop is UOps.ASSIGN:
|
||||
elif uop is Ops.ASSIGN:
|
||||
if dtype.count > 1:
|
||||
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
|
||||
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};")
|
||||
r[u] = r[src[0]]
|
||||
# NOTE: casting to str is fine because you can't vectorize a vectorize
|
||||
elif uop is UOps.VECTORIZE: r[u] = [cast(str,r[x]) for x in src]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
_cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
elif uop is Ops.VECTORIZE: r[u] = [cast(str,r[x]) for x in src]
|
||||
elif uop in {Ops.CAST, Ops.BITCAST}:
|
||||
_cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is Ops.BITCAST, u=u)
|
||||
elif uop is Ops.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
elif uop is Ops.DEFINE_GLOBAL:
|
||||
bufs.append((nm:=f"data{args}", dtype))
|
||||
r[u] = f"%{nm}"
|
||||
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
||||
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
|
||||
elif uop is UOps.WMMA:
|
||||
elif uop is Ops.WMMA:
|
||||
_, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = args
|
||||
wmma, n_operands = [], tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2])
|
||||
dt_map = { dtypes.half: "f16" }
|
||||
|
||||
@@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Compiler, Allocator
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, UOps, UOp
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
|
||||
|
||||
@@ -26,7 +26,7 @@ def _store(m, i, v):
|
||||
|
||||
class PythonProgram:
|
||||
def __init__(self, name:str, lib:bytes):
|
||||
self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
|
||||
self.uops: List[Tuple[Ops, Optional[DType], List[int], Any]] = pickle.loads(lib)
|
||||
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
st = time.perf_counter()
|
||||
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
|
||||
@@ -40,12 +40,12 @@ class PythonProgram:
|
||||
loop_ends: Dict[int, int] = {}
|
||||
while i < len(self.uops):
|
||||
uop, dtype, idp, arg = self.uops[i]
|
||||
void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
|
||||
if uop is UOps.DEFINE_ACC: idp = [idp[0]]
|
||||
void_ops = {Ops.STORE, Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF}
|
||||
if uop is Ops.DEFINE_ACC: idp = [idp[0]]
|
||||
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
||||
if uop is UOps.STORE:
|
||||
if uop is Ops.STORE:
|
||||
if len(inp) == 2: inp.append([True] * len(inp[0])) # set the gate to True
|
||||
if dtp[1].count > 1:
|
||||
for j,val in enumerate(inp[1]):
|
||||
@@ -56,32 +56,32 @@ class PythonProgram:
|
||||
if g: _store(m, o, v)
|
||||
i += 1
|
||||
continue
|
||||
if uop is UOps.ENDRANGE:
|
||||
if uop is Ops.ENDRANGE:
|
||||
loop_ends[idp[0]] = i
|
||||
i = idp[0]
|
||||
continue
|
||||
if uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF):
|
||||
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF):
|
||||
# in the python emulator, the warp is always in sync
|
||||
i += 1
|
||||
continue
|
||||
assert dtype is not None, f"{uop} is missing a dtype"
|
||||
dl[i] = dtype
|
||||
if uop is UOps.DEFINE_GLOBAL:
|
||||
if uop is Ops.DEFINE_GLOBAL:
|
||||
assert dtype.fmt is not None
|
||||
ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
elif uop is Ops.DEFINE_LOCAL:
|
||||
assert dtype.fmt is not None
|
||||
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
|
||||
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
elif uop is Ops.DEFINE_VAR:
|
||||
ul[i] = [pvals.pop(0)] * warp_size
|
||||
elif uop is UOps.SPECIAL:
|
||||
elif uop is Ops.SPECIAL:
|
||||
if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size
|
||||
elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp]
|
||||
elif uop is UOps.CONST: ul[i] = [arg] * warp_size
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
elif uop is Ops.CONST: ul[i] = [arg] * warp_size
|
||||
elif uop is Ops.DEFINE_ACC:
|
||||
ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
|
||||
elif uop is UOps.INDEX:
|
||||
elif uop is Ops.INDEX:
|
||||
ret = []
|
||||
if isinstance(dtp[0], ImageDType):
|
||||
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
|
||||
@@ -90,9 +90,9 @@ class PythonProgram:
|
||||
else:
|
||||
for m,o in zip(inp[0], inp[1]): ret.append((m,o))
|
||||
ul[i] = ret
|
||||
elif uop is UOps.CAST and isinstance(dtype, PtrDType):
|
||||
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
|
||||
ul[i] = inp[0]
|
||||
elif uop is UOps.RANGE:
|
||||
elif uop is Ops.RANGE:
|
||||
if i not in ul: ul[i] = [inp[0][0]] * warp_size
|
||||
else:
|
||||
for j in range(len(ul[i])):
|
||||
@@ -101,11 +101,11 @@ class PythonProgram:
|
||||
del ul[i]
|
||||
i = loop_ends[i] + 1
|
||||
continue
|
||||
elif uop is UOps.VECTORIZE: ul[i] = inp
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
elif uop is Ops.VECTORIZE: ul[i] = inp
|
||||
elif uop in {Ops.CAST, Ops.BITCAST}:
|
||||
assert dtp[0].fmt and dtype.fmt
|
||||
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
|
||||
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
||||
if uop is Ops.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
|
||||
else:
|
||||
casted = [dtypes.as_const(x, dtype) for x in inp[0]]
|
||||
if dtypes.is_int(dtype):
|
||||
@@ -114,18 +114,18 @@ class PythonProgram:
|
||||
elif dtypes.is_float(dtype):
|
||||
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
|
||||
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
|
||||
elif uop is UOps.LOAD:
|
||||
elif uop is Ops.LOAD:
|
||||
if dtype.count > 1:
|
||||
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
|
||||
else:
|
||||
ul[i] = load(inp)
|
||||
elif uop is UOps.ASSIGN:
|
||||
elif uop is Ops.ASSIGN:
|
||||
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
|
||||
ul[i] = inp[0]
|
||||
elif uop is UOps.GEP:
|
||||
elif uop is Ops.GEP:
|
||||
assert len(arg) == 1
|
||||
ul[i] = inp[0][arg[0]]
|
||||
elif uop is UOps.WMMA:
|
||||
elif uop is Ops.WMMA:
|
||||
# here are the models for the WMMA instruction on the different hardware
|
||||
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
|
||||
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}"
|
||||
@@ -180,7 +180,7 @@ class PythonProgram:
|
||||
def c_map(_, elem): return (elem%16, elem//16)
|
||||
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif uop is UOps.ALU:
|
||||
elif uop is Ops.ALU:
|
||||
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
|
||||
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
|
||||
ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Tuple, List, Optional, Dict, Set
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint
|
||||
from tinygrad.ops import UOp, Ops, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class ShapeTracker:
|
||||
@@ -40,7 +40,7 @@ class ShapeTracker:
|
||||
|
||||
def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
||||
|
||||
def to_uop(self) -> UOp: return UOp(UOps.VIEW, dtypes.void, (), self)
|
||||
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
|
||||
|
||||
def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
|
||||
idx, valid = self.views[-1].to_indexed_uops(_idxs)
|
||||
@@ -75,20 +75,20 @@ class ShapeTracker:
|
||||
ret: List[Optional[sint]] = [None] * len(self.shape)
|
||||
idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
|
||||
for c in split_uop(idx, BinaryOps.ADD):
|
||||
if c.op is UOps.RANGE: ret[c.arg] = 1
|
||||
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[0].op is UOps.RANGE and c.src[1].op is UOps.CONST: ret[c.src[0].arg] = c.src[1].arg
|
||||
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[1].op is UOps.RANGE and c.src[0].op is UOps.CONST: ret[c.src[1].arg] = c.src[0].arg
|
||||
used_ranges = [x.arg for x in idx.sparents if x.op is UOps.RANGE]
|
||||
if c.op is Ops.RANGE: ret[c.arg] = 1
|
||||
if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
|
||||
if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
|
||||
used_ranges = [x.arg for x in idx.sparents if x.op is Ops.RANGE]
|
||||
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
|
||||
if not ignore_valid:
|
||||
for masked_axis in [x.arg for x in valid.sparents if x.op is UOps.RANGE]: ret[masked_axis] = None
|
||||
for masked_axis in [x.arg for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None
|
||||
return tuple(ret)
|
||||
|
||||
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
def axis_is_masked(self, axis:int) -> bool:
|
||||
_, valid = self.to_indexed_uops()
|
||||
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is UOps.RANGE]
|
||||
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is Ops.RANGE]
|
||||
|
||||
def simplify(self) -> ShapeTracker:
|
||||
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, UOps, BinaryOps, sint, Variable, SimpleMathTrait
|
||||
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, Ops, BinaryOps, sint, Variable, SimpleMathTrait
|
||||
from tinygrad.device import Device, Buffer, BufferOptions
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -136,7 +136,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
||||
elif isinstance(data, UOp):
|
||||
assert data.op is UOps.BIND and data.src[0].op is UOps.DEFINE_VAR and data.src[1].op is UOps.CONST, f"can't create tensor from UOp {data}"
|
||||
assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}"
|
||||
data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data)
|
||||
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
@@ -382,9 +382,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
|
||||
@staticmethod
|
||||
def from_uop(y:UOp, **kwargs) -> Tensor:
|
||||
if y.op is UOps.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
|
||||
if y.op is UOps.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
||||
if y.op is UOps.ALU:
|
||||
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
|
||||
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
|
||||
if y.op is Ops.ALU:
|
||||
if y.arg is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
|
||||
if y.arg is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
|
||||
if y.arg is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))
|
||||
|
||||
@@ -5,13 +5,13 @@ from urllib.parse import parse_qs, urlparse
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Tuple, Optional
|
||||
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
|
||||
from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
||||
uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0",
|
||||
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
|
||||
UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff",
|
||||
UOps.INDEX: "#e8ffa0", UOps.WMMA: "#efefc0", UOps.VIEW: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"}
|
||||
uops_colors = {Ops.ALU: "#ffffc0", Ops.LOAD: "#ffc0c0", Ops.STORE: "#c0ffc0", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE: "#C4A484",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.REDUCE_AXIS: "#f58488"}
|
||||
|
||||
# ** API spec
|
||||
|
||||
@@ -50,7 +50,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List
|
||||
for k,ctxs in contexts:
|
||||
name = to_function_name(k.name) if isinstance(k, Kernel) else k
|
||||
for ctx in ctxs:
|
||||
if ctx.sink.op is UOps.CONST: continue
|
||||
if ctx.sink.op is Ops.CONST: continue
|
||||
upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None]
|
||||
if name not in kernels: kernels[name] = []
|
||||
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
|
||||
@@ -60,11 +60,11 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
assert isinstance(x, UOp)
|
||||
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
|
||||
for u in x.sparents:
|
||||
if u.op is UOps.CONST: continue
|
||||
if u.op is Ops.CONST: continue
|
||||
label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
|
||||
for idx,x in enumerate(u.src):
|
||||
if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}"
|
||||
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
|
||||
if x.op is Ops.CONST: label += f"\nCONST{idx} {x.arg:g}"
|
||||
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not Ops.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
|
||||
return graph
|
||||
def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
|
||||
if (found:=replaces.get(base)) is not None: return found
|
||||
@@ -88,7 +88,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
|
||||
if new_sink is sink:
|
||||
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
|
||||
# update ret data
|
||||
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST])
|
||||
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not Ops.CONST])
|
||||
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
|
||||
g.graphs.append(sink:=new_sink)
|
||||
return g
|
||||
|
||||
Reference in New Issue
Block a user