mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix UOp ALU bound (#5844)
* fix UOp ALU bound root cause of resnet bug, the ALU bound is only correct for scalar, not vectorized * it can be nan...
This commit is contained in:
@@ -419,5 +419,26 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.UNROLL, axis=2, amt=0)]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
|
||||
def test_failure_46(self):
|
||||
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
|
||||
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))), src=(
|
||||
LazyOp(BinaryOps.MUL, arg=None, src=(
|
||||
LazyOp(ReduceOps.SUM, arg=(1,), src=(
|
||||
LazyOp(UnaryOps.NEG, arg=None, src=(
|
||||
LazyOp(BinaryOps.MUL, arg=None, src=(
|
||||
LazyOp(UnaryOps.CAST, arg=dtypes.float, src=(
|
||||
LazyOp(BinaryOps.MUL, arg=None, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 1), offset=0, mask=None, contiguous=False),))), src=()),
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(512, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),))), src=()),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=3, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(512, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=4, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),
|
||||
LazyOp(UnaryOps.RECIP, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=5, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))), src=()),)),)),)),))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2)]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Tuple, Any, Set, cast, List, Union, DefaultDict, Callable, Dict
|
||||
import functools, itertools
|
||||
import functools, itertools, math
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
@@ -95,9 +95,9 @@ class UOp:
|
||||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
return None # generic None if we aren't sure
|
||||
@functools.cached_property
|
||||
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None else self.const(dtypes.min(cast(DType, self.dtype)))
|
||||
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.const(dtypes.min(cast(DType, self.dtype)))
|
||||
@functools.cached_property
|
||||
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None else self.const(dtypes.max(cast(DType, self.dtype)))
|
||||
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.const(dtypes.max(cast(DType, self.dtype)))
|
||||
@functools.cached_property
|
||||
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
@@ -106,7 +106,7 @@ class UOp:
|
||||
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
||||
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
|
||||
if self.op is UOps.CONST: return self, self
|
||||
if self.op is UOps.ALU:
|
||||
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
|
||||
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
||||
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
|
||||
return self.const(-s0.vmax.arg), self.const(-s0.vmin.arg)
|
||||
@@ -115,6 +115,7 @@ class UOp:
|
||||
# handle at lease one is non-negative
|
||||
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
|
||||
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
|
||||
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
|
||||
return self.const(Lmin*Rmin), self.const(Lmax*Rmax)
|
||||
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.const(0), self.const(s1.arg-1)
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
||||
|
||||
Reference in New Issue
Block a user