mirror of
https://github.com/3b1b/manim.git
synced 2026-04-26 03:00:23 -04:00
Further nn/part2 progress
This commit is contained in:
@@ -85,10 +85,14 @@ class AmbientMovement(ContinualAnimation):
|
||||
class ContinualUpdateFromFunc(ContinualAnimation):
|
||||
def __init__(self, mobject, func, **kwargs):
|
||||
self.func = func
|
||||
self.func_arg_count = func.func_code.co_argcount
|
||||
if self.func_arg_count > 2:
|
||||
raise Exception("ContinualUpdateFromFunc function must take 1 or 2 args")
|
||||
ContinualAnimation.__init__(self, mobject, **kwargs)
|
||||
|
||||
def update_mobject(self, dt):
|
||||
self.func(self.mobject)
|
||||
args = (self.mobject, dt)
|
||||
self.func(*args[:self.func_arg_count])
|
||||
|
||||
class ContinualMaintainPositionRelativeTo(ContinualAnimation):
|
||||
def __init__(self, mobject, tracked_mobject, **kwargs):
|
||||
|
||||
@@ -22,6 +22,7 @@ from nn.mnist_loader import load_data_wrapper
|
||||
|
||||
NN_DIRECTORY = os.path.dirname(os.path.realpath(__file__))
|
||||
# PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_36")
|
||||
# PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases_ReLU")
|
||||
PRETRAINED_DATA_FILE = os.path.join(NN_DIRECTORY, "pretrained_weights_and_biases")
|
||||
IMAGE_MAP_DATA_FILE = os.path.join(NN_DIRECTORY, "image_map")
|
||||
# PRETRAINED_DATA_FILE = "/Users/grant/cs/manim/nn/pretrained_weights_and_biases_on_zero"
|
||||
@@ -29,7 +30,7 @@ IMAGE_MAP_DATA_FILE = os.path.join(NN_DIRECTORY, "image_map")
|
||||
DEFAULT_LAYER_SIZES = [28**2, 16, 16, 10]
|
||||
|
||||
class Network(object):
|
||||
def __init__(self, sizes):
|
||||
def __init__(self, sizes, non_linearity = "sigmoid"):
|
||||
"""The list ``sizes`` contains the number of neurons in the
|
||||
respective layers of the network. For example, if the list
|
||||
was [2, 3, 1] then it would be a three-layer network, with the
|
||||
@@ -45,11 +46,19 @@ class Network(object):
|
||||
self.biases = [np.random.randn(y, 1) for y in sizes[1:]]
|
||||
self.weights = [np.random.randn(y, x)
|
||||
for x, y in zip(sizes[:-1], sizes[1:])]
|
||||
if non_linearity == "sigmoid":
|
||||
self.non_linearity = sigmoid
|
||||
self.d_non_linearity = sigmoid_prime
|
||||
elif non_linearity == "ReLU":
|
||||
self.non_linearity = ReLU
|
||||
self.d_non_linearity = ReLU_prime
|
||||
else:
|
||||
raise Exception("Invalid non_linearity")
|
||||
|
||||
def feedforward(self, a):
|
||||
"""Return the output of the network if ``a`` is input."""
|
||||
for b, w in zip(self.biases, self.weights):
|
||||
a = sigmoid(np.dot(w, a)+b)
|
||||
a = self.non_linearity(np.dot(w, a)+b)
|
||||
return a
|
||||
|
||||
def get_activation_of_all_layers(self, input_a, n_layers = None):
|
||||
@@ -58,7 +67,7 @@ class Network(object):
|
||||
activations = [input_a.reshape((input_a.size, 1))]
|
||||
for bias, weight in zip(self.biases, self.weights)[:n_layers]:
|
||||
last_a = activations[-1]
|
||||
new_a = sigmoid(np.dot(weight, last_a) + bias)
|
||||
new_a = self.non_linearity(np.dot(weight, last_a) + bias)
|
||||
new_a = new_a.reshape((new_a.size, 1))
|
||||
activations.append(new_a)
|
||||
return activations
|
||||
@@ -118,11 +127,11 @@ class Network(object):
|
||||
for b, w in zip(self.biases, self.weights):
|
||||
z = np.dot(w, activation)+b
|
||||
zs.append(z)
|
||||
activation = sigmoid(z)
|
||||
activation = self.non_linearity(z)
|
||||
activations.append(activation)
|
||||
# backward pass
|
||||
delta = self.cost_derivative(activations[-1], y) * \
|
||||
sigmoid_prime(zs[-1])
|
||||
self.d_non_linearity(zs[-1])
|
||||
nabla_b[-1] = delta
|
||||
nabla_w[-1] = np.dot(delta, activations[-2].transpose())
|
||||
# Note that the variable l in the loop below is used a little
|
||||
@@ -133,7 +142,7 @@ class Network(object):
|
||||
# that Python can use negative indices in lists.
|
||||
for l in xrange(2, self.num_layers):
|
||||
z = zs[-l]
|
||||
sp = sigmoid_prime(z)
|
||||
sp = self.d_non_linearity(z)
|
||||
delta = np.dot(self.weights[-l+1].transpose(), delta) * sp
|
||||
nabla_b[-l] = delta
|
||||
nabla_w[-l] = np.dot(delta, activations[-l-1].transpose())
|
||||
@@ -170,6 +179,14 @@ def sigmoid_inverse(z):
|
||||
1.0, (np.true_divide(1.0, z) - 1)
|
||||
))
|
||||
|
||||
def ReLU(z):
|
||||
result = np.array(z)
|
||||
result[result < 0] = 0
|
||||
return result
|
||||
|
||||
def ReLU_prime(z):
|
||||
return (np.array(z) > 0).astype('int')
|
||||
|
||||
def get_pretrained_network():
|
||||
data_file = open(PRETRAINED_DATA_FILE)
|
||||
weights, biases = cPickle.load(data_file)
|
||||
|
||||
29
nn/part2.py
29
nn/part2.py
@@ -206,7 +206,6 @@ class PreviewLearning(NetworkScene):
|
||||
|
||||
reversed_delta_edges = VGroup(*it.chain(*reversed(delta_edge_groups)))
|
||||
reversed_delta_neurons = VGroup(*reversed(delta_neuron_groups))
|
||||
edge_groups.save_state()
|
||||
|
||||
self.play(
|
||||
LaggedStart(
|
||||
@@ -223,6 +222,7 @@ class PreviewLearning(NetworkScene):
|
||||
rate_func = None,
|
||||
)
|
||||
)
|
||||
edge_groups.save_state()
|
||||
self.color_network_edges()
|
||||
self.remove(edge_groups)
|
||||
self.play(*it.chain(
|
||||
@@ -395,15 +395,34 @@ class FunctionMinmization(GraphScene):
|
||||
}
|
||||
def construct(self):
|
||||
self.setup_axes()
|
||||
title = TextMobject("Finding minima")
|
||||
title.to_edge(UP)
|
||||
self.add(title)
|
||||
|
||||
def func(x):
|
||||
x -= 5
|
||||
return 0.1*(x**3 - 9*x) + 4
|
||||
x -= 4.5
|
||||
return 0.03*(x**4 - 16*x**2) + 0.3*x + 4
|
||||
graph = self.get_graph(func)
|
||||
graph_label = self.get_graph_label(graph, "C(x)")
|
||||
self.add(graph, graph_label)
|
||||
|
||||
dot = Dot(color = YELLOW)
|
||||
x =
|
||||
dots = VGroup(*[
|
||||
Dot().move_to(self.input_to_graph_point(x, graph))
|
||||
for x in range(10)
|
||||
])
|
||||
dots.gradient_highlight(YELLOW, RED)
|
||||
|
||||
def update_dot(dot, dt):
|
||||
x = self.x_axis.point_to_number(dot.get_center())
|
||||
slope = self.slope_of_tangent(x, graph)
|
||||
x -= slope*dt
|
||||
dot.move_to(self.input_to_graph_point(x, graph))
|
||||
|
||||
self.add(*[
|
||||
ContinualUpdateFromFunc(dot, update_dot)
|
||||
for dot in dots
|
||||
])
|
||||
self.dither(10)
|
||||
|
||||
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user