fix UOp ALU bound (#5844)

* fix UOp ALU bound

root cause of resnet bug, the ALU bound is only correct for scalar, not vectorized

* it can be nan...
This commit is contained in:
chenyu
2024-07-31 15:19:31 -04:00
committed by GitHub
parent 5eedd9e3ad
commit 4fe5b95568
2 changed files with 26 additions and 4 deletions

View File

@@ -419,5 +419,26 @@ class TestLinearizerFailures(unittest.TestCase):
opts = [Opt(op=OptOps.UNROLL, axis=2, amt=0)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
def test_failure_46(self):
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))), src=(
LazyOp(BinaryOps.MUL, arg=None, src=(
LazyOp(ReduceOps.SUM, arg=(1,), src=(
LazyOp(UnaryOps.NEG, arg=None, src=(
LazyOp(BinaryOps.MUL, arg=None, src=(
LazyOp(UnaryOps.CAST, arg=dtypes.float, src=(
LazyOp(BinaryOps.MUL, arg=None, src=(
LazyOp(BinaryOps.CMPNE, arg=None, src=(
LazyOp(BinaryOps.CMPNE, arg=None, src=(
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 1), offset=0, mask=None, contiguous=False),))), src=()),
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(512, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),))), src=()),)),
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=3, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(512, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),
LazyOp(UnaryOps.RECIP, arg=None, src=(
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))), src=()),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
if __name__ == '__main__':
unittest.main()

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Optional, Tuple, Any, Set, cast, List, Union, DefaultDict, Callable, Dict
import functools, itertools
import functools, itertools, math
from collections import defaultdict
from enum import Enum, auto
from dataclasses import dataclass
@@ -95,9 +95,9 @@ class UOp:
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
@functools.cached_property
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None else self.const(dtypes.min(cast(DType, self.dtype)))
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.const(dtypes.min(cast(DType, self.dtype)))
@functools.cached_property
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None else self.const(dtypes.max(cast(DType, self.dtype)))
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.const(dtypes.max(cast(DType, self.dtype)))
@functools.cached_property
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
# NOTE: returned UOp is assumed to be CONST
@@ -106,7 +106,7 @@ class UOp:
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
if self.op is UOps.CONST: return self, self
if self.op is UOps.ALU:
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
return self.const(-s0.vmax.arg), self.const(-s0.vmin.arg)
@@ -115,6 +115,7 @@ class UOp:
# handle at lease one is non-negative
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
return self.const(Lmin*Rmin), self.const(Lmax*Rmax)
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.const(0), self.const(s1.arg-1)
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST: