Let’s start with linear regression. Using established libraries such as scikit-learn, it is almost trivial to train a linear regression model. We can easily run the model training with a few hundred Megabytes of data on our laptop with a build-in CPU.
Now let’s think big.
First, what if we have much more data (TB, PB, or much more) than what we could store on a single local laptop? There has been a data explosion [2,3] with the wide use of internet, automatic data collection, and crowdsourcing labeling. For example, ImageNet [4, 5] contains more than 14 million hand-labeled images; Google now processes over 40,000 search queries every second on average, 3.5 billion searches per day [6]; YouTube has 300 hours of video uploaded every minute; Every minute on Facebook, 510,000 comments are posted, 293,000 statuses are updated, and 136,000 photos are uploaded every minute [3]. User behavior and browsing history has become valuable assets for online and offline commerce such as Amazon and Walmart as well as content platform such as YouTube and Netflix. In 2004, MapReduce [7] distributed file system was invented to provide a reliable and scalable data storage and processing solution. Companies with large data used to run their own in-house data center with the Hadoop Distributed File System (HDFS), and now are moving to cloud storage solutions with commercial platforms such as Google Cloud Platform and Amazon Web Services.
Second, what if the CPU on our laptop is too slow? CPU is very versatile in handling and coordinating various tasks. However for model training tasks, such as computation of a convolution matrix, we only need basic arithmetic and bitwise operations, many of them. Here is when GPU becomes extremely useful with its powerful parallel computing capacity for arithmetic operations. GPU and the new invention of TPU provide powerful computing hardware support for big data processing.
Third, scikit-learn provides many training algorithms such as linear regression and decision trees for tabular data. What if our training data is not a table, but unstructured data format such as image, text, and video? Deep neural network was introduced specifically to tackle this problem. Development of neural net architecture such as CNN, RNN, autoencoder, and attention models, accompanied by abstraction libraries such as Tensorflow and PyTorch, enables better training for more complex dataset.
Last, in classic gradient descent, we are computing gradient on all data in each iteration. When we have a lot of data with a lot of features, computing gradient on all data with gradient descent (called batch gradient descent) is slow. That’s where stochastic gradient descent (SGD) [8] becomes useful. In SGD, we use only 1 random sample to compute gradient in each iteration. This frequent update of the model, however, can be computationally expensive and generate a noisy gradient signal, taking longer time for the model to converge. In practice, mini-batch SGD with a few samples (32, 128, 1024, etc) is commonly used to balance model update frequency and computational efficiency.
Driven by big data, computing units and libraries, powerful model architecture, and efficient algorithms, ,ML model training is blurring the boundary of machine learning and systems.
In May 2019, a group of scientists and researchers from academia and industry published the MLSys white paper [9], raising full-stack MLSys challenges including specialized software design and libraries for end-to-end ML lifecycle, specialized hardware design for distributed data storage, model training and serving, and metrics optimization beyond prediction accuracy such as cost, latency, privacy, and fairness. These challenges call for wide collaboration among experts in machine learning and systems.
Read my post 5 things you need to know about Machine Learning Systems to learn more about MLSys.
In the next post, I will talk about distributed model training and how it works.
References
[1] cover image: https://datacenterfrontier.com/google-building-more-data-centers-for-massive-future-clouds/
[2] https://www.forbes.com/sites/bernardmarr/2018/05/21/how-much-data-do-we-create-every-day-the-mind-blowing-stats-everyone-should-read/#6ea76c7660ba
[3] https://blog.microfocus.com/how-much-data-is-created-on-the-internet-each-day/
[4] http://www.image-net.org/papers/imagenet_cvpr09.pdf
[5] http://image-net.org/about-overview
[6] https://www.internetlivestats.com/google-search-statistics/
[7] https://static.googleusercontent.com/media/research.google.com/en//archive/mapreduce-osdi04.pdf
[8] https://stats.stackexchange.com/questions/313681/who-invented-stochastic-gradient-descent
[9] https://arxiv.org/abs/1904.03257