move image dtype fixup [pr] (#7444)

* move image dtype fixup [pr]

* more work

* late dtype

* use base
This commit is contained in:
qazal
2024-10-31 13:51:46 +02:00
committed by GitHub
parent f579693ec9
commit 38b1790575
2 changed files with 10 additions and 13 deletions

View File

@@ -2,8 +2,8 @@ import sys
from collections import defaultdict, deque
from typing import Tuple, List, Dict, DefaultDict
from tinygrad.ops import UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, resolve
from tinygrad.helpers import DEBUG, FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts
from tinygrad.dtype import ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.device import Buffer
@@ -165,15 +165,4 @@ def get_realizes(outs:List[LazyBuffer]) -> Tuple[List[List[Buffer]], Dict[Buffer
raise RuntimeError(f"can't double realize in one schedule, Buffer is realizing both {dup} and {buf}")
lazybufs_to_realize[buf.buffer] = buf
output_groups[reduce_for_op.get(buf, buf)].append(buf.buffer)
# make things that can't be images not images
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
buf.dtype = dtypes.float32
# hack the underlying buffer too
if buf.base is buf:
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
buf.buffer.dtype = dtypes.float32
buf.buffer.options = None
return list(output_groups.values()), lazybufs_to_realize, assign_targets

View File

@@ -49,6 +49,14 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, ctx, cache).view(buf.st)
return ret
# make things that can't be images not images
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}")
# hack the underlying buffer too
buf.dtype = buf.buffer.dtype = buf.dtype.base
assert not buf.is_realized(), "can't fixup allocated buffer"
buf.buffer.options = None
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
# consts are always fused and generated
if buf.op is MetaOps.CONST: