mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
style cleanups in allocations.py [pr] (#15295)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.helpers import prod, DEBUG, VIZ, pluralize
|
||||
from tinygrad.helpers import prod, DEBUG, VIZ, pluralize, all_int
|
||||
|
||||
@dataclass
|
||||
class AllocCtx:
|
||||
@@ -66,15 +66,14 @@ def replace_store_after_with_contig(u:UOp, src:UOp):
|
||||
while assigned_to.op in {Ops.BITCAST, Ops.AFTER}: assigned_to = assigned_to.src[0].base
|
||||
if assigned_to.op is not Ops.BUFFER: return src.contiguous(tag=u.tag)
|
||||
|
||||
def contiguous_mops_to_view(c:UOp):
|
||||
def contiguous_mops_to_view(c:UOp, src:UOp):
|
||||
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
|
||||
src = c.src[0]
|
||||
buf = src.base
|
||||
if buf.op not in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
|
||||
if src.op is Ops.RESHAPE and src.src[0].op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return None
|
||||
|
||||
# no symbolic shape
|
||||
if not all(isinstance(x, int) for x in c.shape): return None
|
||||
if not all_int(c.shape): return None
|
||||
|
||||
# check if view is supported
|
||||
if not isinstance(c.device, str): return None
|
||||
@@ -82,8 +81,7 @@ def contiguous_mops_to_view(c:UOp):
|
||||
if not hasattr(Device[c.device].allocator, "_offset"): return None
|
||||
|
||||
# see if this can be a view
|
||||
offset = src.contiguous_view_offset()
|
||||
if offset is None: return None
|
||||
if (offset := src.contiguous_view_offset()) is None: return None
|
||||
|
||||
# merge BUFFER_VIEWs
|
||||
if buf.op is Ops.BUFFER_VIEW: offset, buf = offset + buf.arg[1], buf.src[0]
|
||||
@@ -109,8 +107,7 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
|
||||
# if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape
|
||||
# NOTE: must use resolved shapes from the CALL (which substitutes PARAMs with external args), not raw body shapes
|
||||
rets = tuple(r.shrink(tuple((0, s) for s in rs.shape)) if any(isinstance(x, UOp) for x in rs.shape) else r
|
||||
for r,rs in zip(rets, resolved))
|
||||
rets = tuple(r.shrink_to(rs.shape) for r,rs in zip(rets, resolved))
|
||||
|
||||
# return tuple if tuple
|
||||
return UOp.maketuple(*rets) if c.src[0].op is Ops.TUPLE else rets[0]
|
||||
@@ -121,7 +118,7 @@ pm_early_transform_tensor_graph = PatternMatcher([
|
||||
(UPat(Ops.CALL, name="c"), transform_precompiled_call),
|
||||
|
||||
# CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view),
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="c"), contiguous_mops_to_view),
|
||||
|
||||
# add CONTIGUOUS to tagged UOps
|
||||
(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.AFTER, Ops.STORE}, name="x"),
|
||||
|
||||
Reference in New Issue
Block a user