This is a guest post from the Stanford University Computer Science Department. We thank Daniel Kang, Deepti Raghavan and Peter Bailis of Stanford University for their contributions.
Machine learning (ML) models are increasingly used in a wide range of business applications. Organizations deploy hundreds of ML models to predict customer churn, optimal pricing, fraud and more. Many of these models are deployed in situations where humans can’t verify all of the predictions - the data volumes are simply too large! As a result, monitoring these ML models is becoming crucial to successfully and accurately applying ML use cases.
In this blog post, we’ll show why monitoring models is critical and the catastrophic errors that can occur if we do not. Our solution leverages a simple, yet effective, tool for monitoring ML models we developed at Stanford University (published in MLSys 2020) called model assertions. We’ll also describe how to use our open-source Python library model_assertions to detect errors in real ML models.
Why we need monitoring
Let’s consider a simple example of estimating housing prices in Boston (dataset included in scikit-learn). This example is representative of standard use cases in the industry on a publicly available dataset. A data scientist might try to fit a linear regression model using features such as the average number of rooms to predict the price – such models are standard in practice. Using aggregate statistics to measure performance, like RMSE, shows that the model is performing reasonably well:
The model performance for test set -------------------------------------- Root Mean Squared Error: 4.93 R^2: 0.67
Unfortunately, while this model performs well on average, it makes some critical mistakes:
As highlighted above, the model predicts negative housing prices for some of the data. Using this model for setting housing prices would result in giving customers cash to purchase a house! If we only look at the aggregate metrics for our models, we would miss errors like these.
While seemingly simple, these kinds of errors are ubiquitous when using ML models. In our full paper, we also describe how to apply model assertions to autonomous vehicle and vision data (with an example about predicting attributes of TV news anchors here).
In the examples above, we see that ML models widely used in practice can produce inconsistent or nonsensical results. As a first step toward addressing these issues, we’ve developed an API called model assertions.
Model assertions let data scientists, developers and domain experts specify when errors in ML models may be occurring. A model assertion takes the inputs and outputs of a model and returns records containing potential errors.
Let’s look at an example with the housing price prediction model above. As a simple sanity check, a data scientist specifies that housing price predictions must be positive. After specifying and registering the assertion, it will flag potentially erroneous data points:
from model_assertions.checker import Checker from model_assertions.per_row import PerRowAssertion # Define the prediction function in a standard way def pred_fn(df, model=None): X = df.values y_pred = model.predict(X) return pd.DataFrame(y_pred, columns=['Price']) # Define the assertion that outputs should be positive def output_pos(_inp, out): return out
- Per-row assertions (e.g., that the output should be positive).
- Identifier consistency assertions that specify attributes of the same identifier should agree.
- Time consistency assertions that specify entities should not appear and disappear too many times in a time window.
And we plan on adding more!
In our full paper, we show other examples of how to use model assertions, including in autonomous vehicles, video analytics and ECG applications. In addition, we describe how to use model assertions for selecting training data. Using model assertions to select training data can be up to 40% cheaper than standard methods of selecting training data. Instead of selecting data at random or via uncertainty, selecting “hard” data points (i.e. data points with errors or ones that trigger model assertions) can be more informative.
Try the notebooks: