put load early to make pointers match (#11524)

This commit is contained in:
George Hotz
2025-08-05 20:04:32 -07:00
committed by GitHub
parent 92175626e3
commit cf66df0ea6
5 changed files with 15 additions and 11 deletions

View File

@@ -14,7 +14,7 @@ from tinygrad.dtype import ImageDType, AddrSpace
from tinygrad.helpers import all_same, colored, ansilen, dedup, prod, round_up, to_function_name, unwrap, argfix, DEBUG, TC_SELECT, TC_OPT, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import strides_for_shape, get_contraction
from tinygrad.opt.swizzler import view_left
from tinygrad.opt.swizzler import view_left, view_left_through_load
class OptOps(Enum):
TC = auto(); UPCAST = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
@@ -503,4 +503,4 @@ class Kernel:
self.finalized = True
fixed_ast = fixup_ast(self.ast)
del fixup_ast
return graph_rewrite(fixed_ast, view_left, name="fixup optimized AST")
return graph_rewrite(fixed_ast, view_left+view_left_through_load, name="fixup optimized AST")

View File

@@ -44,12 +44,18 @@ def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
view_left = merge_views+PatternMatcher([
# view before elementwise and buffer ops
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.LOAD, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
# if there's ones added after reduce, put this before the reduce
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
])
view_left_through_load = PatternMatcher([
# view before load
(UPat(Ops.VIEW, src=(UPat(Ops.LOAD, name="e"),), name="view"),
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
])
def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left")
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
@@ -96,7 +102,7 @@ view_right = merge_views+PatternMatcher([
# apply view after reduceops
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
# apply view after elementwise ops
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS, Ops.LOAD, Ops.STORE}, name="root"), elementwise_view_right),
(UPat(GroupOp.All-{Ops.SINK, Ops.REDUCE_AXIS, Ops.STORE}, name="root"), elementwise_view_right),
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
@@ -112,9 +118,7 @@ def check_load_st(glbl:UOp, view:UOp):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
fix_kernel_ops = PatternMatcher([
# add the LOAD
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: x.replace(tag=None).view(x.st).load() if x.tag is not None else None),
fix_kernel_ops = view_left_through_load+PatternMatcher([
# STORE (except for meta ops)
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), lambda sink:
UOp.sink(*[UOp.store(UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(s.st.real_size()), (), i).view(s.st), s) for i,x in enumerate(sink.src)])),

View File

@@ -3,7 +3,7 @@ from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND,
from tinygrad.shape.shapetracker import ShapeTracker
ALWAYS_CONTIGUOUS = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL}
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL, Ops.LOAD}
# **** Grouper decides which of the UOps realize

View File

@@ -151,7 +151,7 @@ create_kernels = PatternMatcher([
early_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x), tag=1)),
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).view(x.st).load()),
# no SINK for meta ops
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
])

View File

@@ -154,8 +154,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
sz = cast(PtrDType, self.dtype).size
return ShapeTracker.from_shape((sz,)) if sz > 0 else None
# hack for PTX, CASTing the ptr loses the shape. even worse hack with tag
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL and self.src[0].tag is None: return None
# hack for PTX, CASTing the ptr loses the shape
if self.op is Ops.CAST and self.src[0].op is Ops.DEFINE_GLOBAL: return None
# otherwise we get the shape from sources
if not (src_sts := [x.st for x in self.src if x.st is not None]): return None