mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
gpu
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from tinygrad import Device, Tensor, Context
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, AxisType, PatternMatcher, UPat, pm_lower_index_dtype, GroupOp
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, AxisType, PatternMatcher, UPat, pm_lower_index_dtype, GroupOp, KernelInfo
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.schedule.rangeify import pm_mops
|
||||
@@ -61,27 +61,28 @@ if __name__ == "__main__":
|
||||
a_reg = load(a_reg, gl_a, tg_id_y, k)
|
||||
b_reg = load(b_reg, gl_b, k, tg_id_x)
|
||||
d_reg = mma_AB(d_reg, a_reg, b_reg, k)
|
||||
sink = store(gl_d, d_reg, tg_id_y, tg_id_x).sink()
|
||||
sink = store(gl_d, d_reg, tg_id_y, tg_id_x).sink(arg=KernelInfo())
|
||||
|
||||
sink = graph_rewrite(sink, pm_mops, name="pm_mops")
|
||||
|
||||
from tinygrad.codegen.gpudims import pm_add_gpudims
|
||||
sink = graph_rewrite(sink, pm_add_gpudims, ctx=Device.default.renderer, name="gpudims")
|
||||
|
||||
pm_lower_index_dtype_simple = PatternMatcher([
|
||||
(UPat(GroupOp.All, dtype=dtypes.index, name="x"), lambda x: x.replace(dtype=dtypes.int))
|
||||
])
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype_simple, name="index_dtype")
|
||||
|
||||
|
||||
#sink = graph_rewrite(sink, pm_lower_index_dtype, name="index_dtype")
|
||||
|
||||
from tinygrad.codegen import rewrites_for_linearizer, apply_rewrites
|
||||
lin = apply_rewrites(sink, rewrites_for_linearizer)
|
||||
src = Device.default.renderer.render(lin.arg.lst)
|
||||
print(src)
|
||||
#exit(0)
|
||||
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
|
||||
ps = ProgramSpec("test", src, Device.DEFAULT, sink, lin.arg.lst)
|
||||
ps = ProgramSpec("test", src, Device.DEFAULT, sink, lin.arg.lst, [1,1,1], [1,1,1])
|
||||
run = CompiledRunner(ps)
|
||||
|
||||
a = Tensor.randn(N, N)
|
||||
|
||||
Reference in New Issue
Block a user