AgglomerativeClustering with Cluster Centers

Tushar Sinha
2 min readApr 17, 2021

Recently I was working on a project that required me to compare Kmeans clustering to AgglomerativeClustering (using the sklearn package). I wanted to plot the clusters in 2d space and use the yellowbrick elbow visualiser to find the right number of clusters.

For those who don’t know yellowbrick, it is an awesome visualisation tool for machine learning algorithms. Check it out here:

However AgglomerativeClustering does not have a predict function which yellowbrick needs to compute the elbow plot. It also does not have the attribute cluster_centers_ which I needed to plot the centroids of each cluster.

Now cluster centers are simply the mean of all the rows in that cluster so it’s easy to compute but doing this repeatedly is a drag.

Since I needed to run the clustering algorithms several times for different metrics and methods, I chose to create a subclass, with the required functions and attributes as below.

Here, predict simply assigns the output, that is the predicted class labels from the parent class (sklearn.cluster.AgglomerativeClustering) to the in class variable labels_ which yellowbrick uses to show the elbow plot.

I also override the fit_predict function such that when it is called on an object of this class AgglomerativeClusteringWithPredict, it calls my override function. This function calls the fit_predict in the parent class and also calls the function makeClusterCenterList which computes the centroids in each cluster and stores it in cluster_centers_

This can then be used in the following manner:

The elbowplot from this code is as below:

I also wrote a custom function to show the clusters in 2D:

While I wrote the code for a 2D representation, this can be easily adapted to 3D and in fact that may be preferable when there are many features as the fewer dimensions one compresses to, the more distortions one creates in the way the data is clustered.

In any case I hope this is useful for folks looking to use yellowbrick with AgglomerativeClustering or looking for cluster centroids from AgglomerativeClustering. The full jupyter notebook is available on my github:

--

--

Tushar Sinha

Private Equity professional, passionate about applying data science and machine learning to alt assets