FLOAT16 logic for IMAGE==1 goes back to image_conv2d (#15105)

This commit is contained in:
Christopher Milan
2026-03-03 02:37:57 -08:00
committed by GitHub
parent 529318259c
commit 5f6b610da1
2 changed files with 7 additions and 13 deletions

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, track_rewrites
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize, IMAGE, FLOAT16
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize, FLOAT16
@dataclass
class AllocCtx:
@@ -61,13 +61,13 @@ def replace_assign_with_contig(u:UOp):
return u.src[1].contiguous(tag=u.tag)
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
x = src
while x is not src.base:
if (x:=src).op is Ops.CAST and x.dtype == dtypes.half and FLOAT16: x, contig = x.src[0], contig.cast(dtypes.float)
while x is not x.base:
if x.op is Ops.PERMUTE: contig = contig.permute(argsort(x.marg))
elif x.op is Ops.RESHAPE: contig = contig.reshape(x.src[0].shape)
else: return None
x = x.src[0]
ctx[src.base] = contig
ctx[x] = contig
def contiguous_mops_to_view(c:UOp):
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
@@ -95,17 +95,12 @@ def contiguous_mops_to_view(c:UOp):
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (size, offset)).reshape(src.shape).contiguous(tag=c.tag)
def make_float16(assign:UOp, buf:UOp, val:UOp):
if IMAGE != 1 or not FLOAT16: return None
new_buf = buf.replace(dtype=dtypes.half, src=(buf.src[0].replace(dtype=dtypes.half), *buf.src[1:]) if buf.op is Ops.RESHAPE else buf.src)
return assign.replace(dtype=dtypes.half, src=(new_buf, val.cast(dtypes.half))).cast(dtypes.float)
pm_early_transform_tensor_graph = PatternMatcher([
# CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view),
# *** CONTIGUOUS replacement hack for openpilot ***
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="contig"), found_contiguous),
(UPat(Ops.CONTIGUOUS, src=(UPat((*GroupOp.Movement, Ops.CAST), name="src"),), name="contig"), found_contiguous),
# replace ALU sources with contiguous versions found above
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
@@ -127,8 +122,6 @@ pm_early_transform_tensor_graph = PatternMatcher([
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
# early fixup const copy (TODO: is this wrong if there's a pad?)
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
# IMAGE FLOAT16: use the texture sampler to store as half and automatically cast float load/store
(UPat(Ops.ASSIGN, dtypes.float, src=(UPat.var("buf"), UPat(GroupOp.All-{Ops.COPY}, name="val")), name="assign"), make_float16),
])
def untag_and_append(ctx:AllocCtx, x:UOp):

View File

@@ -3643,7 +3643,8 @@ class Tensor(OpMixin):
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
x, w = x.contiguous(), w.contiguous()
if IMAGE == 1 and FLOAT16: x, w = x.cast(dtypes.half).contiguous().cast(dtypes.float), w.cast(dtypes.half).contiguous().cast(dtypes.float)
else: x, w = x.contiguous(), w.contiguous()
if IMAGE == 1 and added_weight: w, H = w[:, :-added_weight, ...], H - added_weight