Guest Post: Vinod Bakthavachalam, Senior Data Scientist, Coursera
This November 29 - 30, Vinod will be presenting his most current research at the Applied AI Summit in Houston. Additional confirmed speakers include Rumman Chowdhury, Senior Principal at Accenture, Hao Yi Ong, Research Scientist at Lyft, Daniel Ellis. Tech Lead at Reddit and many more. Register before October 5 to save 25% with Early Bird discounted passes.
Machine learning and artificial intelligence have revolutionized data science by allowing the creation of automated models that can self learn intricate patterns from the data. Countless data products ranging from classical recommendation systems to complex facial recognition software have turned to machine learning to improve their performance.
A lot of the recent buzz in this space has been around deep learning and neural networks and not without justification as a lot of the headline improvements in the products above have been powered by this class of models.
We can see just how popular these terms have become in recent years in the Google Trends graph below where as recently as 2016, the worldwide interest in deep learning (in red) and neural networks (in yellow) was on par with interest in linear regression (in blue). Now however, interest in both deep learning and neural networks have soared above that of linear regression.
When we talk about the cutting edge of machine learning and artificial intelligence, it is all about the power of deep learning and neural networks to achieve new, unseen levels of performance. One would think that with this buzz the simple linear regression model would fall by the wayside. It turns out though that there is still much value to be had from simple machine learning models and in fact at Coursera, we often prefer the simple linear regression or logistic regression algorithms to deep learning models.
When people discuss the benefits of deep learning vs. other machine learning methods like regression, they typically focus on a graph like this:
It highlights the fact that the performance of traditional machine learning methods plateaus while the performance of deep learning does not as we add more data. What this graph is missing though is the magnitude of the increase in performance on the y-axis.
Often times the improvement when using deep learning compared to simpler methods is economically insignificant as regression can get 80%-90% of the way there. Now for some companies and problems that last 10%-20% can mean millions of dollars but for many startups and smaller companies, the tradeoff of increased training and optimization time for that small boost in performance is not worth it.
Second, when it comes to extracting insights, deep learning models make it hard to understand how they arrive at their predictions and are often a black box. This can limit their ability to inform downstream applications.
As an example, at Coursera we use machine learning to identify at risk degree students and power automated interventions to ensure these students stay on track in their coursework. We could probably use a neural network to achieve greater predictive accuracy, but the results of that model would not be effective in providing insight into the reasons behind a student’s struggles so as to make the best intervention to get them on track.
Knowing that a student is at risk in and of itself does not immediately inform how to help him or her. Is the problem that the student has been absent for the last week and might have a life circumstance interfering with the course schedule? Or is it that the student has been struggling with programming assignments and might need more attention from a TA? These two situations are addressed in very different ways — in this case even by different people at our partner institutions — and it’s important that our at-risk predictions tackle this nuance.
Using simpler regression models allows us to easily output the relative importance of each feature (or set of features) driving the prediction to identify the right intervention for each at risk student, allowing true personalization.
Often times the best machine learning algorithm to use is the simplest as it gets the majority of the way there and allows us to easily understand the predictions to power downstream applications. It might no longer be buzz worthy but regression can still carry the day.