From ef1100fdff8c1d5d78266ef35d1808c31af862a0 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 19 Jul 2022 09:30:06 -0700 Subject: [PATCH] touchups --- .github/workflows/test.yml | 2 +- tinygrad/helpers.py | 8 ++------ tinygrad/llops/ops_cpu.py | 2 +- tinygrad/nn.py | 2 +- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 791f903008..451f003d24 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -97,7 +97,7 @@ jobs: - name: Install Dependencies run: pip install -e '.[gpu,testing]' - name: Run Pytest (default) - run: OPT=2 GPU=1 python -m pytest -s -v + run: OPT=1 GPU=1 python -m pytest -s -v testopencl: name: OpenCL Tests diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index e786f9c63b..9ad4c1077a 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -2,12 +2,8 @@ from collections import namedtuple import os, math def prod(x): return math.prod(x) - -# https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python -def argsort(x): return sorted(range(len(x)), key=x.__getitem__) - -def reduce_shape(shape, axis): - return tuple(1 if i in axis else shape[i] for i in range(len(shape))) +def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python +def reduce_shape(shape, axis): return tuple(1 if i in axis else shape[i] for i in range(len(shape))) ConvArgs = namedtuple('ConvArgs', ['H', 'W', 'groups', 'rcout', 'cin', 'oy', 'ox', 'iy', 'ix', 'sy', 'sx', 'bs', 'cout', 'py', 'py_', 'px', 'px_', 'dy', 'dx', 'out_shape']) def get_conv_args(x_shape, w_shape, stride=1, groups=1, padding=0, dilation=1, out_shape=None): diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index f55fd00827..f5695aa28b 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -42,7 +42,7 @@ class CPUBuffer(np.ndarray): elif op == MovementOps.PERMUTE: return x.permute(arg) elif op == MovementOps.FLIP: return x.flip(arg) elif op == MovementOps.PAD: return x.custompad(arg) - elif op == MovementOps.SHRINK: return x[tuple(slice(p[0], p[1], None) for i,p in enumerate(arg))] + elif op == MovementOps.SHRINK: return x[tuple(slice(p[0], p[1], None) for p in arg)] elif op == MovementOps.EXPAND: return x.expand(arg) elif op == MovementOps.STRIDED: return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg]) diff --git a/tinygrad/nn.py b/tinygrad/nn.py index ebaaf2479e..9bae6f060e 100644 --- a/tinygrad/nn.py +++ b/tinygrad/nn.py @@ -33,7 +33,7 @@ class BatchNorm2D: return batch_normalize(x, self.weight, self.bias, batch_mean, batch_var, self.eps) - return batch_normalize(x, self.weight, self.bias, self.running_mean, self.running_var,self.eps) + return batch_normalize(x, self.weight, self.bias, self.running_mean, self.running_var, self.eps) class Conv2d: def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):