From d1284fa817a9c875388bd58315bb0cea3ec56a4d Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 10 Nov 2020 16:10:14 -0800 Subject: [PATCH] stride tests and i32 --- test/test_ops.py | 4 ++-- tinygrad/opsgpu.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 0773c06436..976b2d9e8f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -83,11 +83,11 @@ class TestOps(unittest.TestCase): with self.subTest(stride := 2): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(), - lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), gpu=self.gpu, forward_only=self.gpu) + lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), gpu=self.gpu) with self.subTest(stride := (2,1)): helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(), - lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), gpu=self.gpu, forward_only=self.gpu) + lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), gpu=self.gpu) def test_maxpool2d(self): for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: diff --git a/tinygrad/opsgpu.py b/tinygrad/opsgpu.py index d50a0c15c2..2b2cabb0ab 100644 --- a/tinygrad/opsgpu.py +++ b/tinygrad/opsgpu.py @@ -537,9 +537,8 @@ class Conv2D(Function): } """) - prg.convw(ctx.cl_queue, [ctx.groups*rcout*cin, H, W], None, x, grad_output, dw, np.int32(H), np.int32(W), np.int32(ctx.groups), - np.int32(rcout), np.int32(cin), np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix), np.int32(ys), np.int32(xs), np.int32(bs)) - prg.convx(ctx.cl_queue, [bs, ctx.groups, cin], None, w, grad_output, dx, np.int32(H), np.int32(W), np.int32(ctx.groups), - np.int32(rcout), np.int32(cin), np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix), np.int32(ys), np.int32(xs), np.int32(bs)) + conv_args = i32(H), i32(W), i32(ctx.groups), i32(rcout), i32(cin), i32(oy), i32(ox), i32(iy), i32(ix), i32(ys), i32(xs), i32(bs) + prg.convw(ctx.cl_queue, [ctx.groups*rcout*cin, H, W], None, x, grad_output, dw, *conv_args) + prg.convx(ctx.cl_queue, [bs, ctx.groups, cin], None, w, grad_output, dx,*conv_args) return dx, dw register('conv2d', Conv2D, gpu=True)