mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move IMAGE FLOAT16 logic to allocations (#15095)
* FLOAT16 logic in allocations * cleanup * separate that * only apply when IMAGE == 1 * test passing now * create image buffers earlier
This commit is contained in:
committed by
GitHub
parent
d483e4153a
commit
c70e8af068
@@ -796,7 +796,6 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertIsInstance(out.uop.base.realized.dtype, ImageDType)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
|
||||
@unittest.expectedFailure
|
||||
def test_image_dot_f16_fusion(self):
|
||||
with Context(FLOAT16=1):
|
||||
def cnt():
|
||||
@@ -809,7 +808,8 @@ class TestSchedule(unittest.TestCase):
|
||||
with Context(IMAGE=1): cnt1 = cnt()
|
||||
with Context(IMAGE=2): cnt2 = cnt()
|
||||
|
||||
self.assertEqual(cnt1, cnt2)
|
||||
self.assertEqual(cnt1, 5)
|
||||
self.assertEqual(cnt2, 5)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
|
||||
@unittest.expectedFailure
|
||||
|
||||
@@ -41,15 +41,15 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
# split ranges
|
||||
sink = graph_rewrite(sink, pm_split_ranges+pm_flatten_range, ctx={}, name="split ranges")
|
||||
|
||||
# create image buffers
|
||||
if IMAGE == 1 and ren.device in {"QCOM", "CL"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True)
|
||||
|
||||
# symbolic (NOTE: this is a requirement for pm_simplify_ranges to be correct)
|
||||
sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic")
|
||||
|
||||
# optimize (schedule) the AST
|
||||
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
|
||||
|
||||
# create image buffers
|
||||
if IMAGE == 1 and ren.device in {"QCOM", "CL"}: sink = graph_rewrite(sink, pm_make_images, name="create image buffers", bottom_up=True)
|
||||
|
||||
# do postrange optimization, BEAM or hand_coded_optimizations
|
||||
sink = apply_opts(sink, ren)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, track_rewrites
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize, IMAGE, FLOAT16
|
||||
|
||||
@dataclass
|
||||
class AllocCtx:
|
||||
@@ -95,6 +95,11 @@ def contiguous_mops_to_view(c:UOp):
|
||||
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
|
||||
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (size, offset)).reshape(src.shape).contiguous(tag=c.tag)
|
||||
|
||||
def make_float16(assign:UOp, buf:UOp, val:UOp):
|
||||
if IMAGE != 1 or not FLOAT16: return None
|
||||
new_buf = buf.replace(dtype=dtypes.half, src=(buf.src[0].replace(dtype=dtypes.half), *buf.src[1:]) if buf.op is Ops.RESHAPE else buf.src)
|
||||
return assign.replace(dtype=dtypes.half, src=(new_buf, val.cast(dtypes.half))).cast(dtypes.float)
|
||||
|
||||
pm_early_transform_tensor_graph = PatternMatcher([
|
||||
# CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view),
|
||||
@@ -122,6 +127,8 @@ pm_early_transform_tensor_graph = PatternMatcher([
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
|
||||
# early fixup const copy (TODO: is this wrong if there's a pad?)
|
||||
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
|
||||
# IMAGE FLOAT16: use the texture sampler to store as half and automatically cast float load/store
|
||||
(UPat(Ops.ASSIGN, dtypes.float, src=(UPat.var("buf"), UPat(GroupOp.All-{Ops.COPY}, name="val")), name="assign"), make_float16),
|
||||
])
|
||||
|
||||
def untag_and_append(ctx:AllocCtx, x:UOp):
|
||||
|
||||
@@ -3643,8 +3643,7 @@ class Tensor(OpMixin):
|
||||
|
||||
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
||||
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
||||
if IMAGE == 1 and FLOAT16: x, w = x.cast(dtypes.half).contiguous().cast(dtypes.float), w.cast(dtypes.half).contiguous().cast(dtypes.float)
|
||||
else: x, w = x.contiguous(), w.contiguous()
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
|
||||
if IMAGE == 1 and added_weight: w, H = w[:, :-added_weight, ...], H - added_weight
|
||||
|
||||
|
||||
Reference in New Issue
Block a user