This commit is contained in:
George Hotz
2025-10-07 16:00:10 +08:00
parent 1d7a8b33c1
commit 51f3a5cbb4

View File

@@ -1,5 +1,5 @@
from tinygrad import Device, Tensor, Context
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, AxisType, PatternMatcher, UPat, pm_lower_index_dtype, GroupOp
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, AxisType, PatternMatcher, UPat, pm_lower_index_dtype, GroupOp, KernelInfo
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.helpers import prod
from tinygrad.schedule.rangeify import pm_mops
@@ -61,27 +61,28 @@ if __name__ == "__main__":
a_reg = load(a_reg, gl_a, tg_id_y, k)
b_reg = load(b_reg, gl_b, k, tg_id_x)
d_reg = mma_AB(d_reg, a_reg, b_reg, k)
sink = store(gl_d, d_reg, tg_id_y, tg_id_x).sink()
sink = store(gl_d, d_reg, tg_id_y, tg_id_x).sink(arg=KernelInfo())
sink = graph_rewrite(sink, pm_mops, name="pm_mops")
from tinygrad.codegen.gpudims import pm_add_gpudims
sink = graph_rewrite(sink, pm_add_gpudims, ctx=Device.default.renderer, name="gpudims")
pm_lower_index_dtype_simple = PatternMatcher([
(UPat(GroupOp.All, dtype=dtypes.index, name="x"), lambda x: x.replace(dtype=dtypes.int))
])
sink = graph_rewrite(sink, pm_lower_index_dtype_simple, name="index_dtype")
#sink = graph_rewrite(sink, pm_lower_index_dtype, name="index_dtype")
from tinygrad.codegen import rewrites_for_linearizer, apply_rewrites
lin = apply_rewrites(sink, rewrites_for_linearizer)
src = Device.default.renderer.render(lin.arg.lst)
print(src)
#exit(0)
from tinygrad.engine.realize import CompiledRunner, ExecItem
from tinygrad.renderer import ProgramSpec
ps = ProgramSpec("test", src, Device.DEFAULT, sink, lin.arg.lst)
ps = ProgramSpec("test", src, Device.DEFAULT, sink, lin.arg.lst, [1,1,1], [1,1,1])
run = CompiledRunner(ps)
a = Tensor.randn(N, N)