diff --git a/extra/thunder/clone.py b/extra/thunder/clone.py index 8153025f13..72453533d4 100644 --- a/extra/thunder/clone.py +++ b/extra/thunder/clone.py @@ -1,5 +1,5 @@ from tinygrad import Device, Tensor, Context -from tinygrad.uop.ops import UOp, Ops, graph_rewrite, AxisType, PatternMatcher, UPat, pm_lower_index_dtype, GroupOp +from tinygrad.uop.ops import UOp, Ops, graph_rewrite, AxisType, PatternMatcher, UPat, pm_lower_index_dtype, GroupOp, KernelInfo from tinygrad.dtype import dtypes, AddrSpace from tinygrad.helpers import prod from tinygrad.schedule.rangeify import pm_mops @@ -61,27 +61,28 @@ if __name__ == "__main__": a_reg = load(a_reg, gl_a, tg_id_y, k) b_reg = load(b_reg, gl_b, k, tg_id_x) d_reg = mma_AB(d_reg, a_reg, b_reg, k) - sink = store(gl_d, d_reg, tg_id_y, tg_id_x).sink() + sink = store(gl_d, d_reg, tg_id_y, tg_id_x).sink(arg=KernelInfo()) sink = graph_rewrite(sink, pm_mops, name="pm_mops") + from tinygrad.codegen.gpudims import pm_add_gpudims + sink = graph_rewrite(sink, pm_add_gpudims, ctx=Device.default.renderer, name="gpudims") + pm_lower_index_dtype_simple = PatternMatcher([ (UPat(GroupOp.All, dtype=dtypes.index, name="x"), lambda x: x.replace(dtype=dtypes.int)) ]) sink = graph_rewrite(sink, pm_lower_index_dtype_simple, name="index_dtype") - - #sink = graph_rewrite(sink, pm_lower_index_dtype, name="index_dtype") - from tinygrad.codegen import rewrites_for_linearizer, apply_rewrites lin = apply_rewrites(sink, rewrites_for_linearizer) src = Device.default.renderer.render(lin.arg.lst) print(src) + #exit(0) from tinygrad.engine.realize import CompiledRunner, ExecItem from tinygrad.renderer import ProgramSpec - ps = ProgramSpec("test", src, Device.DEFAULT, sink, lin.arg.lst) + ps = ProgramSpec("test", src, Device.DEFAULT, sink, lin.arg.lst, [1,1,1], [1,1,1]) run = CompiledRunner(ps) a = Tensor.randn(N, N)