Training on the device
As you’ve no doubt noticed, machine learning on mobile is a thing now. ???? Apple mentioned it about 100 times in the WWDC 2017 keynote. With all the hype, it’s no surprise that developers are scrambling to add ML to their apps (I can help out).
However… most of these machine learning models are only used for inference, they make predictions using a fixed set of knowledge. Despite the term “machine learning”, no actual learning happens on the device — the knowledge inside the model never improves from using it.
A big reason is that training a model takes a lot of computational power and mobile phones are just not fast enough. It’s more practical to train offline on a server farm and include any model improvements in an app update.
That said, training on the device does make sense for certain apps — and I believe that it’s only a matter of time before training models on the device becomes just as normal as using them for inference.
In this blog post I want to explore the possibilities. Let’s see if we can put “learning” into machine learning.
Today: inference only
The most common use of machine/deep learning in apps right now is probably computer vision for analyzing photos and videos. That makes sense because the iPhone is the most popular camera in the world.
But ML is not limited to just images, it’s also used for audio, language, time series, and many other types of data. A modern phone has a dozen different sensors plus fast internet access, so there is lots of data to feed into our models.
iOS itself uses several kinds of on-device deep learning models, such as the face detection in the Camera and Photos apps, listening for the “Hey Siri” phrase, and handwriting recognition for Chinese characters.
But all of these models do not learn from the user.
Pretty much all mobile machine learning APIs (MPSCNN, TensorFlow Lite, Caffe2) only support inference. That is, you can use models to make predictions based on the user’s data or behavior, but you cannot make these models learn new things from that data.
At the moment, training typically happens on a massive server with lots of GPUs. It’s a slow process that needs a lot of data. A convolutional neural network, for example, is trained on thousands or even millions of images. To train a modern CNN from scratch takes days on a powerful multi-GPU server, weeks on a desktop computer, and an eternity on a mobile device — definitely not something you can do on a single battery charge.
Training on a high-end server is a perfectly fine strategy for when model updates happen infrequently and each user always uses the exact same model. The app only gets model updates whenever it is updated in the App Store, or by periodically downloading new parameters from the cloud.
But just because training large models on the device isn’t feasible today, doesn’t mean it will be impossible forever. Also, not all models need to be large. And most importantly: one-model-for-everyone may not be the best we can do.
Why learn on the device?
There are advantages to training models on the device:
- The app can learn from the user’s data (or behavior).
- The data never needs to leave the device (good for privacy).
- Anything you can do on-device saves money. Running servers is expensive.
- You can always be learning and update the model continuously.
It’s not appropriate for all situations, but there are definitely applications of on-device learning that make sense. I think the mean benefit is that it allows you to tailor the model to the individual user.
The following apps already do learning on iOS devices:
The phone’s predictive keyboard learns from what you’re typing and makes suggestions for the next word in the sentence. This model is trained specifically for you, not other users. Since learning happens on the device, whatever you type is not sent to some cloud server where other people can snoop on it.
The Photos app automatically organizes images into a People album. I’m not exactly sure how this works, but it uses the face detection API to find any faces in the photos and then groups similar faces together. Maybe this is just an unsupervised clustering task — however, some learning must be involved, since the app lets you correct its mistakes and presumably it improves from that feedback. Regardless of how the algorithm works, this is a great example of customizing the app experience based on the user’s own data.
Touch ID / Face ID learn from your fingerprint or face. Face ID keeps learning over time, so if you grow a beard or start wearing glasses it will still recognize your face.
Activity detection. The Apple Watch learns your habits, such as what your heart rate is during different activities. Again, I’m not sure how this works but it’s clear that some kind of learning takes place.
The Clarifai Mobile SDK allows users to create custom image classification models by taking photos on their devices and tagging them. Unlike training a classification model from scratch — which takes thousands of images — their SDK can learn from just a few examples. I’m not sure how much learning happens on the device vs. their cloud service, but the ability to create image classifiers from your own photos (without having to be an expert in machine learning) has many practical uses.
Some of these tasks are simpler than others. Often “learning” is just a matter of remembering the last thing the user did. For many apps this is good enough and it doesn’t require any fancy machine learning algorithms.
The model for the predictive keyboard is simple enough that training can happen on the device in real-time. The Photos app’s people learning task is much slower and uses a lot of power, which is why this only works when the device is plugged in. Most practical uses of on-device learning will sit somewhere between these two extremes.
Other examples of existing software that learns from you are spam detection (your email client refines its idea of what spam is based on which emails you classify as junk), spelling and grammar correction (it learns the common mistakes you make while typing and fixes them), and smart calendars such as Google Now (learns to recognize repeated actions that you perform).
How far can we take this?
If the goal of training on the device is to adapt the machine learning model to the needs or usage patterns of specific users, then what sort of things can we do with this?
Here is a fun toy example: a neural net that turns gestures into Emoji. It asks you to draw a few different shapes and then it trains the model to detect these strokes.
This is implemented as a Swift Playground, which are not exactly known for being speedy. Even so, if doesn’t take very long to train this neural network — on a device it only takes a few seconds. (If you’re curious, here is a great description of how this model works.)
So if your model is not too complex — like this 2-layer neural net — training on the device is already in reach right now.
Note: On iPhone X, developers have access to a low-resolution 3D model of the user’s face from the Face ID sensors. You could use this data to train a similar model that chooses an emoji — or some action in your app — based on the user’s facial expressions.
Here are some future possibilities that are a bit more advanced:
Smart Reply is a model from Google that analyzes an incoming text or email message and suggests an appropriate reply. The model is not currently trained on the device, so it recommends the same kinds of reply to every user, but in theory it could be trained on the user’s own words — which would be better, since not everyone uses language in the same way.
Handwriting detection that is trained on your own handwriting, especially useful on the iPad Pro with the Pencil. Turning handwriting into text isn’t new but if your writing is as bad as mine, chances are that a standard model makes lots of mistakes while a custom model would learn the nuances of your style.
Speech recognition that gets more accurate as it is fine-tuned on your own voice over time, as we all sound different and speak in different dialects.
Sleep tracking / Fitness apps. Before these apps can make recommendations on how to improve your health, they first needs to learn about you. For privacy reasons, you may not want this data to leave your device.
Personalized conversational models. We’ll still have to see if chatbots will take off, but the advantage of talking to a bot is that is can adapt to you. When you “talk” to a chatbot, your device can learn about how you speak and like to be spoken to, and tailor the responses of the chatbot to suit your personality and conversational style. In other words, your device will instruct the chatbot how to converse with you. (Example: Siri could learn to be less of a wise-ass.)
Better advertising. No one likes ads, but machine learning can make them less annoying and intrusive, yet more profitable for advertisers. For example, an advertising SDK could learn how often you look at ads and tap on them, and which ads work better for you. The app could learn a local model, which is then used to request only ads that work for this particular user.
Making recommendations is a common use of ML. A podcast player could learn from the programs you have to listened to before and recommend what to listen to next. Apps that do this currently are cloud-based (Spotify etc), but there’s no reason why apps without a cloud component cannot learn to make recommendations.
For people with disabilities, apps could learn about that person’s environment to help them navigate it or understand it better. This is outside my area of expertise, but I can imagine that a less-abled person could benefit from apps that are tuned to their particular surroundings and the things they need help with (using the camera to tell apart their medicine bottles, for example).
These are just some ideas. Since everyone is different, it makes sense that the machine learning models we use will get tweaked to suit our specific needs and desires. Instead of building one-size-fits-all models, training on the device lets us build a unique model for every unique user.
Different scenarios for training models
Before you can deploy a machine learning model you first need to train it. And afterwards, you can keep training to refine the model. I believe the big benefit of training on the device is that you can customize the model to each user, and the key idea there is to train the model on that user’s data rather than with a generic dataset (or at least in addition to it).
These are the different options for learning from the user:
Don’t learn from user data at all. Collect your own data or use a publicly available dataset and build a one-size-fits-all model. Whenever you improve the model, release an app update or make the app download the new parameters. This is what the current crop of ML-enabled apps does: train offline, use the model for inference only. The point behind this blog post is to move on from that.
Central learning. If your app or service already requires that the data from the user (or about the user) is stored on your servers (non-encrypted, so you can read it), then it makes sense to do training on the server as well. Send the user’s data to the server and keep it there to learn things from it, possibly specific to that user or for all users in general. This is what platforms like Facebook do.
This setup has issues with privacy (there is none), security (all the data is in one place), scaling (more users means you need a bigger server), etc. Those are the situations we want to avoid by training on the device.
Note: There are other ways to avoid the privacy issue, such as what Apple does with their “differential privacy” approach to gathering user data, but that has its own shortcomings.
Collaborative learning. This is mostly just a way to move the cost of training to the users instead. Training happens on the device and each user trains a small part of the model. These partial model updates are shared with other users so they can also learn from your data, and you from theirs. But it’s still a one-size-fits-all model, as everyone still ends up with the same learned parameters.
The main benefit is that the training is decentralized and instead gets distributed over users’ devices. In theory this is better for privacy, but research shows it may actually be worse.
Each user trains their own model. This is the option that I’m personally most interested in, as it lets us customize machine learning for each individual user. The model can be learned from scratch (such as in the gesture-to-emoji example) or it can be a pre-trained model that is fine-tuned on your own data. In both cases, we can keep refining the model over time. For example, the predictive keyboard starts with a generic model trained in a specific language but over time it learns to predict the kinds of sentences that you write.
The downside of this train-your-own-model approach is that other users cannot benefit from the things the app has learned from you. So this really only makes sense for apps that use data that is relatively unique for each user.
How to actually do on-device training?
One thing to keep in mind is that learning from an individual user’s data is different from training on a large dataset. The initial model for the predictive keyboard may have been trained on a standard corpus (such as all text from Wikipedia) but a text message or email may not have the same writing style as a typical Wikipedia article. And this writing style will be different from one user to the next. The model must allow for these kinds of variations.
There is also the problem that our best training methods for (deep) models are brute-force and rather inefficient. As I’ve pointed out, training an image classifier can take days or weeks. The bigger the computer the better — with 1024 GPUs you can train an ImageNet classifier in 11 minutes.
It takes so long because the training process, Stochastic Gradient Descent (SGD), needs to take small steps. There are typically a million or so images in the dataset and the neural network looks at each image about 100 times.
Obviously, this training method is not feasible on a mobile device.
To be fair, you often don’t need to train a model from scratch. Most people take a pre-trained model and then use transfer learning to make it fit their own dataset. But these smaller datasets typically still consist of thousands of images, and so even transfer learning is rather slow.
It’s safe to say that with our current training methods, fine-tuning deep learning models on mobile is still a ways off.
However, not all is lost. For simple models it’s already possible to train them on the device. The gestures-to-emoji neural network we’ve seen is a basic feed-forward network with one hidden layer. It is trained using SGD with momentum and only takes a few seconds to complete.
Classical machine learning models, such as logistic regression, decision trees, or naive Bayes are typically very quick to train as well, especially when using second-order optimization methods such as L-BFGS or conjugate gradient. Even a basic recurrent neural network should be within the realm of possibilities.
For models like the predictive keyboard, some kind of “online” training method might work. Here, you do a single training pass after every X characters or words that the user types. Likewise for models using the accelerometer and motion data, where the data comes in as a constant stream of numbers. Since these models are trained on only a small piece of data at a time, each training update is fast.
So if your model is small and you have relatively little data, then the training time will be on the order of seconds on a modern device. If you do it in a background thread, no one will notice.
But if your model is not small, or you have a lot of data to process, then you need to get creative. A model that wants to learn the faces of the people in your photo library simply has a lot of data to go through, and so you’ll need to find a balance between the speed and accuracy of your learning algorithm.
These are some of the issues to overcome:
Large models. For deep neural networks, the current training methods are too slow and require too much data. There’s a lot of research going on right now in how to train deep models with very little data (such as a single photo) and with a small number of steps, known as “one-shot” / “few-shot” learning. I’m sure that any progress here will quickly translate into more people doing on-device training.
Multiple devices. You probably have more than one device — phone, tablet, computer. One problem that needs to be solved is how to share models (and the data used to train them) between these devices. For example, Photos on iOS 10 did not share what it had learned about people’s faces, so it had to re-learn this on all devices separately.
App updates. If your app comes with a pre-trained model that it customizes based on user data or behavior, then what happens when you improve the pre-trained model in an app update? You’ll need to apply the user’s local improvements again too (which probably means retraining the whole thing on the user’s data again).
It’s still early days for training on the device, but in my opinion it’s an inevitable technology — and one that will become important in the design of software.
I wrote this blog post as a way for me to think through what is already possible and where things may be heading. I hope you found it a useful exercise. ???? As always, I look forward to hearing your thoughts!
Source: Training on the device