mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add sub and a few refactors
This commit is contained in:
@@ -50,9 +50,30 @@ class MetalBuffer:
|
||||
relu_shader = MetalPerformanceShaders.MPSImageThresholdToZero.alloc().initWithDevice_thresholdValue_linearGrayColorTransform_(device, 0, None)
|
||||
inv_relu_shader = MetalPerformanceShaders.MPSImageThresholdBinary.alloc().initWithDevice_thresholdValue_maximumValue_linearGrayColorTransform_(device, 0, 1, None)
|
||||
add_shader = MetalPerformanceShaders.MPSImageAdd.alloc().initWithDevice_(device)
|
||||
sub_shader = MetalPerformanceShaders.MPSImageSubtract.alloc().initWithDevice_(device)
|
||||
mul_shader = MetalPerformanceShaders.MPSImageMultiply.alloc().initWithDevice_(device)
|
||||
sum_shader = MetalPerformanceShaders.MPSImageReduceRowSum.alloc().initWithDevice_(device)
|
||||
|
||||
def unary_op(shader, input):
|
||||
out = MetalBuffer(input.shape, None)
|
||||
mtl_buffer = cmd_buffer()
|
||||
shader.encodeToCommandBuffer_sourceTexture_destinationTexture_(
|
||||
mtl_buffer, input.texture, out.texture
|
||||
)
|
||||
mtl_buffer.commit()
|
||||
return out
|
||||
|
||||
def binary_op(shader, x, y):
|
||||
ret = MetalBuffer(x.shape, None)
|
||||
mtl_buffer = cmd_buffer()
|
||||
shader.setPrimaryEdgeMode_(MetalPerformanceShaders.MPSImageEdgeModeClamp)
|
||||
shader.setSecondaryEdgeMode_(MetalPerformanceShaders.MPSImageEdgeModeClamp)
|
||||
shader.encodeToCommandBuffer_primaryTexture_secondaryTexture_destinationTexture_(
|
||||
mtl_buffer, x.texture, y.texture, ret.texture
|
||||
)
|
||||
mtl_buffer.commit()
|
||||
return ret
|
||||
|
||||
class Sum(Function):
|
||||
def forward(ctx, input, axis=None):
|
||||
assert axis is None
|
||||
@@ -68,54 +89,46 @@ class Sum(Function):
|
||||
def backward(ctx, grad_output):
|
||||
shape, axis = ctx.saved_tensors
|
||||
out = MetalBuffer(shape, None)
|
||||
ret = MetalBuffer(shape, None)
|
||||
mtl_buffer = cmd_buffer()
|
||||
add_shader.setPrimaryEdgeMode_(MetalPerformanceShaders.MPSImageEdgeModeClamp)
|
||||
add_shader.setSecondaryEdgeMode_(MetalPerformanceShaders.MPSImageEdgeModeClamp)
|
||||
add_shader.encodeToCommandBuffer_primaryTexture_secondaryTexture_destinationTexture_(
|
||||
mtl_buffer, out.texture, grad_output.texture, ret.texture
|
||||
)
|
||||
mtl_buffer.commit()
|
||||
return ret
|
||||
return binary_op(add_shader, out, grad_output)
|
||||
|
||||
class ReLU(Function):
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
out = MetalBuffer(input.shape, None)
|
||||
mtl_buffer = cmd_buffer()
|
||||
relu_shader.encodeToCommandBuffer_sourceTexture_destinationTexture_(
|
||||
mtl_buffer, input.texture, out.texture
|
||||
)
|
||||
mtl_buffer.commit()
|
||||
return out
|
||||
return unary_op(relu_shader, input)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
out = MetalBuffer(input.shape, None)
|
||||
mtl_buffer = mtl_queue.commandBuffer()
|
||||
inv_relu_shader.encodeToCommandBuffer_sourceTexture_destinationTexture_(
|
||||
mtl_buffer, input.texture, out.texture
|
||||
)
|
||||
# TODO: make in place work
|
||||
#mul_shader.encodeToCommandBuffer_inPlacePrimaryTexture_secondaryTexture_fallbackCopyAllocator_(
|
||||
# mtl_buffers, out.texture, grad_output.texture, None)
|
||||
ret = MetalBuffer(input.shape, None)
|
||||
mul_shader.encodeToCommandBuffer_primaryTexture_secondaryTexture_destinationTexture_(
|
||||
mtl_buffer, grad_output.texture, out.texture, ret.texture
|
||||
)
|
||||
mtl_buffer.commit()
|
||||
return ret
|
||||
return binary_op(mul_shader, unary_op(inv_relu_shader, input), grad_output)
|
||||
|
||||
|
||||
"""
|
||||
class Add(Function):
|
||||
def forward(ctx, x, y):
|
||||
#add_shader.
|
||||
pass
|
||||
ctx.save_for_backward(x, y)
|
||||
return binary_op(add_shader, x, y)
|
||||
|
||||
#ctx.save_for_backward(x.shape, y.shape)
|
||||
#return binary_op(ctx, 'a+b', x, y)
|
||||
"""
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
return grad_output, grad_output
|
||||
|
||||
class Sub(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return binary_op(sub_shader, x, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
out = MetalBuffer(y.shape, None)
|
||||
return grad_output, binary_op(sub_shader, out, grad_output)
|
||||
|
||||
class Mul(Function):
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
return binary_op(mul_shader, x, y)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x,y = ctx.saved_tensors
|
||||
grad_x = binary_op(mul_shader, y, grad_output)
|
||||
grad_y = binary_op(mul_shader, x, grad_output)
|
||||
return grad_x, grad_y
|
||||
|
||||
if __name__ == "__main__":
|
||||
b1 = MetalBuffer(10, np.ones(10))
|
||||
|
||||
Reference in New Issue
Block a user