I've spent the past few months exploring 's grokking result through the lens of mechanistic interpretability. I fully reverse engineered the modular addition model, and looked at what it does when training. So what's up with grokking? A 🧵... (1/17) alignmentforum.org/posts/N6WM6hs7
Conversation
Takeaway 1: There's a deep relationship between grokking and phase changes. Phase changes are an abrupt change in capabilities during training, like we see when training a 2L attn-only transformer to predict repeated subsequences (2/17)
1
4
103
Phase changes turn into grokking if we regularise and add JUST enough data that it still generalises - here's the same problem on 512 training data points (3/17)
2
3
71
Why? Phase changes happen because the general solution is "hard" to reach - something about the loss landscape makes models take a while to get there. But memorisation is easy to reach. So models memorise first. (4/17)
2
6
106
But generalising is simpler, so regularisation kicks in. The model slowly interpolates from memorising to generalising, until the phase change kicks in and it speeds up. Test loss only changes when the memorisation is totally gone. (5/17)
3
5
114
I don't claim to fully understand phase changes, but I DO claim to have reduced my confusion about grokking to my confusion about phase changes! (6/17)
3
1
76
Takeaway 2: Transforms grok modular addition with an interpretable algorithm, using Discrete Fourier Transforms and trig identities. This can be reverse engineered from the weights post-grok. If we ablate all activations not captured by the algorithm, performance goes UP! (7/17)
8
40
252
To emphasise, I did not predict this algorithm! SGD found it entirely on its own, I discovered it after weeks of staring at the weights. Just another way that neural networks are smarter than me! (8/17)
2
4
125
Turns out the right way to think of the input space is not as 113 separate inputs but as waves of different frequencies - the model embeds a few waves and ignores the rest (9/17)
2
5
105
Neuron activations are 2D Periodic functions involving a single frequency - neurons neatly cluster by their frequency, and we recover the frequency with a 2D DFT. Here are a few: (10/17)
2
5
79
Takeaway 3: Once we understand the circuits, we can see them smoothly develop during training. The model interpolates between memorising and generalising, it is NOT a random walk. Here's the embedding of waves 3K epochs pre-grokking (11/17)
2
7
77
So we see toy models grok. Does any of this matter for large models? Probably! They have too much data to grok, but they likely DO see phase changes when learning a new capability. But many tiny phase changes average to a smooth loss curve so we don't notice. (12/17)
2
4
73
In contrast, these tiny models seem weird just because they're small enough to notice. 5 digit addition has a bumpy loss curve, from a phase change for each digit. (13/17)
2
8
98
My weirdest result: log_softmax has a float32 precision error, rounds outputs below 1.19e-7 to 0. This gives cursed loss curves, but is fixed by casting logits to float64. This rounding creates 's Slingshots and the Slingshot Mechanism is not needed to grok (14/17)
2
3
75
To me, this work is a proof of concept of the Circuits agenda. Networks CAN be understood. If we put in the work to reverse engineer them, we can dissolve confusing phenomena like grokking. So let's go do it! (15/17)
distill.pub/2020/circuits/
1
12
148
If you want to join me on this, I end with a list of promising future directions. If you're excited about any of them, let's chat! (16/17)
alignmentforum.org/posts/N6WM6hs7
3
3
87
Credit to for leading the early stages of this project, thanks to my former colleagues at , ,,, for the Transformer Circuits work inspiring this & to for finding this mystery! (END)
transformer-circuits.pub
3
2
70
I go into more detail about what's up with the Slingshot Mechanism in another thread (18/17)
Quote Tweet
Pulling out a specific unexpected result from my grokking work, what's up with @AggieInCA's Slingshot Mechanism? Turns out there's a float32 precision error in log_softmax - any output <1.19e-7 rounds down to 0 - try log_softmax([0, 16.63]) vs 16.64
(1/7) twitter.com/AggieInCA/stat…
Show this thread
1
21
Another thread! This time with animations plotting out exactly what the circuits look like as they develop - I find it fascinating to see what complex structure can emerge from simple gradient descent, and how smooth this emergence is.
Quote Tweet
1/ Another follow-up
on grokking - Now we understand the circuits, we can actually animate their development over training, with fascinating results. Here's the norm of the embedding of each wave over inputs - sparsity is smoothly learnt pre-grokking (Grokking is epoch 9-13K) twitter.com/NeelNanda5/sta…
Show this thread
0:16
5.8K views
3
2
35
I've also released most of the code and a more detailed technical write-up, for reproducibility and to test my pet hypothesis that all ML papers would be better as Colab notebooks - it includes interactive figures esp animations during training! (20/17)
2
4
43
