CLV Part 1: Customer Lifetimes(Python)
Loading...

Calculating the Probability of Future Customer Engagement

In non-subscription retail models, customers come and go with no long-term commitments, making it very difficult to determine whether a customer will return in the future. Determining the probability that a customer will re-engage is critical to the design of effective marketing campaigns. Different messaging and promotions may be required to incentivize customers who have likely dropped out to return to our stores. Engaged customers may be more responsive to marketing that encourages them to expand the breadth and scale of purchases with us. Understanding where our customers land with regard to the probability of future engagement is critical to tailoring our marketing efforts to them.

The Buy 'til You Die (BTYD) models popularized by Peter Fader and others leverage two basic customer metrics, i.e. the recency of a customer's last engagement and the frequency of repeat transactions over a customer's lifetime, to derive a probability of future re-engagement. This is done by fitting customer history to curves describing the distribution of purchase frequencies and engagement drop-off following a prior purchase. The math behind these models is fairly complex but thankfully it's been encapsulated in the lifetimes library, making it much easier for traditional enterprises to employ. The purpose of this notebook is to examine how these models may be applied to customer transaction history and how they may be deployed for integration in marketing processes.

Step 1: Setup the Environment

To run this notebook, you need to attach to a Databricks ML Runtime cluster leveraging Databricks version 6.5+. This version of the Databricks runtime will provide access to many of the pre-configured libraries used here. Still, there are additional Python libraries which you will need to install and attach to your cluster. These are:

  • xlrd
  • lifetimes==0.10.1
  • nbconvert

To install these libraries in your Databricks workspace, please follow these steps using the PyPI library source in combination with the bullet-pointed library names in the provided list. Once installed, please be sure to attach these libraries to the cluster with which you are running this notebook.

With the libraries installed, let's load a sample dataset with which we can examine the BTYD models. The dataset we will use is the Online Retail Data Set available from the UCI Machine Learning Repository. This dataset is made available as a Microsoft Excel workbook (XLSX). Having downloaded this XLSX file to our local system, we can load it into our Databricks environment by following the steps provided here. Please note when performing the file import, you don't need to select the Create Table with UI or the Create Table in Notebook options to complete the import process. Also, the name of the XLSX file will be modified upon import as it includes an unsupported space character. As a result, we will need to programmatically locate the new name for the file assigned by the import process.

Assuming we've uploaded the XLSX to the /FileStore/tables/online_retail/, we can access it as follows:

import pandas as pd
import numpy as np

# identify name of xlsx file (which will change when uploaded)
xlsx_filename = dbutils.fs.ls('file:///dbfs/FileStore/tables/online_retail')[0][0]

# schema of the excel spreadsheet data range
orders_schema = {
  'InvoiceNo':np.str,
  'StockCode':np.str,
  'Description':np.str,
  'Quantity':np.int64,
  'InvoiceDate':np.datetime64,
  'UnitPrice':np.float64,
  'CustomerID':np.str,
  'Country':np.str  
  }

# read spreadsheet to pandas dataframe
# the xlrd library must be installed for this step to work 
orders_pd = pd.read_excel(
  xlsx_filename, 
  sheet_name='Online Retail',
  header=0, # first row is header
  dtype=orders_schema
  )

# display first few rows from the dataset
orders_pd.head(10)
Out[1]:

The data in the workbook are organized as a range in the Online Retail spreadsheet. Each record represents a line item in a sales transaction. The fields included in the dataset are:

Field Description
InvoiceNo A 6-digit integral number uniquely assigned to each transaction
StockCode A 5-digit integral number uniquely assigned to each distinct product
Description The product (item) name
Quantity The quantities of each product (item) per transaction
InvoiceDate The invoice date and a time in mm/dd/yy hh:mm format
UnitPrice The per-unit product price in pound sterling (£)
CustomerID A 5-digit integral number uniquely assigned to each customer
Country The name of the country where each customer resides

Of these fields, the ones of particular interest for our work are InvoiceNo which identifies the transaction, InvoiceDate which identifies the date of that transaction, and CustomerID which uniquely identifies the customer across multiple transactions. (In a separate notebook, we will examine the monetary value of the transactions through the UnitPrice and Quantity fields.)

Step 2: Explore the Dataset

To enable the exploration of the data using SQL statements, let's flip the pandas DataFrame into a Spark DataFrame and persist it as a temporary view:

# convert pandas DF to Spark DF
orders = spark.createDataFrame(orders_pd)

# present Spark DF as queriable view
orders.createOrReplaceTempView('orders') 

Examining the transaction activity in our dataset, we can see the first transaction occurs December 1, 2010 and the last is on December 9, 2011 making this a dataset that's a little more than 1 year in duration. The daily transaction count shows there is quite a bit of volatility in daily activity for this online retailer:

%sql -- unique transactions by date

SELECT 
  TO_DATE(InvoiceDate) as InvoiceDate,
  COUNT(DISTINCT InvoiceNo) as Transactions
FROM orders
GROUP BY TO_DATE(InvoiceDate)
ORDER BY InvoiceDate;
InvoiceDateJanJul201150100150200TransactionsTransactions

We can smooth this out a bit by summarizing activity by month. It's important to keep in mind that December 2011 only consists of 9 days so the sales decline graphed for the last month should most likely be ignored:

NOTE We will hide the SQL behind each of the following result sets for ease of viewing. To view this code, simply click the Show code item above each of the following charts.

Show code
InvoiceMonthJanJul20111.5k2.0k2.5k3.0kTransactionsTransactions

For the little more than 1-year period for which we have data, we see over four-thousand unique customers. These customers generated about twenty-two thousand unique transactions:

Show code
437222190

A little quick math may lead us to estimate that, on average, each customer is responsible for about 5 transactions, but this would not provide an accurate representation of customer activity.

Instead, if we count the unique transactions by customer and then examine the frequency of these values, we see that many of the customers have engaged in a single transaction. The distribution of the count of repeat purchases declines from there in a manner that we may describe as negative binomial distribution (which is the basis of the NBD acronym included in the name of most BTYD models):

Show code
0.002004006008001.0k1.2k1.4k1357911131517192123252729313335384043465052546063757786118169248TOOLTIPTransactionsOccurrences

If we alter our last analysis to group a customer's transactions that occur on the same date into a single transaction - a pattern that aligns with metrics we will calculate later - we see that a few more customers are identified as non-repeat customers but the overall pattern remains the same:

Show code
-2000.002004006008001.0k1.2k1.4k1.6k1357911131517192123252729313335373941454753647289113146TOOLTIPTransactionsOccurances

Focusing on customers with repeat purchases, we can examine the distribution of the days between purchase events. What's important to note here is that most customers return to the site within 2 to 3 months of a prior purchase. Longer gaps do occur but significantly fewer customers have longer gaps between returns. This is important to understand in the context of our BYTD models in that the time since we last saw a customer is a critical factor to determining whether they will ever come back with the probability of return dropping as more and more time passes since a customer's last purchase event:

Show code
0.000.010.020.030.040.050.060.070.080.090.100.0020406080100120140160180200220240260280300320340360AvgDaysBetweenDensity

Aggregated (by count) in the backend.

Step 3: Calculate Customer Metrics

The dataset with which we are working consists of raw transactional history. To apply the BTYD models, we need to derive several per-customer metrics:

  • Frequency - the number of dates on which a customer made a purchase subsequent to the date of the customer's first purchase
  • Age (T) - the number of time units, e.g. days, since the date of a customer's first purchase to the current date (or last date in the dataset)
  • Recency - the age of the customer (as previously defined) at the time of their last purchase

It's important to note that when calculating metrics such as customer age that we need to consider when our dataset terminates. Calculating these metrics relative to today's date can lead to erroneous results. Given this, we will identify the last date in the dataset and define that as today's date for all calculations.

To get started with these calculations, let's take a look at how they are performed using the built-in functionality of the lifetimes library:

import lifetimes

# set the last transaction date as the end point for this historical dataset
current_date = orders_pd['InvoiceDate'].max()

# calculate the required customer metrics
metrics_pd = (
  lifetimes.utils.summary_data_from_transaction_data(
    orders_pd,
    customer_id_col='CustomerID',
    datetime_col='InvoiceDate',
    observation_period_end = current_date, 
    freq='D'
    )
  )

# display first few rows
metrics_pd.head(10)
Out[3]:

The lifetimes library, like many Python libraries, is single-threaded. Using this library to derive customer metrics on larger transactional datasets may overwhelm your system or simply take too long to complete. For this reason, let's examine how these metrics can be calculated using the distributed capabilities of Apache Spark.

As SQL is frequency employed for complex data manipulation, we'll start with a Spark SQL statement. In this statement, we first assemble each customer's order history consisting of the customer's ID, the date of their first purchase (first_at), the date on which a purchase was observed (transaction_at) and the current date (using the last date in the dataset for this value). From this history, we can count the number of repeat transaction dates (frequency), the days between the last and first transaction dates (recency), and the days between the current date and first transaction (T) on a per-customer basis:

# sql statement to derive summary customer stats
sql = '''
  SELECT
    a.customerid as CustomerID,
    CAST(COUNT(DISTINCT a.transaction_at) - 1 as float) as frequency,
    CAST(DATEDIFF(MAX(a.transaction_at), a.first_at) as float) as recency,
    CAST(DATEDIFF(a.current_dt, a.first_at) as float) as T
  FROM ( -- customer order history
    SELECT DISTINCT
      x.customerid,
      z.first_at,
      TO_DATE(x.invoicedate) as transaction_at,
      y.current_dt
    FROM orders x
    CROSS JOIN (SELECT MAX(TO_DATE(invoicedate)) as current_dt FROM orders) y                                -- current date (according to dataset)
    INNER JOIN (SELECT customerid, MIN(TO_DATE(invoicedate)) as first_at FROM orders GROUP BY customerid) z  -- first order per customer
      ON x.customerid=z.customerid
    WHERE x.customerid IS NOT NULL
    ) a
  GROUP BY a.customerid, a.current_dt, a.first_at
  ORDER BY CustomerID
  '''

# capture stats in dataframe 
metrics_sql = spark.sql(sql)

# display stats
display(metrics_sql)  
1234600325
123476365367
123483283358
123490018
1235000310
123526260296
1235300204
1235400232
1235500214
123562303325
123570033
123581149150
123595324331
123602148200
1236100287
1236212292295
123631133242
123643105112
1236500291
12367004
123703309360
1237111559
123722225296
1237300311
123740025
1237529698
12377139354
1237800129
12379289170
123804164185
123815115119
123835168352
12384293121
12386129366
123885311326
123900079
123910021
123933260332
123941154217
1239513356371
123971100135
123980045
123993142261
1240100303
1240200323
124030049
1240500148
124062161183
124074215264
124086228260
124093104182
1241017308
12412295169
124133271337
12414293310
1241517313337
1241711354357
1241800112
124200063
124213304319
124222229324
124238353353
1242400162
124250078
1242600194
124274360371
124289258283
124293356365
124300043
1243115338373
124322130172
124335373373
124343276360
124351188267
124360099
1243715330331
124381126140
1244100366
12442003
124444150171
124450022
124460057
1244700243
124480044
124493165187
12450117173
124515314324
124520016
1245300134
124541356
124555223296
124563210254
124578181239
124581213284
12461160154
124623301303
124635195241
124646299309
124652162169
124682173316

Showing the first 1000 rows.

Of course, Spark SQL does not require the DataFrame to be accessed exclusively using a SQL statement. We may derive this same result using the Programmatic SQL API which may align better with some Data Scientist's preferences. The code in the next cell is purposely assembled to mirror the structure in the previous SQL statement for the purposes of comparison:

from pyspark.sql.functions import to_date, datediff, max, min, countDistinct, count, sum, when
from pyspark.sql.types import *

# valid customer orders
x = orders.where(orders.CustomerID.isNotNull())

# calculate last date in dataset
y = (
  orders
    .groupBy()
    .agg(max(to_date(orders.InvoiceDate)).alias('current_dt'))
  )

# calculate first transaction date by customer
z = (
  orders
    .groupBy(orders.CustomerID)
    .agg(min(to_date(orders.InvoiceDate)).alias('first_at'))
  )

# combine customer history with date info 
a = (x
    .crossJoin(y)
    .join(z, x.CustomerID==z.CustomerID, how='inner')
    .select(
      x.CustomerID.alias('customerid'), 
      z.first_at, 
      to_date(x.InvoiceDate).alias('transaction_at'), 
      y.current_dt
      )
     .distinct()
    )

# calculate relevant metrics by customer
metrics_api = (a
           .groupBy(a.customerid, a.current_dt, a.first_at)
           .agg(
             (countDistinct(a.transaction_at)-1).cast(FloatType()).alias('frequency'),
             datediff(max(a.transaction_at), a.first_at).cast(FloatType()).alias('recency'),
             datediff(a.current_dt, a.first_at).cast(FloatType()).alias('T')
             )
           .select('customerid','frequency','recency','T')
           .orderBy('customerid')
          )

display(metrics_api)
1234600325
123476365367
123483283358
123490018
1235000310
123526260296
1235300204
1235400232
1235500214
123562303325
123570033
123581149150
123595324331
123602148200
1236100287
1236212292295
123631133242
123643105112
1236500291
12367004
123703309360
1237111559
123722225296
1237300311
123740025
1237529698
12377139354
1237800129
12379289170
123804164185
123815115119
123835168352
12384293121
12386129366
123885311326
123900079
123910021
123933260332
123941154217
1239513356371
123971100135
123980045
123993142261
1240100303
1240200323
124030049
1240500148
124062161183
124074215264
124086228260
124093104182
1241017308
12412295169
124133271337
12414293310
1241517313337
1241711354357
1241800112
124200063
124213304319
124222229324
124238353353
1242400162
124250078
1242600194
124274360371
124289258283
124293356365
124300043
1243115338373
124322130172
124335373373
124343276360
124351188267
124360099
1243715330331
124381126140
1244100366
12442003
124444150171
124450022
124460057
1244700243
124480044
124493165187
12450117173
124515314324
124520016
1245300134
124541356
124555223296
124563210254
124578181239
124581213284
12461160154
124623301303
124635195241
124646299309
124652162169
124682173316

Showing the first 1000 rows.

Let's take a moment to compare the data in these different metrics datasets, just to confirm the results are identical. Instead of doing this record by record, let's calculate summary statistics across each dataset to verify their consistency:

NOTE You may notice means and standard deviations vary slightly in the hundred-thousandths and millionths decimal places. This is a result of slight differences in data types between the pandas and Spark DataFrames but do not affect our results in a meaningful way.

# summary data from lifetimes
metrics_pd.describe()
Out[6]:
# summary data from SQL statement
metrics_sql.toPandas().describe()
Out[7]:
# summary data from pyspark.sql API
metrics_api.toPandas().describe()
Out[8]:

The metrics we've calculated represent summaries of a time series of data. To support model validation and avoid overfitting, a common pattern with time series data is to train models on an earlier portion of the time series (known as the calibration period) and validate against a later portion of the time series (known as the holdout period). In the lifetimes library, the derivation of per customer metrics using calibration and holdout periods is done through a simple method call. Because our dataset consists of a limited range for data, we will instruct this library method to use the last 90-days of data as the holdout period. A simple parameter called a widget on the Databricks platform has been implemented to make the configuration of this setting easily changeable:

NOTE To change the number of days in the holdout period, look for the textbox widget by scrolling to the top of your Databricks notebook after running this next cell

# define a notebook parameter making holdout days configurable (90-days default)
dbutils.widgets.text('holdout days', '90')
from datetime import timedelta

# set the last transaction date as the end point for this historical dataset
current_date = orders_pd['InvoiceDate'].max()

# define end of calibration period
holdout_days = int(dbutils.widgets.get('holdout days'))
calibration_end_date = current_date - timedelta(days = holdout_days)

# calculate the required customer metrics
metrics_cal_pd = (
  lifetimes.utils.calibration_and_holdout_data(
    orders_pd,
    customer_id_col='CustomerID',
    datetime_col='InvoiceDate',
    observation_period_end = current_date,
    calibration_period_end=calibration_end_date,
    freq='D'    
    )
  )

# display first few rows
metrics_cal_pd.head(10)
Out[10]:

As before, we may leverage Spark SQL to derive this same information. Again, we'll examine this through both a SQL statement and the programmatic SQL API.

To understand the SQL statement, first recognize that it's divided into two main parts. In the first, we calculate the core metrics, i.e. recency, frequency and age (T), per customer for the calibration period, much like we did in the previous query example. In the second part of the query, we calculate the number of purchase dates in the holdout customer for each customer. This value (frequency_holdout) represents the incremental value to be added to the frequency for the calibration period (frequency_cal) when we examine a customer's entire transaction history across both calibration and holdout periods.

To simplify our logic, a common table expression (CTE) named CustomerHistory is defined at the top of the query. This query extracts the relevant dates that make up a customer's transaction history and closely mirrors the logic at the center of the last SQL statement we examined. The only difference is that we include the number of days in the holdout period (duration_holdout):

sql = '''
WITH CustomerHistory 
  AS (
    SELECT  -- nesting req'ed b/c can't SELECT DISTINCT on widget parameter
      m.*,
      getArgument('holdout days') as duration_holdout
    FROM (
      SELECT DISTINCT
        x.customerid,
        z.first_at,
        TO_DATE(x.invoicedate) as transaction_at,
        y.current_dt
      FROM orders x
      CROSS JOIN (SELECT MAX(TO_DATE(invoicedate)) as current_dt FROM orders) y                                -- current date (according to dataset)
      INNER JOIN (SELECT customerid, MIN(TO_DATE(invoicedate)) as first_at FROM orders GROUP BY customerid) z  -- first order per customer
        ON x.customerid=z.customerid
      WHERE x.customerid IS NOT NULL
    ) m
  )
SELECT
    a.customerid as CustomerID,
    a.frequency as frequency_cal,
    a.recency as recency_cal,
    a.T as T_cal,
    COALESCE(b.frequency_holdout, 0.0) as frequency_holdout,
    a.duration_holdout
FROM ( -- CALIBRATION PERIOD CALCULATIONS
    SELECT
        p.customerid,
        CAST(p.duration_holdout as float) as duration_holdout,
        CAST(DATEDIFF(MAX(p.transaction_at), p.first_at) as float) as recency,
        CAST(COUNT(DISTINCT p.transaction_at) - 1 as float) as frequency,
        CAST(DATEDIFF(DATE_SUB(p.current_dt, p.duration_holdout), p.first_at) as float) as T
    FROM CustomerHistory p
    WHERE p.transaction_at < DATE_SUB(p.current_dt, p.duration_holdout)  -- LIMIT THIS QUERY TO DATA IN THE CALIBRATION PERIOD
    GROUP BY p.customerid, p.first_at, p.current_dt, p.duration_holdout
  ) a
LEFT OUTER JOIN ( -- HOLDOUT PERIOD CALCULATIONS
  SELECT
    p.customerid,
    CAST(COUNT(DISTINCT p.transaction_at) as float) as frequency_holdout
  FROM CustomerHistory p
  WHERE 
    p.transaction_at >= DATE_SUB(p.current_dt, p.duration_holdout) AND  -- LIMIT THIS QUERY TO DATA IN THE HOLDOUT PERIOD
    p.transaction_at <= p.current_dt
  GROUP BY p.customerid
  ) b
  ON a.customerid=b.customerid
ORDER BY CustomerID
'''

metrics_cal_sql = spark.sql(sql)
display(metrics_cal_sql)
1234600235090
123474238277290
123482110268190
1235000220090
12352334206390
1235300114090
1235400142090
1235500124090
12356180235190
123580060190
123593142241290
12360188110190
1236100197090
123625183205790
123631133152090
123640022390
1236500201090
12370286270190
12372184206190
1237300221090
12375008290
12377139264090
123780039090
1237911580190
123800095490
1238111929490
123835168262090
123840031290
12386129276090
123883178236290
12393287242190
1239400127190
123959259281490
123970045190
123993142171090
1240100213090
1240200233090
124050058090
1240616193190
124073130174190
124085162170190
1240916592290
1241017218090
1241213279190
124132101247190
12414293220090
1241512238247590
124178251267390
124180022090
124212205229190
124222229234090
124237255263190
124240072090
1242600104090
12427120281390
124288175193190
124292193275190
1243112254283390
124321282190
124332282283390
124342111270190
1243500177190
12436009090
124379233241690
124380050190
1244100276090
1244415081390
1244700153090
1244929697190
1245011783090
124512177234390
124530044090
124554169206190
124562146164190
124573110149590
1245800194190
1246116064090
1246219213290
12463366151290
124644203219290
124650079290
124682173226090
12471262802821290
124728235283490
12473277101290
1247416270278990
124768234271590
124775220221390
124802193246190
124816186275290
124837241249390
124844236246390
1248900246090
124905236239490
124920015190
12493232107090
124947208275390
125005227243590
12501121246090
125024214219090

Showing the first 1000 rows.

And here is the equivalent Programmatic SQL API logic:

from pyspark.sql.functions import avg, date_sub, coalesce, lit, expr

# valid customer orders
x = orders.where(orders.CustomerID.isNotNull())

# calculate last date in dataset
y = (
  orders
    .groupBy()
    .agg(max(to_date(orders.InvoiceDate)).alias('current_dt'))
  )

# calculate first transaction date by customer
z = (
  orders
    .groupBy(orders.CustomerID)
    .agg(min(to_date(orders.InvoiceDate)).alias('first_at'))
  )

# combine customer history with date info (CUSTOMER HISTORY)
p = (x
    .crossJoin(y)
    .join(z, x.CustomerID==z.CustomerID, how='inner')
    .withColumn('duration_holdout', lit(int(dbutils.widgets.get('holdout days'))))
    .select(
      x.CustomerID.alias('customerid'),
      z.first_at, 
      to_date(x.InvoiceDate).alias('transaction_at'), 
      y.current_dt, 
      'duration_holdout'
      )
     .distinct()
    )

# calculate relevant metrics by customer
# note: date_sub requires a single integer value unless employed within an expr() call
a = (p
       .where(p.transaction_at < expr('date_sub(current_dt, duration_holdout)')) 
       .groupBy(p.customerid, p.current_dt, p.duration_holdout, p.first_at)
       .agg(
         (countDistinct(p.transaction_at)-1).cast(FloatType()).alias('frequency_cal'),
         datediff( max(p.transaction_at), p.first_at).cast(FloatType()).alias('recency_cal'),
         datediff( expr('date_sub(current_dt, duration_holdout)'), p.first_at).cast(FloatType()).alias('T_cal')
       )
    )

b = (p
      .where((p.transaction_at >= expr('date_sub(current_dt, duration_holdout)')) & (p.transaction_at <= p.current_dt) )
      .groupBy(p.customerid)
      .agg(
        countDistinct(p.transaction_at).cast(FloatType()).alias('frequency_holdout')
        )
   )

metrics_cal_api = (a
                 .join(b, a.customerid==b.customerid, how='left')
                 .select(
                   a.customerid.alias('CustomerID'),
                   a.frequency_cal,
                   a.recency_cal,
                   a.T_cal,
                   coalesce(b.frequency_holdout, lit(0.0)).alias('frequency_holdout'),
                   a.duration_holdout
                   )
                 .orderBy('CustomerID')
              )

display(metrics_cal_api)
1234600235090
123474238277290
123482110268190
1235000220090
12352334206390
1235300114090
1235400142090
1235500124090
12356180235190
123580060190
123593142241290
12360188110190
1236100197090
123625183205790
123631133152090
123640022390
1236500201090
12370286270190
12372184206190
1237300221090
12375008290
12377139264090
123780039090
1237911580190
123800095490
1238111929490
123835168262090
123840031290
12386129276090
123883178236290
12393287242190
1239400127190
123959259281490
123970045190
123993142171090
1240100213090
1240200233090
124050058090
1240616193190
124073130174190
124085162170190
1240916592290
1241017218090
1241213279190
124132101247190
12414293220090
1241512238247590
124178251267390
124180022090
124212205229190
124222229234090
124237255263190
124240072090
1242600104090
12427120281390
124288175193190
124292193275190
1243112254283390
124321282190
124332282283390
124342111270190
1243500177190
12436009090
124379233241690
124380050190
1244100276090
1244415081390
1244700153090
1244929697190
1245011783090
124512177234390
124530044090
124554169206190
124562146164190
124573110149590
1245800194190
1246116064090
1246219213290
12463366151290
124644203219290
124650079290
124682173226090
12471262802821290
124728235283490
12473277101290
1247416270278990
124768234271590
124775220221390
124802193246190
124816186275290
124837241249390
124844236246390
1248900246090
124905236239490
124920015190
12493232107090
124947208275390
125005227243590
12501121246090
125024214219090

Showing the first 1000 rows.

Using summary stats, we can again verify these different units of logic are returning the same results:

# summary data from lifetimes
metrics_cal_pd.describe()
Out[13]:
# summary data from SQL statement
metrics_cal_sql.toPandas().describe()
Out[14]:
# summary data from pyspark.sql API
metrics_cal_api.toPandas().describe()
Out[15]:

Our data prep is nearly done. The last thing we need to do is exclude customers for which we have no repeat purchases, i.e. frequency or frequency_cal is 0. The Pareto/NBD and BG/NBD models we will use focus exclusively on performing calculations on customers with repeat transactions. A modified BG/NBD model, i.e. MBG/NBD, which allows for customers with no repeat transactions is supported by the lifetimes library. However, to stick with the two most popular of the BYTD models in use today, we will limit our data to align with their requirements:

NOTE We are showing how both the pandas and Spark DataFrames are filtered simply to be consistent with side-by-side comparisons earlier in this section of the notebook. In a real-world implementation, you would simply choose to work with pandas or Spark DataFrames for data preparation.

# remove customers with no repeats (complete dataset)
filtered_pd = metrics_pd[metrics_pd['frequency'] > 0]
filtered = metrics_api.where(metrics_api.frequency > 0)

## remove customers with no repeats in calibration period
filtered_cal_pd = metrics_cal_pd[metrics_cal_pd['frequency_cal'] > 0]
filtered_cal = metrics_cal_api.where(metrics_cal_api.frequency_cal > 0)

Step 4: Train the Model

To ease into the training of a model, let's start with a simple exercise using a Pareto/NBD model, the original BTYD model. We'll use the calibration-holdout dataset constructed in the last section of this notebook, fitting the model to the calibration data and later evaluating it using the holdout data:

from lifetimes.fitters.pareto_nbd_fitter import ParetoNBDFitter
from lifetimes.fitters.beta_geo_fitter import BetaGeoFitter

# load spark dataframe to pandas dataframe
input_pd = filtered_cal.toPandas()

# fit a model
model = ParetoNBDFitter(penalizer_coef=0.0)
model.fit( input_pd['frequency_cal'], input_pd['recency_cal'], input_pd['T_cal'])
Out[17]: <lifetimes.ParetoNBDFitter: fitted with 2163 subjects, alpha: 96.96, beta: 3014.97, r: 1.99, s: 0.84>

With our model now fit, let's make some predictions for the holdout period. We'll grab the actuals for that same period to enable comparison in a subsequent step:

# get predicted frequency during holdout period
frequency_holdout_predicted = model.predict( input_pd['duration_holdout'], input_pd['frequency_cal'], input_pd['recency_cal'], input_pd['T_cal'])

# get actual frequency during holdout period
frequency_holdout_actual = input_pd['frequency_holdout']
/databricks/python/lib/python3.7/site-packages/numpy/core/fromnumeric.py:86: RuntimeWarning: invalid value encountered in reduce return ufunc.reduce(obj, axis, dtype, out, **passkwargs)

With actual and predicted values in hand, we can calculate some standard evaluation metrics. Let's wrap those calculations in a function call to make evaluation easier in future steps:

import numpy as np

def score_model(actuals, predicted, metric='mse'):
  # make sure metric name is lower case
  metric = metric.lower()
  
  # Mean Squared Error and Root Mean Squared Error
  if metric=='mse' or metric=='rmse':
    val = np.sum(np.square(actuals-predicted))/actuals.shape[0]
    if metric=='rmse':
        val = np.sqrt(val)
  
  # Mean Absolute Error
  elif metric=='mae':
    np.sum(np.abs(actuals-predicted))/actuals.shape[0]
  
  else:
    val = None
  
  return val

# score the model
print('MSE: {0}'.format(score_model(frequency_holdout_actual, frequency_holdout_predicted, 'mse')))
MSE: 3.102822341084317

While the internals of the Pareto/NBD model may be quite complex. In a nutshell, the model calculates a double integral of two curves, one which describes the frequency of customer purchases within a population and another which describes customer survivorship following a prior purchase event. All of the calculation logic is thankfully hidden behind a simple method call.

As simple as training a model may be, we have two models that we could use here: the Pareto/NBD model and the BG/NBD model. The BG/NBD model simplifies the math involved in calculating customer lifetime and is the model that popularized the BTYD approach. Both models work off the same customer features and employ the same constraints. (The primary difference between the two models is that the BG/NBD model maps the survivorship curve to a beta-geometric distribution instead of a Pareto distribution.) To achieve the best fit possible, it is worthwhile to compare the results of both models with our dataset.

Each model leverages an L2-norm regularization parameter which we've arbitrarily set to 0 in the previous training cycle. In addition to exploring which model works best, we should consider which value (between 0 and 1) works best for this parameter. This gives us a pretty broad search space to explore with some hyperparameter tuning.

To assist us with this, we will make use of hyperopt. Hyperopt allows us to parallelize the training and evaluation of models against a hyperparameter search space. This can be done leveraging the multiprocessor resources of a single machine or across the broader resources provided by a Spark cluster. With each model iteration, a loss function is calculated. Using various optimization algorithms, hyperopt navigates the search space to locate the best available combination of parameter settings to minimize the value returned by the loss function.

To make use of hyperopt, lets define our search space and re-write our model training and evaluation logic to provide a single function call which will return a loss function measure:

from hyperopt import hp, fmin, tpe, rand, SparkTrials, STATUS_OK, STATUS_FAIL, space_eval

# define search space
search_space = hp.choice('model_type',[
                  {'type':'Pareto/NBD', 'l2':hp.uniform('pareto_nbd_l2', 0.0, 1.0)},
                  {'type':'BG/NBD'    , 'l2':hp.uniform('bg_nbd_l2', 0.0, 1.0)}  
                  ]
                )

# define function for model evaluation
def evaluate_model(params):
  
  # accesss replicated input_pd dataframe
  data = inputs.value
  
  # retrieve incoming parameters
  model_type = params['type']
  l2_reg = params['l2']
  
  # instantiate and configure the model
  if model_type == 'BG/NBD':
    model = BetaGeoFitter(penalizer_coef=l2_reg)
  elif model_type == 'Pareto/NBD':
    model = ParetoNBDFitter(penalizer_coef=l2_reg)
  else:
    return {'loss': None, 'status': STATUS_FAIL}
  
  # fit the model
  model.fit(data['frequency_cal'], data['recency_cal'], data['T_cal'])
  
  # evaluate the model
  frequency_holdout_actual = data['frequency_holdout']
  frequency_holdout_predicted = model.predict(data['duration_holdout'], data['frequency_cal'], data['recency_cal'], data['T_cal'])
  mse = score_model(frequency_holdout_actual, frequency_holdout_predicted, 'mse')
  
  # return score and status
  return {'loss': mse, 'status': STATUS_OK}

Notice that the evaluate_model function retrieves its data from a variable named inputs. Inputs is defined in the next cell as a broadcast variable containing the inputs_pd DataFrame used earlier. As a broadcast variable, a complete stand-alone copy of the dataset used by the model is replicated to each worker in the Spark cluster. This limits the amount of data that must be sent from the cluster driver to the workers with each hyperopt iteration. For more information on this and other hyperopt best practices, please refer to this document.

With everything in place, let's perform our hyperparameter tuning over 100 iterations in order to identify the best model type and L2 settings for our dataset:

import mlflow

# replicate input_pd dataframe to workers in Spark cluster
inputs = sc.broadcast(input_pd)

# configure hyperopt settings to distribute to all executors on workers
spark_trials = SparkTrials(parallelism=2)

# select optimization algorithm
algo = tpe.suggest

# perform hyperparameter tuning (logging iterations to mlflow)
argmin = fmin(
  fn=evaluate_model,
  space=search_space,
  algo=algo,
  max_evals=100,
  trials=spark_trials
  )

# release the broadcast dataset
inputs.unpersist()
Hyperopt with SparkTrials will automatically track trials in MLflow. To view the MLflow experiment associated with the notebook, click the 'Runs' icon in the notebook context bar on the upper right. There, you can view all runs. To view logs from trials, please check the Spark executor logs. To view executor logs, expand 'Spark Jobs' above until you see the (i) icon next to the stage from the trial job. Click it and find the list of tasks; Task 0 is the first trial attempt, and subsequent Tasks are retries. Click the 'stderr' link for a task to view trial logs. 0%| | 0/80 [00:00<?, ?trial/s, best loss=?] 1%|▏ | 1/80 [00:05<06:47, 5.16s/trial, best loss: 3.641341814081204] 2%|▎ | 2/80 [00:06<05:09, 3.97s/trial, best loss: 3.6185327681839743] 4%|▍ | 3/80 [00:08<04:23, 3.43s/trial, best loss: 3.554343997973016] 5%|▌ | 4/80 [00:09<03:25, 2.70s/trial, best loss: 3.554343997973016] 8%|▊ | 6/80 [00:11<02:42, 2.19s/trial, best loss: 3.554343997973016] 9%|▉ | 7/80 [00:15<03:22, 2.78s/trial, best loss: 3.5414203015275407] 10%|█ | 8/80 [00:16<02:41, 2.25s/trial, best loss: 3.5414203015275407] 11%|█▏ | 9/80 [00:17<02:12, 1.87s/trial, best loss: 3.5414203015275407] 12%|█▎ | 10/80 [00:20<02:35, 2.21s/trial, best loss: 3.5414203015275407] 14%|█▍ | 11/80 [00:21<02:07, 1.85s/trial, best loss: 3.5414203015275407] 15%|█▌ | 12/80 [00:23<02:09, 1.90s/trial, best loss: 3.5414203015275407] 16%|█▋ | 13/80 [00:24<01:49, 1.63s/trial, best loss: 3.5414203015275407] 19%|█▉ | 15/80 [00:27<01:43, 1.59s/trial, best loss: 3.5414203015275407] 20%|██ | 16/80 [00:31<02:28, 2.32s/trial, best loss: 3.5414203015275407] 21%|██▏ | 17/80 [00:32<02:01, 1.92s/trial, best loss: 3.5414203015275407] 22%|██▎ | 18/80 [00:34<02:00, 1.95s/trial, best loss: 3.5414203015275407] 24%|██▍ | 19/80 [00:36<02:00, 1.97s/trial, best loss: 3.5414203015275407] 25%|██▌ | 20/80 [00:38<01:58, 1.98s/trial, best loss: 3.5414203015275407] 26%|██▋ | 21/80 [00:39<01:42, 1.73s/trial, best loss: 3.540445199381716] 28%|██▊ | 22/80 [00:40<01:27, 1.51s/trial, best loss: 3.540445199381716] 29%|██▉ | 23/80 [00:42<01:34, 1.66s/trial, best loss: 3.540445199381716] 30%|███ | 24/80 [00:43<01:22, 1.47s/trial, best loss: 3.540445199381716] 31%|███▏ | 25/80 [00:45<01:29, 1.63s/trial, best loss: 3.540445199381716] 32%|███▎ | 26/80 [00:46<01:17, 1.44s/trial, best loss: 3.540445199381716] 34%|███▍ | 27/80 [00:48<01:25, 1.61s/trial, best loss: 3.540445199381716] 35%|███▌ | 28/80 [00:49<01:14, 1.43s/trial, best loss: 3.540445199381716] 36%|███▋ | 29/80 [00:51<01:21, 1.61s/trial, best loss: 3.540445199381716] 38%|███▊ | 30/80 [00:52<01:11, 1.42s/trial, best loss: 3.540445199381716] 39%|███▉ | 31/80 [00:55<01:18, 1.60s/trial, best loss: 3.540445199381716] 40%|████ | 32/80 [00:56<01:08, 1.42s/trial, best loss: 3.540445199381716] 41%|████▏ | 33/80 [00:58<01:15, 1.60s/trial, best loss: 3.540445199381716] 42%|████▎ | 34/80 [00:59<01:05, 1.42s/trial, best loss: 3.540445199381716] 44%|████▍ | 35/80 [01:01<01:11, 1.60s/trial, best loss: 3.540445199381716] 46%|████▋ | 37/80 [01:04<01:07, 1.57s/trial, best loss: 3.540445199381716] 48%|████▊ | 38/80 [01:07<01:24, 2.01s/trial, best loss: 3.540445199381716] 49%|████▉ | 39/80 [01:08<01:09, 1.71s/trial, best loss: 3.540445199381716] 51%|█████▏ | 41/80 [01:11<01:04, 1.65s/trial, best loss: 3.540445199381716] 54%|█████▍ | 43/80 [01:15<01:04, 1.76s/trial, best loss: 3.540445199381716] 55%|█████▌ | 44/80 [01:18<01:16, 2.13s/trial, best loss: 3.540445199381716] 56%|█████▋ | 45/80 [01:20<01:13, 2.10s/trial, best loss: 3.540445199381716] 57%|█████▊ | 46/80 [01:21<01:00, 1.77s/trial, best loss: 3.540445199381716] 59%|█████▉ | 47/80 [01:23<01:00, 1.84s/trial, best loss: 3.540445199381716] 60%|██████ | 48/80 [01:25<01:00, 1.89s/trial, best loss: 3.540445199381716] 61%|██████▏ | 49/80 [01:26<00:50, 1.63s/trial, best loss: 3.540445199381716] 64%|██████▍ | 51/80 [01:29<00:46, 1.59s/trial, best loss: 3.540445199381716] 65%|██████▌ | 52/80 [01:32<00:56, 2.02s/trial, best loss: 3.540445199381716] 66%|██████▋ | 53/80 [01:33<00:46, 1.71s/trial, best loss: 3.540445199381716] 68%|██████▊ | 54/80 [01:35<00:46, 1.80s/trial, best loss: 3.540445199381716] 69%|██████▉ | 55/80 [01:36<00:39, 1.56s/trial, best loss: 3.540445199381716] 70%|███████ | 56/80 [01:38<00:40, 1.70s/trial, best loss: 3.540445199381716] 71%|███████▏ | 57/80 [01:39<00:34, 1.49s/trial, best loss: 3.540445199381716] 72%|███████▎ | 58/80 [01:41<00:36, 1.65s/trial, best loss: 3.540445199381716] 74%|███████▍ | 59/80 [01:42<00:30, 1.45s/trial, best loss: 3.540445199381716] 75%|███████▌ | 60/80 [01:44<00:32, 1.62s/trial, best loss: 3.540445199381716] 76%|███████▋ | 61/80 [01:45<00:27, 1.44s/trial, best loss: 3.540445199381716] 78%|███████▊ | 62/80 [01:47<00:28, 1.61s/trial, best loss: 3.540445199381716] 79%|███████▉ | 63/80 [01:48<00:24, 1.43s/trial, best loss: 3.540445199381716] 80%|████████ | 64/80 [01:50<00:25, 1.60s/trial, best loss: 3.540445199381716] 81%|████████▏ | 65/80 [01:51<00:21, 1.42s/trial, best loss: 3.540445199381716] 82%|████████▎ | 66/80 [01:53<00:22, 1.60s/trial, best loss: 3.540445199381716] 84%|████████▍ | 67/80 [01:54<00:18, 1.42s/trial, best loss: 3.540445199381716] 85%|████████▌ | 68/80 [01:56<00:19, 1.60s/trial, best loss: 3.540445199381716] 86%|████████▋ | 69/80 [01:57<00:15, 1.42s/trial, best loss: 3.540445199381716] 88%|████████▊ | 70/80 [01:59<00:15, 1.60s/trial, best loss: 3.540445199381716] 89%|████████▉ | 71/80 [02:00<00:12, 1.42s/trial, best loss: 3.540445199381716] 90%|█████████ | 72/80 [02:02<00:12, 1.60s/trial, best loss: 3.540445199381716] 91%|█████████▏| 73/80 [02:03<00:09, 1.42s/trial, best loss: 3.540445199381716] 92%|█████████▎| 74/80 [02:05<00:09, 1.60s/trial, best loss: 3.540445199381716] 94%|█████████▍| 75/80 [02:06<00:07, 1.42s/trial, best loss: 3.540445199381716] 95%|█████████▌| 76/80 [02:08<00:06, 1.60s/trial, best loss: 3.540445199381716] 96%|█████████▋| 77/80 [02:09<00:04, 1.42s/trial, best loss: 3.540445199381716] 99%|█████████▉| 79/80 [02:12<00:01, 1.45s/trial, best loss: 3.540445199381716] 100%|██████████| 80/80 [02:15<00:00, 1.91s/trial, best loss: 3.540445199381716] Total Trials: 80: 80 succeeded, 0 failed, 0 cancelled.

