diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py index 82a2ae11fd..2a25d6f361 100644 --- a/tinygrad/engine/fuse.py +++ b/tinygrad/engine/fuse.py @@ -2,8 +2,8 @@ import sys from collections import defaultdict, deque from typing import Tuple, List, Dict, DefaultDict from tinygrad.ops import UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, resolve -from tinygrad.helpers import DEBUG, FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts -from tinygrad.dtype import ImageDType, dtypes +from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts +from tinygrad.dtype import ImageDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.engine.lazy import LazyBuffer from tinygrad.device import Buffer @@ -165,15 +165,4 @@ def get_realizes(outs:List[LazyBuffer]) -> Tuple[List[List[Buffer]], Dict[Buffer raise RuntimeError(f"can't double realize in one schedule, Buffer is realizing both {dup} and {buf}") lazybufs_to_realize[buf.buffer] = buf output_groups[reduce_for_op.get(buf, buf)].append(buf.buffer) - - # make things that can't be images not images - if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or - not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())): - if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32") - buf.dtype = dtypes.float32 - # hack the underlying buffer too - if buf.base is buf: - assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer" - buf.buffer.dtype = dtypes.float32 - buf.buffer.options = None return list(output_groups.values()), lazybufs_to_realize, assign_targets diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ed20058d2e..05088d8d9e 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -49,6 +49,14 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> if buf is not buf.base: cache[buf] = ret = to_uop(buf.base, ctx, cache).view(buf.st) return ret + # make things that can't be images not images + if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or + not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())): + if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}") + # hack the underlying buffer too + buf.dtype = buf.buffer.dtype = buf.dtype.base + assert not buf.is_realized(), "can't fixup allocated buffer" + buf.buffer.options = None dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype # consts are always fused and generated if buf.op is MetaOps.CONST: