mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
contiguous, and no strided for matmul
This commit is contained in:
@@ -235,11 +235,16 @@ class LazyBuffer:
|
||||
|
||||
if NOCONV or not getattr(x.dbuffer, "processing_op", False):
|
||||
# universal conv, just mul and reduce
|
||||
# TODO: is there any way to replace strided with other movement ops?
|
||||
x = x.movement_op(MovementOps.STRIDED, (
|
||||
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
|
||||
(1, 1), (C.oy, C.sy*x.shape[3]), (C.ox, C.sx),
|
||||
(C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
|
||||
# TODO: is there any way to replace strided with other movement ops? answer: not really
|
||||
if C.sy == 1 and C.sx == 1 and C.H == 1 and C.W == 1:
|
||||
# TODO: this doesn't belong here, ShapeTracker or lazy should be able to infer this from STRIDED
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.oy, C.ox, 1, C.H, C.W))
|
||||
x = x.movement_op(MovementOps.PERMUTE, (0,1,5,3,4,2,6,7))
|
||||
else:
|
||||
x = x.movement_op(MovementOps.STRIDED, (
|
||||
(C.bs, C.groups*C.cin*x.shape[2]*x.shape[3]), (C.groups, C.cin*x.shape[2]*x.shape[3]),
|
||||
(1, 1), (C.oy, C.sy*x.shape[3]), (C.ox, C.sx),
|
||||
(C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
|
||||
#if C.H <= 3 and C.W <= 3: # max 9x the RAM overhead, this is im2col
|
||||
# x = x.contiguous_op()
|
||||
x = x.movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W))
|
||||
|
||||
@@ -2,6 +2,10 @@ from tinygrad.helpers import prod, argsort, reduce_shape, get_conv_args
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
from tinygrad.tensor import Function
|
||||
|
||||
class Contiguous(Function):
|
||||
def forward(self, x): return x.contiguous_op()
|
||||
def backward(self, grad_output): return grad_output
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
class ReLU(Function):
|
||||
|
||||
Reference in New Issue
Block a user