When used with the Databricks ML runtime, the individual runs that make up the search space evaluation are tracked in a built-in repository called mlflow. For more information on how to review the models generated by hyperopt using the Databricks mlflow interface, please check out this document.

The optimal hyperparameter settings observed during the hyperopt iterations are captured in the argmin variable. Using the space_eval function, we can obtain a friendly representation of which settings performed best:

# print optimum hyperparameter settings
print(space_eval(search_space, argmin))
{'l2': 0.9975590906220992, 'type': 'BG/NBD'}

Now that we know our best parameter settings, let's train the model with these to enable us to perform some more in-depth model evaluation:

NOTE Because of how search spaces are searched, different hyperopt runs may yield slightly different results.

# get hyperparameter settings
params = space_eval(search_space, argmin)
model_type = params['type']
l2_reg = params['l2']

# instantiate and configure model
if model_type == 'BG/NBD':
  model = BetaGeoFitter(penalizer_coef=l2_reg)
elif model_type == 'Pareto/NBD':
  model = ParetoNBDFitter(penalizer_coef=l2_reg)
else:
  raise 'Unrecognized model type'
  
# train the model
model.fit(input_pd['frequency_cal'], input_pd['recency_cal'], input_pd['T_cal'])
Out[23]: <lifetimes.BetaGeoFitter: fitted with 2163 subjects, a: 0.01, alpha: 18.28, b: 0.07, r: 0.45>

Step 5: Evaluate the Model

Using a method defined in the last section of this notebook, we can calculate the MSE for our newly trained model:

# score the model
frequency_holdout_actual = input_pd['frequency_holdout']
frequency_holdout_predicted = model.predict(input_pd['duration_holdout'], input_pd['frequency_cal'], input_pd['recency_cal'], input_pd['T_cal'])
mse = score_model(frequency_holdout_actual, frequency_holdout_predicted, 'mse')

print('MSE: {0}'.format(mse))
MSE: 3.540430294872508

While important for comparing models, the MSE metric is a bit more challenging to interpret in terms of the overall goodness of fit of any individual model. To provide more insight into how well our model fits our data, let's visualize the relationships between some actual and predicted values.

To get started, we can examine how purchase frequencies in the calibration period relates to actual (frequency_holdout) and predicted (model_predictions) frequencies in the holdout period:

from lifetimes.plotting import plot_calibration_purchases_vs_holdout_purchases

plot_calibration_purchases_vs_holdout_purchases(
  model, 
  input_pd, 
  n=90, 
  **{'figsize':(8,8)}
  )

display()

What we see here is that a higher number of purchases in the calibration period predicts a higher average number of purchases in the holdout period but the actual values diverge sharply from model predictions when we consider customers with a large number of purchases (>60) in the calibration period. Thinking back to the charts in the data exploration section of this notebook, you might recall that there are very few customers with such a large number of purchases so that this divergence may be a result of a very limited number of instances at the higher end of the frequency range. More data may bring the predicted and actuals back together at this higher end of the curve. If this divergence persists, it may indicate a range of customer engagement frequency above which we cannot make reliable predictions.

Using the same method call, we can visualize time since last purchase relative to the average number of purchases in the holdout period. This visualization illustrates that as time since the last purchase increases, the number of purchases in the holdout period decreases. In otherwords, those customers we haven't seen in a while aren't likely coming back anytime soon:

NOTE As before, we will hide the code in the following cells to focus on the visualizations. Use Show code to see the associated Python logic.

Show code

Plugging the age of the customer at the time of the last purchase into the chart shows that the timing of the last purchase in a customer's lifecycle doesn't seem to have a strong influence on the number of purchases in the holdout period until a customer becomes quite old. This would indicate that the customers that stick around a long while are likely to be more frequently engaged:

Show code

From a quick visual inspection, it's fair to say our model isn't perfect but there are some useful patterns that it captures. Using these patterns, we might calculate the probability a customer remains engaged:

# add a field with the probability a customer is currently "alive"
filtered_pd['prob_alive']=model.conditional_probability_alive(
    filtered_pd['frequency'], 
    filtered_pd['recency'], 
    filtered_pd['T']
    )

filtered_pd.head(10)
Out[28]:

The prediction of the customer's probability of being alive could be very interesting for the application of the model to our marketing processes. But before exploring model deployment, let's take a look at how this probability changes as customers re-engage by looking at the history of a single customer with modest activity in the dataset, CustomerID 12383:

from lifetimes.plotting import plot_history_alive
import matplotlib.pyplot as plt

# clear past visualization instructions
plt.clf()

# customer of interest
CustomerID = '12383'

# grab customer's metrics and transaction history
cmetrics_pd = input_pd[input_pd['CustomerID']==CustomerID]
trans_history = orders_pd.loc[orders_pd['CustomerID'] == CustomerID]

# calculate age at end of dataset
days_since_birth = 400

# plot history of being "alive"
plot_history_alive(
  model, 
  days_since_birth, 
  trans_history, 
  'InvoiceDate'
  )

display()

From this chart, we can see this customer made his or her first purchase in January 2011 followed by a repeat purchase later that month. There was about a 1-month lull in activity during which the probability of the customer being alive declined slightly but with purchases in March, April and June of that year, the customer sent repeated signals that he or she was engaged. Since that last June purchase, the customer hasn't been seen in our transaction history, and our belief that the customer remains engaged has been dropping though as a moderate pace given the signals previously sent.

