Finished WalkThroughTwoExample in nn/part3

This commit is contained in:
Grant Sanderson
2017-10-25 14:21:12 -07:00
parent fae819695a
commit 4ea7a29628
2 changed files with 216 additions and 23 deletions

View File

@@ -97,13 +97,13 @@ class VMobject(Mobject):
def get_fill_color(self):
try:
self.fill_rgb = np.clip(self.fill_rgb, 0, 1)
self.fill_rgb = np.clip(self.fill_rgb, 0.0, 1.0)
return Color(rgb = self.fill_rgb)
except:
return Color(WHITE)
def get_fill_opacity(self):
return self.fill_opacity
return np.clip(self.fill_opacity, 0, 1)
def get_stroke_color(self):
try:
@@ -113,7 +113,7 @@ class VMobject(Mobject):
return Color(WHITE)
def get_stroke_width(self):
return self.stroke_width
return max(0, self.stroke_width)
def get_color(self):
if self.fill_opacity == 0:

View File

@@ -537,7 +537,6 @@ class ShowAveragingCost(PreviewLearning):
self.curr_image = image
class WalkThroughTwoExample(ShowAveragingCost):
CONFIG = {
"random_seed" : 0,
@@ -548,12 +547,10 @@ class WalkThroughTwoExample(ShowAveragingCost):
self.setup_bases()
def construct(self):
self.force_skipping()
self.setup_network()
self.setup_diff_words()
self.show_single_example()
# self.single_example_influencing_weights()
self.single_example_influencing_weights()
self.expand_last_layer()
self.show_activation_formula()
self.three_ways_to_increase()
@@ -591,7 +588,6 @@ class WalkThroughTwoExample(ShowAveragingCost):
run_time = 1,
)
self.play(
two.next_to, edge_groups[0].get_corner(DOWN+RIGHT), DOWN,
adjust_edge_group_anim(edge_groups[0])
@@ -630,6 +626,7 @@ class WalkThroughTwoExample(ShowAveragingCost):
mover.save_state()
mover.generate_target()
mover.target.scale_in_place(2)
neurons[2].save_state()
neurons.target.to_edge(DOWN, MED_LARGE_BUFF)
output_labels.target.next_to(neurons.target, RIGHT, MED_SMALL_BUFF)
@@ -733,6 +730,30 @@ class WalkThroughTwoExample(ShowAveragingCost):
)
self.dither()
#Show changing activations
anims = []
def get_decimal_update(start, end):
return lambda a : interpolate(start, end, a)
for i in range(10):
target = 1.0 if i == 2 else 0.01
anims += [neurons[i].set_fill, WHITE, target]
decimal = self.decimals[i]
anims.append(ChangingDecimal(
decimal,
get_decimal_update(decimal.number, target),
num_decimal_points = 1
))
anims.append(UpdateFromFunc(
self.decimals[i],
lambda m : m.set_fill(WHITE if m.number < 0.8 else BLACK)
))
self.play(
*anims,
run_time = 3,
rate_func = there_and_back
)
two_rect = rects[2]
eight_rect = rects[8].copy()
non_two_rects = VGroup(*[r for r in rects if r is not two_rect])
@@ -928,7 +949,7 @@ class WalkThroughTwoExample(ShowAveragingCost):
self.play(LaggedStart(
ApplyFunction, edges,
lambda edge : (
lambda m : m.rotate_in_place(np.pi/12).highlight(YELLOW),
lambda m : m.rotate_in_place(np.pi/12).set_stroke(YELLOW),
edge
),
rate_func = wiggle
@@ -944,6 +965,10 @@ class WalkThroughTwoExample(ShowAveragingCost):
ReplacementTransform(bright_neurons[0].copy(), a_terms[0]),
ShowCreation(terms_rect)
)
self.dither()
for x in range(2):
self.play(LaggedStart(ShowCreationThenDestruction, bright_edges))
self.play(LaggedStart(ShowCreation, bright_edges))
self.play(FadeOut(terms_rect))
self.dither()
self.play(
@@ -1074,18 +1099,17 @@ class WalkThroughTwoExample(ShowAveragingCost):
positive_arrows = VGroup()
negative_arrows = VGroup()
all_arrows = VGroup()
positive_edges = VGroup()
negative_edges = VGroup()
positive_neurons = VGroup()
negative_neurons = VGroup()
for neuron, edge in zip(prev_neurons, edges):
value = edge.get_stroke_width()
if Color(edge.get_stroke_color()) == Color(self.negative_edge_color):
value *= -1
arrow = Vector(0.25*value*UP, color = edge.get_color())
arrow.stretch_to_fit_height(neuron.get_height())
value = self.get_edge_value(edge)
arrow = self.get_neuron_nudge_arrow(edge)
arrow.move_to(neuron.get_left())
arrow.shift(SMALL_BUFF*LEFT)
all_arrows.add(arrow)
if value > 0:
positive_arrows.add(arrow)
positive_edges.add(edge)
@@ -1148,10 +1172,11 @@ class WalkThroughTwoExample(ShowAveragingCost):
)
self.dither()
self.play(Write(added_words, run_time = 1))
self.play(prev_neurons.set_stroke, WHITE, 2)
self.set_variables_as_attrs(
in_proportion_to_w = added_words,
prev_neuron_arrows = VGroup(positive_arrows, negative_arrows),
prev_neuron_arrows = all_arrows,
)
def only_keeping_track_of_changes(self):
@@ -1162,7 +1187,6 @@ class WalkThroughTwoExample(ShowAveragingCost):
words = TextMobject("No direct influence")
words.next_to(rect, UP)
self.revert_to_original_skipping_status()
self.play(ShowCreation(rect))
self.play(Write(words))
self.dither()
@@ -1171,16 +1195,185 @@ class WalkThroughTwoExample(ShowAveragingCost):
def show_other_output_neurons(self):
two_neuron = self.two_neuron
two_decimal = self.two_decimal
two_arrow = self.two_arrow
two_label = self.two_label
two_edges = two_neuron.edges_in
prev_neurons = self.network_mob.layers[-2].neurons
neurons = self.network_mob.layers[-1].neurons
prev_neuron_arrows = self.prev_neuron_arrows
arrows_to_fade = VGroup(prev_neuron_arrows)
output_labels = self.network_mob.output_labels
quads = zip(neurons, self.decimals, self.arrows, output_labels)
self.play(
two_neuron.restore,
two_decimal.scale, 0.5,
two_decimal.move_to, two_neuron.saved_state,
two_arrow.scale, 0.5,
two_arrow.next_to, two_neuron.saved_state, RIGHT, 0.5*SMALL_BUFF,
two_label.scale, 0.5,
two_label.next_to, two_neuron.saved_state, RIGHT, 1.5*SMALL_BUFF,
FadeOut(VGroup(self.lhs, self.rhs)),
*[e.restore for e in two_edges]
)
for neuron, decimal, arrow, label in quads[:2]:
plusses = VGroup()
new_arrows = VGroup()
for edge, prev_arrow in zip(neuron.edges_in, prev_neuron_arrows):
plus = TexMobject("+").scale(0.5)
plus.move_to(prev_arrow)
plus.shift(2*SMALL_BUFF*LEFT)
new_arrow = self.get_neuron_nudge_arrow(edge)
new_arrow.move_to(plus)
new_arrow.shift(2*SMALL_BUFF*LEFT)
plusses.add(plus)
new_arrows.add(new_arrow)
self.play(
FadeIn(VGroup(neuron, decimal, arrow, label)),
LaggedStart(ShowCreation, neuron.edges_in),
)
self.play(
ReplacementTransform(neuron.edges_in.copy(), new_arrows),
Write(plusses, run_time = 2)
)
arrows_to_fade.add(new_arrows, plusses)
prev_neuron_arrows = new_arrows
all_dots_plus = VGroup()
for arrow in prev_neuron_arrows:
dots_plus = TexMobject("\\cdots +")
dots_plus.scale(0.5)
dots_plus.move_to(arrow.get_center(), RIGHT)
dots_plus.shift(2*SMALL_BUFF*LEFT)
all_dots_plus.add(dots_plus)
arrows_to_fade.add(all_dots_plus)
self.play(
LaggedStart(
FadeIn, VGroup(*it.starmap(VGroup, quads[-7:])),
),
LaggedStart(
FadeIn, VGroup(*[n.edges_in for n in neurons[-7:]])
),
Write(all_dots_plus),
run_time = 3,
)
self.dither(2)
def squish(p):
return p[1]*UP
self.play(
arrows_to_fade.apply_function, squish,
arrows_to_fade.move_to, prev_neurons,
)
def show_recursion(self):
pass
network_start = VGroup(*it.chain(
self.network_mob.edge_groups[1],
self.network_mob.layers[1],
self.network_mob.edge_groups[0],
self.network_mob.layers[0],
))
words_to_fade = VGroup(
self.increase_words,
self.in_proportion_to_w,
self.in_proportion_to_a,
)
self.play(
FadeOut(words_to_fade),
LaggedStart(FadeIn, network_start, run_time = 3)
)
self.dither()
for i in 1, 0:
edges = self.network_mob.edge_groups[i]
self.play(LaggedStart(
ApplyFunction, edges,
lambda edge : (
lambda m : m.rotate_in_place(np.pi/12).highlight(YELLOW),
edge
),
rate_func = wiggle
))
self.dither()
####
def get_neuron_nudge_arrow(self, edge):
value = self.get_edge_value(edge)
height = np.sign(value)*0.1 + 0.1*value
arrow = Vector(height*UP, color = edge.get_color())
return arrow
def get_edge_value(self, edge):
value = edge.get_stroke_width()
if Color(edge.get_stroke_color()) == Color(self.negative_edge_color):
value *= -1
return value
class NotANeuroScientist(TeacherStudentsScene):
def construct(self):
quote = TextMobject("``Neurons that fire together wire together''")
quote.to_edge(UP)
self.add(quote)
asterisks = TextMobject("***")
asterisks.next_to(quote.get_corner(UP+RIGHT), RIGHT, SMALL_BUFF)
asterisks.highlight(BLUE)
brain = SVGMobject(file_name = "brain")
brain.scale_to_fit_height(1.5)
self.add(brain)
double_arrow = DoubleArrow(LEFT, RIGHT)
double_arrow.next_to(brain, RIGHT)
q_marks = TextMobject("???")
q_marks.next_to(double_arrow, UP)
network = NetworkMobject(Network(sizes = [6, 4, 4, 5]))
network.scale_to_fit_height(1.5)
network.next_to(double_arrow, RIGHT)
group = VGroup(brain, double_arrow, q_marks, network)
group.next_to(self.students, UP, buff = 1.5)
self.add(group)
self.add(ContinualEdgeUpdate(network))
rect = SurroundingRectangle(group)
no_claim_words = TextMobject("No claims here...")
no_claim_words.next_to(rect, UP)
no_claim_words.highlight(YELLOW)
brain_outline = brain.copy()
brain_outline.set_fill(opacity = 0)
brain_outline.set_stroke(BLUE, 3)
brain_anim = ShowCreationThenDestruction(brain_outline)
words = TextMobject("Definitely not \\\\ a neuroscientist")
words.next_to(self.teacher, UP, buff = 1.5)
words.shift_onto_screen()
arrow = Arrow(words.get_bottom(), self.teacher.get_top())
self.play(
Write(words),
GrowArrow(arrow),
self.teacher.change, "guilty", words,
run_time = 1,
)
self.change_student_modes(*3*["sassy"])
self.play(
ShowCreation(rect),
Write(no_claim_words, run_time = 1),
brain_anim
)
self.dither()
self.play(brain_anim)
self.play(Write(asterisks, run_time = 1))
for x in range(2):
self.play(brain_anim)
self.dither()