mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
move image dtype fixup [pr] (#7444)
* move image dtype fixup [pr] * more work * late dtype * use base
This commit is contained in:
@@ -2,8 +2,8 @@ import sys
|
||||
from collections import defaultdict, deque
|
||||
from typing import Tuple, List, Dict, DefaultDict
|
||||
from tinygrad.ops import UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, resolve
|
||||
from tinygrad.helpers import DEBUG, FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, prod, dedup, all_int, merge_dicts
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.device import Buffer
|
||||
@@ -165,15 +165,4 @@ def get_realizes(outs:List[LazyBuffer]) -> Tuple[List[List[Buffer]], Dict[Buffer
|
||||
raise RuntimeError(f"can't double realize in one schedule, Buffer is realizing both {dup} and {buf}")
|
||||
lazybufs_to_realize[buf.buffer] = buf
|
||||
output_groups[reduce_for_op.get(buf, buf)].append(buf.buffer)
|
||||
|
||||
# make things that can't be images not images
|
||||
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
||||
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
||||
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
|
||||
buf.dtype = dtypes.float32
|
||||
# hack the underlying buffer too
|
||||
if buf.base is buf:
|
||||
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
|
||||
buf.buffer.dtype = dtypes.float32
|
||||
buf.buffer.options = None
|
||||
return list(output_groups.values()), lazybufs_to_realize, assign_targets
|
||||
|
||||
@@ -49,6 +49,14 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = to_uop(buf.base, ctx, cache).view(buf.st)
|
||||
return ret
|
||||
# make things that can't be images not images
|
||||
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
||||
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
||||
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}")
|
||||
# hack the underlying buffer too
|
||||
buf.dtype = buf.buffer.dtype = buf.dtype.base
|
||||
assert not buf.is_realized(), "can't fixup allocated buffer"
|
||||
buf.buffer.options = None
|
||||
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
|
||||
# consts are always fused and generated
|
||||
if buf.op is MetaOps.CONST:
|
||||
|
||||
Reference in New Issue
Block a user