How does the model arrive at these probabilities? The exact math is tricky but by plotting the probability of being alive as a heatmap relative to frequency and recency, we can understand the probabilities assigned to the intersections of these two values:

from lifetimes.plotting import plot_probability_alive_matrix

# set figure size
plt.subplots(figsize=(12, 8))

plot_probability_alive_matrix(model)

display()

In addition to predicting the probability a customer is still alive, we can calculate the number of purchases expected from a customer over a given future time interval, such as over the next 30-days:

Show code

As before, we can calculate this probability for each customer based on their current metrics:

filtered_pd['purchases_next30days']=(
  model.conditional_expected_number_of_purchases_up_to_time(
    30, 
    filtered_pd['frequency'], 
    filtered_pd['recency'], 
    filtered_pd['T']
    )
  )

filtered_pd.head(10)
Out[32]:

Step 6: Deploy the Model for Predictions

There are numerous ways we might make use of the trained BTYD model. We may wish to understand the probability a customer is still engaged. We may also wish to predict the number of purchases expected from the customer over some number of days. All we need to make these predictions is our trained model and values of frequency, recency and age (T) for the customer as demonstrated here:

frequency = 6
recency = 255
T = 300
t = 30

print('Probability of Alive: {0}'.format( model.conditional_probability_alive(frequency, recency, T) ))
print('Expected Purchases in next {0} days: {1}'.format(t, model.conditional_expected_number_of_purchases_up_to_time(t, frequency, recency, T) ))
Probability of Alive: 0.9949186328353091 Expected Purchases in next 30 days: 0.6048476679280559

The challenge now is to package our model into something we could re-use for this purpose. Earlier, we used mlflow in combination with hyperopt to capture model runs during the hyperparameter tuning exercise. As a platform, mlflow is designed to solve a wide range of challenges that come with model development and deployment, including the deployment of models as functions and microservice applications.

MLFlow tackles deployment challenges out of the box for a number of popular model types. However, lifetimes models are not one of these. To use mlflow as our deployment vehicle, we need to write a custom wrapper class which translates the standard mlflow API calls into logic which can be applied against our model.

To illustrate this, we've implemented a wrapper class for our lifetimes model which maps the mlflow predict() method to multiple prediction calls against our model. Typically, we'd map predict() to a single prediction but we've bumped up the complexity of the returned result to show one of many ways the wrapper may be employed to implement custom logic:

import mlflow
import mlflow.pyfunc

# create wrapper for lifetimes model
class _lifetimesModelWrapper(mlflow.pyfunc.PythonModel):
  
    def __init__(self, lifetimes_model):
        self.lifetimes_model = lifetimes_model

    def predict(self, context, dataframe):
      
      # access input series
      frequency = dataframe.iloc[:,0]
      recency = dataframe.iloc[:,1]
      T = dataframe.iloc[:,2]
      
      # calculate probability currently alive
      results = pd.DataFrame( 
                  self.lifetimes_model.conditional_probability_alive(frequency, recency, T),
                  columns=['alive']
                  )
      # calculate expected purchases for provided time period
      results['purch_15day'] = self.lifetimes_model.conditional_expected_number_of_purchases_up_to_time(15, frequency, recency, T)
      results['purch_30day'] = self.lifetimes_model.conditional_expected_number_of_purchases_up_to_time(30, frequency, recency, T)
      results['purch_45day'] = self.lifetimes_model.conditional_expected_number_of_purchases_up_to_time(45, frequency, recency, T)
      
      return results[['alive', 'purch_15day', 'purch_30day', 'purch_45day']]

We now need to register our model with mlflow. As we do this, we inform it of the wrapper that maps its expected API to the model's functionality. We also provide environment information to instruct it as to which libraries it needs to install and load for our model to work:

NOTE We would typically train and log our model as a single step but in this notebook we've separated the two actions in order to focus here on custom model deployment. For examine more common patterns of mlflow implementation, please refer to this and other examples available online.

# add lifetimes to conda environment info
conda_env = mlflow.pyfunc.get_default_conda_env()
conda_env['dependencies'][1]['pip'] += ['lifetimes==0.10.1'] # version should match version installed at top of this notebook

# save model run to mlflow
with mlflow.start_run(run_name='deployment run') as run:
  mlflow.pyfunc.log_model(
    'model', 
    python_model=_lifetimesModelWrapper(model), 
    conda_env=conda_env
    )

Now that our model along with its dependency information and class wrapper have been recorded, let's use mlflow to convert the model into a function we can employ against a Spark DataFrame:

from pyspark.sql.types import ArrayType, FloatType

# define the schema of the values returned by the function
result_schema = ArrayType(FloatType())

# define function based on mlflow recorded model
probability_alive_udf = mlflow.pyfunc.spark_udf(
  spark, 
  'runs:/{0}/model'.format(run.info.run_id), 
  result_type=result_schema
  )

# register the function for use in SQL
_ = spark.udf.register('probability_alive', probability_alive_udf)

Assuming we had access to customer metrics for frequency, recency and age, we can now use our function to generate some predictions:

# create a temp view for SQL demonstration (next cell)
filtered.createOrReplaceTempView('customer_metrics')

# demonstrate function call on Spark DataFrame
display(
  filtered
    .withColumn(
      'predictions', 
      probability_alive_udf(filtered.frequency, filtered.recency, filtered.T)
      )
    .selectExpr(
      'customerid', 
      'predictions[0] as prob_alive', 
      'predictions[1] as purch_15day', 
      'predictions[2] as purch_30day', 
      'predictions[3] as purch_45day'
      )
  )
