mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
FLOAT16 logic for IMAGE==1 goes back to image_conv2d (#15105)
This commit is contained in:
committed by
GitHub
parent
529318259c
commit
5f6b610da1
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, identity_element, track_rewrites
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize, IMAGE, FLOAT16
|
||||
from tinygrad.helpers import prod, DEBUG, argsort, VIZ, pluralize, FLOAT16
|
||||
|
||||
@dataclass
|
||||
class AllocCtx:
|
||||
@@ -61,13 +61,13 @@ def replace_assign_with_contig(u:UOp):
|
||||
return u.src[1].contiguous(tag=u.tag)
|
||||
|
||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
||||
x = src
|
||||
while x is not src.base:
|
||||
if (x:=src).op is Ops.CAST and x.dtype == dtypes.half and FLOAT16: x, contig = x.src[0], contig.cast(dtypes.float)
|
||||
while x is not x.base:
|
||||
if x.op is Ops.PERMUTE: contig = contig.permute(argsort(x.marg))
|
||||
elif x.op is Ops.RESHAPE: contig = contig.reshape(x.src[0].shape)
|
||||
else: return None
|
||||
x = x.src[0]
|
||||
ctx[src.base] = contig
|
||||
ctx[x] = contig
|
||||
|
||||
def contiguous_mops_to_view(c:UOp):
|
||||
"""CONTIGUOUS(MOPS(BUFFER)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to a contiguous range."""
|
||||
@@ -95,17 +95,12 @@ def contiguous_mops_to_view(c:UOp):
|
||||
# NOTE: this contiguous is removed because this BUFFER_VIEW/RESHAPE has_buffer_identity
|
||||
return UOp(Ops.BUFFER_VIEW, src.dtype, (buf,), (size, offset)).reshape(src.shape).contiguous(tag=c.tag)
|
||||
|
||||
def make_float16(assign:UOp, buf:UOp, val:UOp):
|
||||
if IMAGE != 1 or not FLOAT16: return None
|
||||
new_buf = buf.replace(dtype=dtypes.half, src=(buf.src[0].replace(dtype=dtypes.half), *buf.src[1:]) if buf.op is Ops.RESHAPE else buf.src)
|
||||
return assign.replace(dtype=dtypes.half, src=(new_buf, val.cast(dtypes.half))).cast(dtypes.float)
|
||||
|
||||
pm_early_transform_tensor_graph = PatternMatcher([
|
||||
# CONTIGUOUS(MOPS(BUFFER/BUFFER_VIEW)) → CONTIGUOUS(BUFFER_VIEW) when movement ops collapse to contiguous range
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement),), name="c"), contiguous_mops_to_view),
|
||||
|
||||
# *** CONTIGUOUS replacement hack for openpilot ***
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="contig"), found_contiguous),
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat((*GroupOp.Movement, Ops.CAST), name="src"),), name="contig"), found_contiguous),
|
||||
# replace ALU sources with contiguous versions found above
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
|
||||
@@ -127,8 +122,6 @@ pm_early_transform_tensor_graph = PatternMatcher([
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x._shape is not None and x.size == 0 else None),
|
||||
# early fixup const copy (TODO: is this wrong if there's a pad?)
|
||||
(UPat(Ops.COPY, src=(UPat.var("s"), UPat()), name="c"), lambda c,s: c.const_like(ss.arg) if (ss:=s.base).op is Ops.CONST else None),
|
||||
# IMAGE FLOAT16: use the texture sampler to store as half and automatically cast float load/store
|
||||
(UPat(Ops.ASSIGN, dtypes.float, src=(UPat.var("buf"), UPat(GroupOp.All-{Ops.COPY}, name="val")), name="assign"), make_float16),
|
||||
])
|
||||
|
||||
def untag_and_append(ctx:AllocCtx, x:UOp):
|
||||
|
||||
@@ -3643,7 +3643,8 @@ class Tensor(OpMixin):
|
||||
|
||||
# contiguous creates the image, and early realize static weights (TODO: test for the static weight)
|
||||
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
if IMAGE == 1 and FLOAT16: x, w = x.cast(dtypes.half).contiguous().cast(dtypes.float), w.cast(dtypes.half).contiguous().cast(dtypes.float)
|
||||
else: x, w = x.contiguous(), w.contiguous()
|
||||
|
||||
if IMAGE == 1 and added_weight: w, H = w[:, :-added_weight, ...], H - added_weight
|
||||
|
||||
|
||||
Reference in New Issue
Block a user