move IMAGE FLOAT16 logic to allocations (#15095)

* FLOAT16 logic in allocations

* cleanup

* separate that

* only apply when IMAGE == 1

* test passing now

* create image buffers earlier
This commit is contained in:
Christopher Milan
2026-03-02 19:00:05 -08:00
committed by GitHub
parent d483e4153a
commit c70e8af068
4 changed files with 15 additions and 9 deletions

View File

@@ -796,7 +796,6 @@ class TestSchedule(unittest.TestCase):
self.assertIsInstance(out.uop.base.realized.dtype, ImageDType)
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
@unittest.expectedFailure
def test_image_dot_f16_fusion(self):
with Context(FLOAT16=1):
def cnt():
@@ -809,7 +808,8 @@ class TestSchedule(unittest.TestCase):
with Context(IMAGE=1): cnt1 = cnt()
with Context(IMAGE=2): cnt2 = cnt()
self.assertEqual(cnt1, cnt2)
self.assertEqual(cnt1, 5)
self.assertEqual(cnt2, 5)
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
@unittest.expectedFailure

View File

@@ -41,15 +41,15 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# split ranges
sink = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="split ranges")
# create image buffers
if IMAGE == 1 and ren.device in {"QCOM", "CL"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True)
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic")
# optimize (schedule) the AST
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
# create image buffers
if IMAGE == 1 and ren.device in {"QCOM", "CL"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True)
# do postrange optimization, BEAM or hand_coded_optimizations
sink = apply_opts(sink, ren)

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, track_rewrites
from tinygrad.dtype import ImageDType
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize, IMAGE, FLOAT16
@dataclass
class AllocCtx:
@@ -95,6 +95,11 @@ def contiguous_mops_to_view(c:UOp):
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (size, offset)).reshape(src.shape).contiguous(tag=c.tag)
def make_float16(assign:UOp, buf:UOp, val:UOp):
if IMAGE != 1 or not FLOAT16: return None
new_buf = buf.replace(dtype=dtypes.half, src=(buf.src[0].replace(dtype=dtypes.half), *buf.src[1:]) if buf.op is Ops.RESHAPE else buf.src)
return assign.replace(dtype=dtypes.half, src=(new_buf, val.cast(dtypes.half))).cast(dtypes.float)
pm_early_transform_tensor_graph = PatternMatcher([
# CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view),
@@ -122,6 +127,8 @@ pm_early_transform_tensor_graph = PatternMatcher([
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
# early fixup const copy (TODO: is this wrong if there's a pad?)
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
# IMAGE FLOAT16: use the texture sampler to store as half and automatically cast float load/store
(UPat(Ops.ASSIGN, dtypes.float, src=(UPat.var("buf"), UPat(GroupOp.All-{Ops.COPY}, name="val")), name="assign"), make_float16),
])
def untag_and_append(ctx:AllocCtx, x:UOp):

View File

@@ -3643,8 +3643,7 @@ class Tensor(OpMixin):
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
if IMAGE == 1 and FLOAT16: x, w = x.cast(dtypes.half).contiguous().cast(dtypes.float), w.cast(dtypes.half).contiguous().cast(dtypes.float)
else: x, w = x.contiguous(), w.contiguous()
x, w = x.contiguous(), w.contiguous()
if IMAGE == 1 and added_weight: w, H = w[:, :-added_weight, ...], H - added_weight