diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 4e37205c2e..392a7dd92e 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -138,12 +138,11 @@ class ImageDType(PtrDType): # get list of (height, width) that do not require pitch padding @staticmethod def valid_dims(ptr:PtrDType) -> list[tuple[int,int]]: - ALIGN, MAXW = getenv("IMAGE_PITCH_ALIGN", 256 if OSX else 64), 16384 - if ptr.base not in (dtypes.half, dtypes.float) or ptr.size > 4*MAXW*MAXW or (ptr.size if OSX else ptr.nbytes()) % ALIGN != 0: return [] - if OSX and (ptr.size // 4) % ALIGN: return [] # OSX has stricter requirements for height=1 images - pxls: int = ptr.size // 4 - return ([(1, pxls)] * (pxls < MAXW) + [(pxls//ALIGN//k, ALIGN*k) for k in range(ceildiv(pxls//ALIGN, MAXW), min(pxls//ALIGN, MAXW//ALIGN)+1) - if (pxls//ALIGN)%k == 0] if pxls//ALIGN else []) + ALIGN, MAXW, pxls = getenv("IMAGE_PITCH_ALIGN", 256 if OSX else 64), 16384, ptr.size // 4 + if ptr.base not in (dtypes.half, dtypes.float) or ptr.size > 4*MAXW*MAXW: return [] + # OSX has stricter requirements for height=1 images + if ptr.size % (ALIGN * 4) != 0: return [] if OSX or ptr.nbytes() % getenv("IMAGE_BASE_ALIGN", 64) != 0 else [(1, pxls)] + return [(pxls//ALIGN//k, ALIGN*k) for k in range(ceildiv(pxls//ALIGN, MAXW), min(pxls//ALIGN, MAXW//ALIGN)+1) if (pxls//ALIGN)%k == 0] class dtypes: @staticmethod