Clustering and neural network embeddings #
Background #
Consider a dense neural network classifier, which we can think of as a composition of functions:
\[\mathcal{N}: \mathbb{R}^n \rightarrow \mathbb{R}^{n_1} \rightarrow \mathbb{R}^{n_2} \rightarrow \cdots \rightarrow \mathbb{R}^m\]
with one function for each layer of the network. The original set of feature vectors \(\{\vec{x}_i\}\) from a data set \(\mathcal{D}\) is embedded in \(\mathbb{R}^n\). As feature vectors propagate through the network, this embedding changes from \(\mathbb{R}^{n_i}\) to \(\mathbb{R}^{n_{i+1}}\) at each step.
The goal of this exercise is to investigate the geometry of what happens to the feature vectors at each step of this process. A commonly-repeated heuristic is that as it learns, a neural network classifier arranges the points into more-refined clusters at each step of the process.
The experiment #
Follow this outline as you work through the companion Colab notebook.
-
Choose a data set to work with. A few common ones are included in the notebook. To start, choose a small number of classes for your classification problem, say three.
-
Train a dense neural network classifier. Use a number of layers so that you can track the progression of the embeddings over the depth of the network. The notebook uses a depth of five, but feel free to vary this as you experiment. Note: we are interested in how the network arranges the training data for itself, so do not worry about overfitting. If you can, train the network until it reaches 100% (or close) accuracy.
-
Now visualize what is going on in each layer. Layer by layer, take the image of the training data, extract the output of the \(i\)th layer, reduce its dimension to two, and graph the results. Do successive layers construct better embeddings? Note: there is a choice of dimension reduction algorithms. Explore these choices freely, but the one based on the SVD preserves the geometry of the data most accurately.
In particular, I would like you to investigate the neural collapse phenomenon: although I am not an expert, the claim is that when a \(k\)-class classifier neural network reaches close to perfect accuracy on the training data, the embedding on the penultimate layer collapses onto \(k\) points. The role of the final layer is then just to figure out which point corresponds to which class.
- Pay particular attention when you examine the output of the penultimate layer. Does neural collapse occur? If so, what are the points used by the network?