mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
few shapetracker cleanups (#12741)
This commit is contained in:
46
test/external/external_debug_metal_sd_conv.py
vendored
46
test/external/external_debug_metal_sd_conv.py
vendored
@@ -1,46 +0,0 @@
|
||||
# ruff: noqa: E501
|
||||
from tinygrad.codegen.opt.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.codegen.opt.search import bufs_from_lin
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 1, 1, 1), strides=(81920, 0, 64, 8, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.CAST, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
|
||||
UOp(Ops.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.MUL, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 2560, 4, 10, 4, 10), strides=(0, 163840, 0, 64, 0, 8, 0, 1), offset=-9, mask=((0, 1), (0, 2), (0, 1), (0, 2560), (0, 4), (1, 9), (0, 4), (1, 9)), contiguous=False), View(shape=(2, 1, 1280, 8, 8, 2560, 3, 3), strides=(4096000, 0, 0, 40, 1, 1600, 440, 11), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 2560, 3, 3), strides=(0, 0, 23040, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()),
|
||||
x17:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 1280, 8, 8, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=4, src=()),
|
||||
x17,)),)),)),))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=2)]
|
||||
|
||||
k = Kernel(ast)
|
||||
k.apply_opts(opts)
|
||||
bufs = bufs_from_lin(k)
|
||||
|
||||
prg = CompiledRunner(get_program(k.ast, k.opts, k.applied_opts))
|
||||
|
||||
for i in range(10):
|
||||
speed = prg(bufs, var_vals={}, wait=True)
|
||||
print(f"kernel time: {speed*1e3:.2f} ms")
|
||||
|
||||
# on M1 Max
|
||||
# 11ms before block 9b0859d71780fef5cf3831e317f74e53f2483229
|
||||
# 15ms after block cbcc1c20eb09a1342f6581cfbb99632bade982a8
|
||||
55
test/external/external_test_train_gpt2.py
vendored
55
test/external/external_test_train_gpt2.py
vendored
@@ -1,55 +0,0 @@
|
||||
# ruff: noqa: E501
|
||||
import unittest
|
||||
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from .search import Opt, OptOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.codegen.opt.kernel import Kernel
|
||||
|
||||
from test.external.fuzz_linearizer import run_linearizer
|
||||
|
||||
class TestTrainGpt2Kernel(unittest.TestCase):
|
||||
def test_1(self):
|
||||
# kernel 244
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(206045184), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 50304, 1), strides=(51511296, 50304, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(3145728), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 50304, 768), strides=(786432, 768, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(38633472), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 50304, 768), strides=(0, 0, 768, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=3), Opt(op=OptOps.LOCAL, axis=0, arg=2)]
|
||||
kernel = Kernel(ast)
|
||||
kernel.apply_opts(opts)
|
||||
run_linearizer(kernel)
|
||||
|
||||
def test_2(self):
|
||||
# kernel 254
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(3145728), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 1, 768), strides=(786432, 768, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(38633472), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 50304, 768), strides=(0, 0, 768, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(205852672), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 50304, 768), strides=(51463168, 50257, 1, 0), offset=0, mask=((0, 4), (0, 1024), (0, 50257), (0, 768)), contiguous=False),)), src=()),)),)),)),)),))
|
||||
|
||||
opts = [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=4)]
|
||||
kernel = Kernel(ast)
|
||||
kernel.apply_opts(opts)
|
||||
run_linearizer(kernel)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -42,7 +42,7 @@ def get_rewrites_for_renderer(opts:Renderer, optimize:bool=True, linearizer:bool
|
||||
|
||||
@functools.cache
|
||||
def _get_rewrites_for_renderer(opts:Renderer, optimize:bool, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
# ** lowerer **
|
||||
ret: list[RewriteStep] = []
|
||||
|
||||
if optimize:
|
||||
|
||||
@@ -209,7 +209,7 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
|
||||
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
||||
if DEBUG >= 3: print(f"WARNING: this torch load is slow. to permute {intermediate_shape} with {permute_indexes}")
|
||||
assert storage[1] != dtypes.bfloat16, "can't permute BF16"
|
||||
# TODO: find a nice way to support all shapetracker on disktensors
|
||||
# TODO: find a nice way to support all movement ops on disktensors
|
||||
ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
|
||||
|
||||
return ret.reshape(size)
|
||||
|
||||
Reference in New Issue
Block a user