From 5f6b610da1e5adcb9d71f04f430936796eadc7fe Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Tue, 3 Mar 2026 02:37:57 -0800 Subject: [PATCH] FLOAT16 logic for IMAGE==1 goes back to image_conv2d (#15105) --- tinygrad/engine/allocations.py | 17 +++++------------ tinygrad/tensor.py | 3 ++- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index 22e12416d8..7eb2470e83 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, track_rewrites from tinygrad.dtype import dtypes, ImageDType -from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize, IMAGE, FLOAT16 +from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize, FLOAT16 @dataclass class AllocCtx: @@ -61,13 +61,13 @@ def replace_assign_with_contig(u:UOp): return u.src[1].contiguous(tag=u.tag) def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): - x = src - while x is not src.base: + if (x:=src).op is Ops.CAST and x.dtype == dtypes.half and FLOAT16: x, contig = x.src[0], contig.cast(dtypes.float) + while x is not x.base: if x.op is Ops.PERMUTE: contig = contig.permute(argsort(x.marg)) elif x.op is Ops.RESHAPE: contig = contig.reshape(x.src[0].shape) else: return None x = x.src[0] - ctx[src.base] = contig + ctx[x] = contig def contiguous_mops_to_view(c:UOp): """CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range.""" @@ -95,17 +95,12 @@ def contiguous_mops_to_view(c:UOp): # NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (size, offset)).reshape(src.shape).contiguous(tag=c.tag) -def make_float16(assign:UOp, buf:UOp, val:UOp): - if IMAGE != 1 or not FLOAT16: return None - new_buf = buf.replace(dtype=dtypes.half, src=(buf.src[0].replace(dtype=dtypes.half), *buf.src[1:]) if buf.op is Ops.RESHAPE else buf.src) - return assign.replace(dtype=dtypes.half, src=(new_buf, val.cast(dtypes.half))).cast(dtypes.float) - pm_early_transform_tensor_graph = PatternMatcher([ # CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range (UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view), # *** CONTIGUOUS replacement hack for openpilot *** - (UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="contig"), found_contiguous), + (UPat(Ops.CONTIGUOUS, src=(UPat((*GroupOp.Movement, Ops.CAST), name="src"),), name="contig"), found_contiguous), # replace ALU sources with contiguous versions found above (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), @@ -127,8 +122,6 @@ pm_early_transform_tensor_graph = PatternMatcher([ (UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None), # early fixup const copy (TODO: is this wrong if there's a pad?) (UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None), - # IMAGE FLOAT16: use the texture sampler to store as half and automatically cast float load/store - (UPat(Ops.ASSIGN, dtypes.float, src=(UPat.var("buf"), UPat(GroupOp.All-{Ops.COPY}, name="val")), name="assign"), make_float16), ]) def untag_and_append(ctx:AllocCtx, x:UOp): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e9d43736bd..fe9f1a8e93 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3643,7 +3643,8 @@ class Tensor(OpMixin): # contiguous creates the image, and early realize static weights (TODO: test for the static weight) if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4))) - x, w = x.contiguous(), w.contiguous() + if IMAGE == 1 and FLOAT16: x, w = x.cast(dtypes.half).contiguous().cast(dtypes.float), w.cast(dtypes.half).contiguous().cast(dtypes.float) + else: x, w = x.contiguous(), w.contiguous() if IMAGE == 1 and added_weight: w, H = w[:, :-added_weight, ...], H - added_weight