From 837b06c60948c27fc45de968d0d5b388c216f980 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 16 Mar 2026 05:45:24 -0400 Subject: [PATCH] style cleanups in allocations.py [pr] (#15295) --- tinygrad/engine/allocations.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index 1125deb7e7..7fa477156c 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites from tinygrad.dtype import dtypes, ImageDType -from tinygrad.helpers import prod, DEBUG, VIZ, pluralize +from tinygrad.helpers import prod, DEBUG, VIZ, pluralize, all_int @dataclass class AllocCtx: @@ -66,15 +66,14 @@ def replace_store_after_with_contig(u:UOp, src:UOp): while assigned_to.op in {Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base if assigned_to.op is not Ops.BUFFER: return src.contiguous(tag=u.tag) -def contiguous_mops_to_view(c:UOp): +def contiguous_mops_to_view(c:UOp, src:UOp): """CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range.""" - src = c.src[0] buf = src.base if buf.op not in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None if src.op is Ops.RESHAPE and src.src[0].op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None # no symbolic shape - if not all(isinstance(x, int) for x in c.shape): return None + if not all_int(c.shape): return None # check if view is supported if not isinstance(c.device, str): return None @@ -82,8 +81,7 @@ def contiguous_mops_to_view(c:UOp): if not hasattr(Device[c.device].allocator, "_offset"): return None # see if this can be a view - offset = src.contiguous_view_offset() - if offset is None: return None + if (offset := src.contiguous_view_offset()) is None: return None # merge BUFFER_VIEWs if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0] @@ -109,8 +107,7 @@ def transform_precompiled_call(c:UOp) -> UOp|None: # if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape # NOTE: must use resolved shapes from the CALL (which substitutes PARAMs with external args), not raw body shapes - rets = tuple(r.shrink(tuple((0, s) for s in rs.shape)) if any(isinstance(x, UOp) for x in rs.shape) else r - for r,rs in zip(rets, resolved)) + rets = tuple(r.shrink_to(rs.shape) for r,rs in zip(rets, resolved)) # return tuple if tuple return UOp.maketuple(*rets) if c.src[0].op is Ops.TUPLE else rets[0] @@ -121,7 +118,7 @@ pm_early_transform_tensor_graph = PatternMatcher([ (UPat(Ops.CALL, name="c"), transform_precompiled_call), # CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range - (UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view), + (UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="c"), contiguous_mops_to_view), # add CONTIGUOUS to tagged UOps (UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.AFTER, Ops.STORE}, name="x"),