From 8a508682648cb21075f0989b9da744b2b7bca8ca Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 13 Dec 2024 13:07:00 -0800 Subject: [PATCH] touchup function.py [pr] (#8220) * touchup function.py [pr] * remove ALLOWED_READ_IMAGE * eh, keep it, just change it --- .github/workflows/test.yml | 2 +- tinygrad/function.py | 5 +++-- tinygrad/viz/serve.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4feb97df87..0ddced9cb7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -295,7 +295,7 @@ jobs: - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model kernel count and gate usage run: | - PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2131 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx + PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2138 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx - if: ${{ matrix.task == 'optimage' }} name: Test openpilot alt model correctness (float32) run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx diff --git a/tinygrad/function.py b/tinygrad/function.py index 42b46f60be..ca66696c95 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -41,7 +41,7 @@ class Sin(Function): class Relu(Function): def forward(self, x:UOp) -> UOp: - self.ret = x.maximum(0) + self.ret = (x>0).where(x, 0) return self.ret def backward(self, grad_output:UOp) -> UOp: return (self.ret>0).cast(grad_output.dtype) * grad_output @@ -79,7 +79,8 @@ class Sigmoid(Function): return (self.ret * (1 - self.ret)) * grad_output class Sign(Function): - def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + # NOTE: the x*0 is to match torch behavior without function.py + def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0 # backward always return 0 to match torch def backward(self, grad_output:UOp) -> UOp: return grad_output.const_like(0) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index be968cbac9..2aab64bb89 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -62,7 +62,8 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: graph: Dict[int, Tuple[str, str, List[int], str, str]] = {} for u in x.toposort: if u.op is Ops.CONST: continue - label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}" + argst = ("\n".join([f"{v.shape} / {v.strides}"+(f" / {v.offset}" if v.offset else "") for v in u.arg.views])) if u.op is Ops.VIEW else str(u.arg) + label = f"{str(u.op).split('.')[1]}{(' '+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}" for idx,x in enumerate(u.src): if x.op is Ops.CONST: label += f"\nCONST{idx} {x.arg:g}" graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not Ops.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))