mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
EXPAND -> REPEAT
This commit is contained in:
@@ -134,7 +134,7 @@ Buffer # class of memory on this d
|
||||
unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
|
||||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size)
|
||||
movement_op (RESHAPE, PERMUTE, SLICE, EXPAND, FLIP) # A -> B (different size)
|
||||
movement_op (RESHAPE, PERMUTE, SLICE, REPEAT, FLIP) # A -> B (different size)
|
||||
processing_op (CONV) # A + B -> C
|
||||
```
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class CPUBuffer(np.ndarray):
|
||||
elif op == MovementOps.SLICE:
|
||||
padding = [(max(0, -p[0]), max(0, p[1]-x.shape[i])) for i,p in enumerate(arg)]
|
||||
return x.custompad(padding)[tuple(slice(p[0] + padding[i][0], p[1] + padding[i][0], None) for i,p in enumerate(arg))]
|
||||
elif op == MovementOps.EXPAND: return x.expand(arg)
|
||||
elif op == MovementOps.REPEAT: return x.expand(arg)
|
||||
elif op == MovementOps.STRIDED: return x.contiguous().as_strided([x[0] for x in arg], [x[1] for x in arg])
|
||||
|
||||
def processing_op(x,op,w,C):
|
||||
|
||||
@@ -152,7 +152,7 @@ class GPUBuffer:
|
||||
# reverse operation of expand, this validates inputs
|
||||
# generate loops with combined adjacent reduce axis
|
||||
acc = 1
|
||||
for shp,stride in ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, bufs[0][1].shape).views[-1].shape_strides[::-1]:
|
||||
for shp,stride in ShapeTracker(ret.shape).movement_op(MovementOps.REPEAT, bufs[0][1].shape).views[-1].shape_strides[::-1]:
|
||||
if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};"))
|
||||
acc *= shp
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class Sum(Function):
|
||||
return input.reduce_op(ReduceOps.SUM, reduce_shape(input.shape, axis))
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.movement_op(MovementOps.EXPAND, ctx.input_shape)
|
||||
return grad_output.movement_op(MovementOps.REPEAT, ctx.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
@@ -51,14 +51,14 @@ class Max(Function):
|
||||
input, ret = ctx.saved_tensors
|
||||
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = input.binary_op(BinaryOps.CMPEQ, ret.movement_op(MovementOps.EXPAND, input.shape))
|
||||
max_is_1s = input.binary_op(BinaryOps.CMPEQ, ret.movement_op(MovementOps.REPEAT, input.shape))
|
||||
|
||||
# sum of locations, averaged
|
||||
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape)
|
||||
div = div.movement_op(MovementOps.EXPAND, input.shape)
|
||||
div = div.movement_op(MovementOps.REPEAT, input.shape)
|
||||
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
|
||||
|
||||
grad_output_expanded = grad_output.movement_op(MovementOps.EXPAND, input.shape)
|
||||
grad_output_expanded = grad_output.movement_op(MovementOps.REPEAT, input.shape)
|
||||
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
|
||||
|
||||
# ************* binary ops *************
|
||||
@@ -114,7 +114,7 @@ class Pow(Function):
|
||||
class Expand(Function):
|
||||
def forward(ctx, x, shape):
|
||||
ctx.input_shape = x.shape
|
||||
return x.movement_op(MovementOps.EXPAND, shape)
|
||||
return x.movement_op(MovementOps.REPEAT, shape)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.reduce_op(ReduceOps.SUM, ctx.input_shape)
|
||||
|
||||
@@ -13,7 +13,7 @@ sys.setrecursionlimit(10000)
|
||||
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN"])
|
||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP", "STRIDED"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "REPEAT", "FLIP", "STRIDED"])
|
||||
ProcessingOps = Enum("ProcessingOps", ["CONV"])
|
||||
LoadOps = Enum("LoadOps", ["FROMCPU"])
|
||||
|
||||
@@ -266,7 +266,7 @@ class LazyBuffer:
|
||||
(C.rcout, 0), (C.oy, C.sy*x.shape[3]), (C.ox, C.sx),
|
||||
(C.cin, x.shape[2]*x.shape[3]), (C.H, C.dy*x.shape[3]), (C.W, C.dx)))
|
||||
w = w.movement_op(MovementOps.RESHAPE, (1, C.groups, C.rcout, 1, 1, C.cin, C.H, C.W)) \
|
||||
.movement_op(MovementOps.EXPAND, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W))
|
||||
.movement_op(MovementOps.REPEAT, (C.bs, C.groups, C.rcout, C.oy, C.ox, C.cin, C.H, C.W))
|
||||
#print(x.st.views, w.st.views)
|
||||
return x.binary_op(BinaryOps.MUL, w).reduce_op(ReduceOps.SUM, (C.bs, C.groups, C.rcout, C.oy, C.ox, 1, 1, 1)) \
|
||||
.movement_op(MovementOps.RESHAPE, (C.bs, C.cout, C.oy, C.ox))
|
||||
|
||||
Reference in New Issue
Block a user