Conversation

Takeaway 1: There's a deep relationship between grokking and phase changes. Phase changes are an abrupt change in capabilities during training, like we see when training a 2L attn-only transformer to predict repeated subsequences (2/17)
Image
1
103
Phase changes turn into grokking if we regularise and add JUST enough data that it still generalises - here's the same problem on 512 training data points (3/17)
Image
2
71
Why? Phase changes happen because the general solution is "hard" to reach - something about the loss landscape makes models take a while to get there. But memorisation is easy to reach. So models memorise first. (4/17)
2
106
But generalising is simpler, so regularisation kicks in. The model slowly interpolates from memorising to generalising, until the phase change kicks in and it speeds up. Test loss only changes when the memorisation is totally gone. (5/17)
3
114
I don't claim to fully understand phase changes, but I DO claim to have reduced my confusion about grokking to my confusion about phase changes! (6/17)
3
76
Takeaway 2: Transforms grok modular addition with an interpretable algorithm, using Discrete Fourier Transforms and trig identities. This can be reverse engineered from the weights post-grok. If we ablate all activations not captured by the algorithm, performance goes UP! (7/17)
Image
Image
8
252
To emphasise, I did not predict this algorithm! SGD found it entirely on its own, I discovered it after weeks of staring at the weights. Just another way that neural networks are smarter than me! (8/17)
2
125
Turns out the right way to think of the input space is not as 113 separate inputs but as waves of different frequencies - the model embeds a few waves and ignores the rest (9/17)
Image
2
105
Neuron activations are 2D Periodic functions involving a single frequency - neurons neatly cluster by their frequency, and we recover the frequency with a 2D DFT. Here are a few: (10/17)
Image
Image
2
79
Takeaway 3: Once we understand the circuits, we can see them smoothly develop during training. The model interpolates between memorising and generalising, it is NOT a random walk. Here's the embedding of waves 3K epochs pre-grokking (11/17)
Image
2
77
So we see toy models grok. Does any of this matter for large models? Probably! They have too much data to grok, but they likely DO see phase changes when learning a new capability. But many tiny phase changes average to a smooth loss curve so we don't notice. (12/17)
2
73
In contrast, these tiny models seem weird just because they're small enough to notice. 5 digit addition has a bumpy loss curve, from a phase change for each digit. (13/17)
Image
2
98
My weirdest result: log_softmax has a float32 precision error, rounds outputs below 1.19e-7 to 0. This gives cursed loss curves, but is fixed by casting logits to float64. This rounding creates 's Slingshots and the Slingshot Mechanism is not needed to grok (14/17)
Image
2
75
I go into more detail about what's up with the Slingshot Mechanism in another thread (18/17)
Quote Tweet
Pulling out a specific unexpected result from my grokking work, what's up with @AggieInCA's Slingshot Mechanism? Turns out there's a float32 precision error in log_softmax - any output <1.19e-7 rounds down to 0 - try log_softmax([0, 16.63]) vs 16.64 🧵 (1/7) twitter.com/AggieInCA/stat…
Show this thread
1
21
Another thread! This time with animations plotting out exactly what the circuits look like as they develop - I find it fascinating to see what complex structure can emerge from simple gradient descent, and how smooth this emergence is.
Quote Tweet
1/ Another follow-up 🧵 on grokking - Now we understand the circuits, we can actually animate their development over training, with fascinating results. Here's the norm of the embedding of each wave over inputs - sparsity is smoothly learnt pre-grokking (Grokking is epoch 9-13K) twitter.com/NeelNanda5/sta…
Show this thread
Embedded video
0:16
5.8K views
3
35