123470.998028930.250691260.501272740.7517487
123480.990016040.136250440.272428570.4085386
123520.99583010.306633170.613101960.9194168
123560.98945040.10603250.211995180.3178914
123580.87430760.1131548660.226116180.33890733
123590.99735010.233513120.46690920.7001932
123600.98263870.165572970.330979730.49623367
123620.99901520.595530871.19077131.7857268
123630.761270050.0637206960.127367560.19094677
123640.994368140.395055920.78955881.1835755
123700.99234190.13584860.27162620.4073348
123710.673467460.189605620.378598030.5671088
123720.98329720.1150928440.230102910.34503567
123750.990631460.313200030.625850860.9380315
123770.316072970.0185003470.036985290.055455804
123790.965230.18853850.37686040.5649861
123800.994899330.32677960.65327350.9795058
123810.997212230.59381031.18691431.7793927
123830.908423240.200632420.401170760.6016164
123840.984525260.259908470.51942680.77860343
123860.250337060.01419532950.0283791940.042552274
123880.99697580.236814720.473508060.7100879
123930.989745860.146320880.29255910.43871808
123940.81687630.075635570.151175080.22662728
123950.99864120.517582241.03495981.5521383
123970.82800360.1176386250.235059140.3522898
123990.96913570.179684360.359242440.53868294
124060.988104050.180548040.36090050.5410754
124070.992679950.234831180.469510850.7040482
124080.995816650.346285880.692362371.0382422
124090.974926050.252021850.503803550.75536454
124100.145683110.0097286970.019448240.029159257
124120.96982750.190447450.380675020.57070374
124130.9905730.144382790.288685050.4329103
124140.8859010.0992723260.198476050.2976154
124150.99796250.73521281.47012622.2047327
124170.99894720.45720570.91422371.371067
124210.99455330.152697090.305304620.4578266
124220.980275450.1053561050.210642530.3158631
124230.99863250.340973820.68180041.0224867
124270.996428550.170942870.341804530.5125883
124280.99728670.469251280.938254241.4070255
124290.99494510.13442860.268787530.40307957
124310.99708120.59056251.1809021.7710152
124320.983564140.190101610.379986730.56967616
124330.99762650.208511020.41692950.6252564
124340.98898320.13538880.270706860.40595618
124350.814125060.0621763770.124286520.18633555
124370.999281350.66302071.32575691.9882154
124380.859804450.118301550.23638970.3542916
124440.994702160.350868170.701409461.0516504
124490.99312480.25047580.50071870.75074697
124500.375566630.0427667760.0854681950.12811135
124510.997211040.238254340.47638680.7144037
124540.53284910.15606410.31160820.46674675
124550.990043640.257609750.51507610.7724084
124560.9914710.188549760.376963850.5652518
124570.988272250.486908470.973510441.4598337
124580.82623410.059554260.1190483940.17848681
124610.69037080.0872766750.174407140.2614086
124620.995240450.160409880.320720730.48093894
124630.993145350.313214540.62622040.93903285
124640.997672260.295001570.58985070.88455534
124650.99014730.194437710.388650920.58266115
124680.96557770.1062592640.212446510.31856576
124710.999681951.47718042.95375854.4298725
124720.99764410.47618110.952178661.42799
124730.99390750.317098080.633927760.95050776
124740.998737160.98694981.97354652.9597437
124760.99916970.5315051.06279791.5938787
124770.997549240.384038330.76788971.1515616
124790.98832291.23791272.469883.6978498
124800.993819060.145264770.290447770.43555373
124810.99774810.33000810.659879270.9896145
124830.998226460.437987920.87578931.313401
124840.99678040.314476040.62880330.9429928
124880.98850910.389500020.77818051.1661786
124900.9986260.407654050.81512271.2224165
124920.712690830.125866170.251451670.37679976
124930.75688590.129310460.258488950.3875471
124940.998633560.408450570.81673641.2248585
124980.813505230.235105050.469434470.70315784
125000.997838440.445296350.890395341.3353014
125010.223033060.0137175110.0274230650.04111742
125020.985683740.201124330.402135820.60304
125040.972322460.73947371.47638012.211371
125070.966156540.112371720.224663230.33687943
125080.68742540.31599790.630548830.9440692
125100.76357990.0484677030.0968918950.14527546
125130.60924530.101822290.203427760.30484852
125160.77449070.0850821960.170038450.25488195
125170.99235350.246678280.493129730.7393725
125180.996855140.518715441.03674481.554172
125200.98681870.277038930.55391320.8306329
125220.67818360.21293780.425128370.63674384
125230.99909460.534205441.06817381.6019169
125240.99806830.31756940.63499280.9522708
125260.99101840.324473440.648361740.97174984
125270.98276270.105315830.210562320.31574276
125280.998044130.33774830.67532871.0127505

Showing the first 1000 rows.

%sql -- predict probabiliies customer is alive and will return in 15, 30 & 45 days

SELECT
  x.CustomerID,
  x.prediction[0] as prob_alive,
  x.prediction[1] as purch_15day,
  x.prediction[2] as purch_30day,
  x.prediction[3] as purch_45day
FROM (
  SELECT
    CustomerID,
    probability_alive(frequency, recency, T) as prediction
  FROM customer_metrics
  ) x;
123470.998028930.250691260.501272740.7517487
123480.990016040.136250440.272428570.4085386
123520.99583010.306633170.613101960.9194168
123560.98945040.10603250.211995180.3178914
123580.87430760.1131548660.226116180.33890733
123590.99735010.233513120.46690920.7001932
123600.98263870.165572970.330979730.49623367
123620.99901520.595530871.19077131.7857268
123630.761270050.0637206960.127367560.19094677
123640.994368140.395055920.78955881.1835755
123700.99234190.13584860.27162620.4073348
123710.673467460.189605620.378598030.5671088
123720.98329720.1150928440.230102910.34503567
123750.990631460.313200030.625850860.9380315
123770.316072970.0185003470.036985290.055455804
123790.965230.18853850.37686040.5649861
123800.994899330.32677960.65327350.9795058
123810.997212230.59381031.18691431.7793927
123830.908423240.200632420.401170760.6016164
123840.984525260.259908470.51942680.77860343
123860.250337060.01419532950.0283791940.042552274
123880.99697580.236814720.473508060.7100879
123930.989745860.146320880.29255910.43871808
123940.81687630.075635570.151175080.22662728
123950.99864120.517582241.03495981.5521383
123970.82800360.1176386250.235059140.3522898
123990.96913570.179684360.359242440.53868294
124060.988104050.180548040.36090050.5410754
124070.992679950.234831180.469510850.7040482
124080.995816650.346285880.692362371.0382422
124090.974926050.252021850.503803550.75536454
124100.145683110.0097286970.019448240.029159257
124120.96982750.190447450.380675020.57070374
124130.9905730.144382790.288685050.4329103
124140.8859010.0992723260.198476050.2976154
124150.99796250.73521281.47012622.2047327
124170.99894720.45720570.91422371.371067
124210.99455330.152697090.305304620.4578266
124220.980275450.1053561050.210642530.3158631
124230.99863250.340973820.68180041.0224867
124270.996428550.170942870.341804530.5125883
124280.99728670.469251280.938254241.4070255
124290.99494510.13442860.268787530.40307957
124310.99708120.59056251.1809021.7710152
124320.983564140.190101610.379986730.56967616
124330.99762650.208511020.41692950.6252564
124340.98898320.13538880.270706860.40595618
124350.814125060.0621763770.124286520.18633555
124370.999281350.66302071.32575691.9882154
124380.859804450.118301550.23638970.3542916
124440.994702160.350868170.701409461.0516504
124490.99312480.25047580.50071870.75074697
124500.375566630.0427667760.0854681950.12811135
124510.997211040.238254340.47638680.7144037
124540.53284910.15606410.31160820.46674675
124550.990043640.257609750.51507610.7724084
124560.9914710.188549760.376963850.5652518
124570.988272250.486908470.973510441.4598337
124580.82623410.059554260.1190483940.17848681
124610.69037080.0872766750.174407140.2614086
124620.995240450.160409880.320720730.48093894
124630.993145350.313214540.62622040.93903285
124640.997672260.295001570.58985070.88455534
124650.99014730.194437710.388650920.58266115
124680.96557770.1062592640.212446510.31856576
124710.999681951.47718042.95375854.4298725
124720.99764410.47618110.952178661.42799
124730.99390750.317098080.633927760.95050776
124740.998737160.98694981.97354652.9597437
124760.99916970.5315051.06279791.5938787
124770.997549240.384038330.76788971.1515616
124790.98832291.23791272.469883.6978498
124800.993819060.145264770.290447770.43555373
124810.99774810.33000810.659879270.9896145
124830.998226460.437987920.87578931.313401
124840.99678040.314476040.62880330.9429928
124880.98850910.389500020.77818051.1661786
124900.9986260.407654050.81512271.2224165
124920.712690830.125866170.251451670.37679976
124930.75688590.129310460.258488950.3875471
124940.998633560.408450570.81673641.2248585
124980.813505230.235105050.469434470.70315784
125000.997838440.445296350.890395341.3353014
125010.223033060.0137175110.0274230650.04111742
125020.985683740.201124330.402135820.60304
125040.972322460.73947371.47638012.211371
125070.966156540.112371720.224663230.33687943
125080.68742540.31599790.630548830.9440692
125100.76357990.0484677030.0968918950.14527546
125130.60924530.101822290.203427760.30484852
125160.77449070.0850821960.170038450.25488195
125170.99235350.246678280.493129730.7393725
125180.996855140.518715441.03674481.554172
125200.98681870.277038930.55391320.8306329
125220.67818360.21293780.425128370.63674384
125230.99909460.534205441.06817381.6019169
125240.99806830.31756940.63499280.9522708
125260.99101840.324473440.648361740.97174984
125270.98276270.105315830.210562320.31574276
125280.998044130.33774830.67532871.0127505

Showing the first 1000 rows.

With our lifetimes model now registered as a function, we can incorporate customer lifetime probabilities into our ETL batch, streaming and interactive query workloads. Levering additional model deployment capabilities of mlflow, our model may also be deployed as a standalone webservice leveraging AzureML and AWS Sagemaker.