fix the last bug, and make HLOP the default

This commit is contained in:
George Hotz
2023-02-28 17:04:28 -08:00
parent fde6c2d62b
commit 4c4d88aad4
3 changed files with 13 additions and 8 deletions

View File

@@ -4,6 +4,7 @@ import numpy as np
import unittest
from tinygrad.tensor import Tensor, Device
from tinygrad.helpers import getenv
from tinygrad.lazy import IMAGE
FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=3):
@@ -131,9 +132,9 @@ class TestOps(unittest.TestCase):
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
def test_matmul_simple(self):
helper_test_op([(2), (2,2)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
def test_matmul(self):
helper_test_op([(65), (65,99)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(64), (64,99)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
def test_gemm(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
def test_broadcastdot(self):
@@ -337,7 +338,7 @@ class TestOps(unittest.TestCase):
cin = 2
helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=IMAGE>=2)
def test_medium_grouped_conv2d(self):
bs = 1
@@ -346,7 +347,7 @@ class TestOps(unittest.TestCase):
cin = 2
helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=IMAGE>=2)
def test_depthwise_conv2d(self):
bs = 1

View File

@@ -80,9 +80,13 @@ def view_from_shape(shape:Tuple[int, ...]) -> View:
def merge_views(vm2:View, vm1:View) -> Optional[View]:
new_strides = []
new_offset = vm2.expr_node(Variable.num(vm1.offset))
assert isinstance(new_offset, NumNode), "new_offset wasn't a number?!?"
for s,st in zip(vm1.shape, vm1.strides):
this_dim = vm2.expr_node(Variable('idx', 0, s-1)*st)
if isinstance(this_dim, NumNode) and this_dim.b == 0:
this_dim = View(vm2.shape, vm2.strides).expr_node(Variable('idx', 0, s-1)*st)
if s == 1:
new_strides.append(0) # all shape 1 can have stride 0
elif isinstance(this_dim, NumNode) and this_dim.b == 0:
new_strides.append(0)
elif isinstance(this_dim, Variable):
new_strides.append(1)
@@ -91,7 +95,7 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]:
else:
if DEBUG >= 3: print("can't simplify", s, this_dim.render())
break
return View(vm1.shape, tuple(new_strides), vm2.offset + vm1.offset) if len(new_strides) == len(vm1.strides) else None
return View(vm1.shape, tuple(new_strides), new_offset.b) if len(new_strides) == len(vm1.strides) else None
class ShapeTracker:
__slots__ = ('views')

View File

@@ -6,7 +6,7 @@ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten
from tinygrad.lazy import Device, LazyBuffer
HLOP = getenv("HLOP", 0)
HLOP = getenv("HLOP", 1)
from tinygrad.image import image_conv2d_decorator