mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
28
test/external/external_benchmark_sdxl_softmax.py
vendored
28
test/external/external_benchmark_sdxl_softmax.py
vendored
@@ -1,28 +0,0 @@
|
||||
from tinygrad import Tensor, dtypes, GlobalCounters
|
||||
from tinygrad.engine.realize import get_program
|
||||
|
||||
if __name__ == "__main__":
|
||||
t = Tensor.empty(81920, 4096, dtype=dtypes.half)
|
||||
GlobalCounters.reset()
|
||||
t.softmax(-1, dtype="half").realize()
|
||||
GlobalCounters.reset()
|
||||
t.softmax(-1, dtype="half", _single_kernel=True).realize()
|
||||
|
||||
from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad.helpers import get_single_element
|
||||
GlobalCounters.reset()
|
||||
si = get_single_element(t.softmax(-1, dtype="half", _single_kernel=True).schedule())
|
||||
k = Kernel(si.ast)
|
||||
#k.apply_opt(Opt(OptOps.UPCAST, 0, 4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 4))
|
||||
k.apply_opt(Opt(OptOps.LOCAL, 1, 32))
|
||||
#k.apply_opt(Opt(OptOps.LOCAL, 0, 8))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 1, 4))
|
||||
k.apply_opt(Opt(OptOps.UNROLL, 0, 4))
|
||||
#k.apply_opt(Opt(OptOps.GROUP, 1, 256))
|
||||
#k.apply_opt(Opt(OptOps.GROUP, 0, 32))
|
||||
#k.apply_opt(Opt(OptOps.GROUP, 1, 32))
|
||||
#k.apply_opt(Opt(OptOps.GROUP, 0, 32))
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||
run = CompiledRunner(prg:=get_program(k.ast, k.opts, k.applied_opts))
|
||||
ExecItem(k.ast, list(si.bufs), prg=run).run()
|
||||
56
test/external/external_metal_compile_slow.py
vendored
56
test/external/external_metal_compile_slow.py
vendored
@@ -1,56 +0,0 @@
|
||||
# ruff: noqa: E501
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import Timing, getenv
|
||||
from tinygrad.codegen.opt.kernel import Opt, OptOps
|
||||
from tinygrad.engine.realize import get_program, CompiledRunner
|
||||
from tinygrad.uop.ops import UOp, Ops, AxisType
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("TC", 0) == 0:
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=0, src=())
|
||||
c1 = UOp.range(UOp.const(dtypes.int, 512), 0, AxisType.GLOBAL)
|
||||
c2 = UOp.range(UOp.const(dtypes.int, 64), 1, AxisType.GLOBAL)
|
||||
c3 = UOp.range(UOp.const(dtypes.int, 6), 2, AxisType.GLOBAL)
|
||||
c4 = UOp.range(UOp.const(dtypes.int, 6), 3, AxisType.GLOBAL)
|
||||
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(2097152), arg=1, src=())
|
||||
c6 = UOp.range(UOp.const(dtypes.int, 64), 1004, AxisType.REDUCE)
|
||||
c7 = UOp.range(UOp.const(dtypes.int, 3), 1005, AxisType.REDUCE)
|
||||
c8 = UOp.range(UOp.const(dtypes.int, 3), 1006, AxisType.REDUCE)
|
||||
c9 = c5.index(((((((c1*UOp.const(dtypes.int, 4096))+(c3*UOp.const(dtypes.int, 8)))+c4)+(c6*UOp.const(dtypes.int, 64)))+(c7*UOp.const(dtypes.int, 8)))+c8), UOp.const(dtypes.bool, True)).load()
|
||||
c10 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=2, src=())
|
||||
c11 = c10.index(((((c2*UOp.const(dtypes.int, 576))+(c6*UOp.const(dtypes.int, 9)))+(c7*UOp.const(dtypes.int, 3)))+c8), UOp.const(dtypes.bool, True)).load()
|
||||
c12 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(64), arg=3, src=())
|
||||
c13 = c12.index(c2, UOp.const(dtypes.bool, True)).load()
|
||||
c14 = ((c9*c11).reduce(c6, c7, c8, arg=Ops.ADD)+c13)
|
||||
c15 = c0.index(((((c1*UOp.const(dtypes.int, 2304))+(c2*UOp.const(dtypes.int, 36)))+(c3*UOp.const(dtypes.int, 6)))+c4), UOp.const(dtypes.bool, True)).store(c14, c1, c2, c3, c4)
|
||||
ast = c15.sink()
|
||||
|
||||
# this does have tons of locals
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=3, arg=0),
|
||||
Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=3, arg=2),
|
||||
Opt(op=OptOps.GROUPTOP, axis=0, arg=16)]
|
||||
else:
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(10616832), arg=0, src=())
|
||||
c1 = UOp.range(UOp.const(dtypes.int, 512), 0, AxisType.GLOBAL)
|
||||
c2 = UOp.range(UOp.const(dtypes.int, 64), 1, AxisType.GLOBAL)
|
||||
c3 = UOp.range(UOp.const(dtypes.int, 36), 2, AxisType.GLOBAL)
|
||||
c4 = UOp.range(UOp.const(dtypes.int, 9), 3, AxisType.GLOBAL)
|
||||
c5 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(36864), arg=1, src=())
|
||||
c6 = UOp.range(UOp.const(dtypes.int, 64), 1004, AxisType.REDUCE)
|
||||
c7 = c5.index((((c2*UOp.const(dtypes.int, 9))+c4)+(c6*UOp.const(dtypes.int, 576))), UOp.const(dtypes.bool, True)).load()
|
||||
c8 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(1179648), arg=2, src=())
|
||||
c9 = c8.index((((c1*UOp.const(dtypes.int, 2304))+c3)+(c6*UOp.const(dtypes.int, 36))), UOp.const(dtypes.bool, True)).load()
|
||||
c10 = (c7*c9).reduce(c6, arg=Ops.ADD)
|
||||
c11 = c0.index(((((c1*UOp.const(dtypes.int, 20736))+(c2*UOp.const(dtypes.int, 324)))+(c3*UOp.const(dtypes.int, 9)))+c4), UOp.const(dtypes.bool, True)).store(c10, c1, c2, c3, c4)
|
||||
ast = c11.sink()
|
||||
|
||||
opts = [Opt(op=OptOps.TC, axis=0, arg=(0, 0, 1)), Opt(op=OptOps.UPCAST, axis=2, arg=4),
|
||||
Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=0)]
|
||||
|
||||
prg = get_program(ast, opts=opts)
|
||||
print(prg.src)
|
||||
for i in range(10):
|
||||
with Timing(f"try {i}: "):
|
||||
# NOTE: this doesn't even run the kernel
|
||||
try: CompiledRunner(prg)
|
||||
except RuntimeError: pass
|
||||
@@ -1,8 +1,8 @@
|
||||
import itertools
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start, ImageDType
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import partition, dedup
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
|
||||
def flatten_range(r:UOp):
|
||||
off = range_start[r.op]
|
||||
|
||||
Reference in New Issue
Block a user