contiguous, and no strided for matmul

This commit is contained in:
George Hotz
2022-11-09 16:56:26 -08:00
parent 1271f19a2b
commit bff47e9dc1
2 changed files with 14 additions and 5 deletions

View File

@@ -235,11 +235,16 @@ class LazyBuffer:
if NOCONV or not getattr(x.dbuffer, "processing_op", False):
# universal conv, just mul and reduce
# TODO: is there any way to replace strided with other movement ops?
x = x.movement_op(MovementOps.STRIDED, (
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
(1, 1), (C.oy, C.sy*x.shape[3]), (C.ox, C.sx),
(C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
# TODO: is there any way to replace strided with other movement ops? answer: not really
if C.sy == 1 and C.sx == 1 and C.H == 1 and C.W == 1:
# TODO: this doesn't belong here, ShapeTracker or lazy should be able to infer this from STRIDED
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.oy, C.ox, 1, C.H, C.W))
x = x.movement_op(MovementOps.PERMUTE, (0,1,5,3,4,2,6,7))
else:
x = x.movement_op(MovementOps.STRIDED, (
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
(1, 1), (C.oy, C.sy*x.shape[3]), (C.ox, C.sx),
(C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
#if C.H <= 3 and C.W <= 3: # max 9x the RAM overhead, this is im2col
# x = x.contiguous_op()
x = x.movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W))

View File

@@ -2,6 +2,10 @@ from tinygrad.helpers import prod, argsort, reduce_shape, get_conv_args
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
from tinygrad.tensor import Function
class Contiguous(Function):
def forward(self, x): return x.contiguous_op()
def backward(self, grad_output): return grad_output
# ************* unary ops *************
class ReLU(Function):