Sparkify — Predicting customer churn for a popular music streaming service
This post provides an overview of my Udacity Data Science Nanodegree’s capstone project. I chose to work on a project focused on predicting customer churn for a popular music streaming service.
Sparkify is a music streaming dataset where users can use a free version of the service or have a paid subscription. As customers interact with the music service they can chose to move from a free version to a paid subscription, move from a paid subscription to a free tier or cancel the use of the music service altogether.
We were provided with two datasets :
- a ‘mini’ dataset with music choices from 225 unique service customers that we’ve used to explore and get an understanding of the dataset’s characteristics and features (286K rows)
- a ‘large’ dataset with music choices from 22K unique service customers, for which we had to use Spark and the big data Hadoop framework to analyze and model due the sheer size of the dataset (26M rows)
We’ve chosen to use the Databricks ecosystem (PySpark, Spark SQL, Spark ML, ML Flow) to analyze the dataset, create features and perform prediction.
Problem Statement and Approach
We are asked to predict Sparkify’s customer churn. To so so, we used the following approach:
1. In a first phase we focused on the Sparkify mini dataset:
- Performed EDA on the mini dataset answering business questions as we go along
- Identified and defined customer churn
- Transformed the dataset with the appropriate features and label so that supervised machine learning framework can be used
2. In a second phase, we turned our attention to the large Sparkify dataset to perform churn predictions and useed what we had learned from the mini dataset :
- Performed EDA on the large dataset
- Preprocessed the dataset, performed feature engineering and labeled customers with a churn flag
- Explored the processed dataset and performed survival analysis
- Model and predicted customer churn using Spark’s ML libraries
Evaluate models and assess results
In simple terms, customer churn can be defined as registered Sparkify customer that decided to cancel their subscription altogether or downgraded from a paying subscription tier to a free tier.
Churn prediction is a supervised binary classification problem. The classes are well balanced (60/40 ratio as shown later in the analyzing churn section) and therefore ROC / AUC (computed by calculating the true positive rate against the false positive rate at a variety of thresholds) are appropriate evaluation metrics as they should be used when there are roughly equal numbers of observations for each class.
AUC stands for area under the (ROC) curve and returns the AUC score between 0.0 and 1.0 for no skill and perfect skill respectively.
Development Environment — Databricks
Databricks is the implementation of Apache Spark on cloud (Microsoft Azure, Amazon AWS or GCP). With fully managed Spark clusters, it is used to process large workloads of data and also helps in data engineering, data exploring and also visualizing data using Machine learning. It runs a distributed system behind the scenes, meaning the workload is automatically split across various processors and scales up and down on demand.
Some key features of Databricks:
- Databricks Workspace — interactive workspace that enables data scientists and data engineers to collaborate and work closely together on notebooks and dashboards
- Databricks Runtime — includes Apache Spark and support Scalla, PySpark, Python and Spark SQL amongst languages
- Fully managed service — resources like storage, virtual network cluster and compute are easy to start and stop
- Databricks File System (DBFS) — abstraction layer on top of cloud object storage (blob, S3 …)
We will be used Databricks Community Edition to analyze the Sparkify dataset and model customer churn.
Exploring the Dataset
The Sparkify dataset features
The mini Sparkify dataset contains music choices from 225 customers over a 60 days period. The size of this mini dataset is 286,500 rows.
The dataset is composed on the following features and can be thought of as a sequence of service user events and music listened :
artist: string - artist name
auth: string - authentication method
firstName: string - user first name
gender: string - user gender
lastName: string - user last name
length: double - length of the song listened
level: string - sparkify user service level (paid or free)
location: string - location of the user
method: string - http service method
page: - user service interaction event
registration: long - timestamp of user service registration
sessionId: service session id
song: song name played by the user
status: - http status
ts: long - timestamp of user service event
userAgent: - web browser used
userId: string - unique userid
The dataset’s features can be broken down into four broad categories :
- User information: userid, first name, last name, gender, location
- Sparkify service information: level, registration
- User service event/interactions information: method, iteminsession, page, status, ts, useragent, sessionid
- Songs and artist information: artist, song, length
Below is a sample of a user’s interactions with the service order by timestamp:
Exploratory Data Analysis
To get more familiar with the dataset we explored the music streaming data and answered a few interesting questions about the users behaviors and their use of Sparkify’s service (refer to the following notebook here)
- What is the distribution of of users by gender ?
- Users are fairly well balanced by gender, 212 males vs 104 females
- How many users were paying for the service using the first and last event recorded for each user in the dataset
- Out of 225 users, 48 (22 %) had a paid subscription at the earliest date recorded in dataset vs 177 (88 %) users which did not (free subscription)- Out of 225 users, 145 (65 %) had a paid subscription at the latest date recorded in dataset vs 80 (35 %) users which did not (free subscription)
- Which users listened to the most songs ?
- user 39 listened to ~6k songs in two months and average of 100 songs per days !
- users 92 and 140 listened to ~4,500 songs over two months
- Where are the most service users located ?
- Los Angeles & New York are where most Sparkigy users are located
- How many unique artists and songs does the dataset have ?
- The dataset holds approx. 17K artists and 58k songs
- What are the most popular songs ? artists ?
The most popular songs are
1. You're The One by Dwight Yoakam
2. Undo by Bjork
3. Revelry by Kings Of Leon**The most popular artists are
1. Kings Of Leon
3. Florence + The Machine**
- Which users listened to the most songs ?
- user 39 listened to ~6k songs over a two months period, an average of 100 songs per days
- users 92 and 140 listened to ~4,500 songs over two months
Identifying and defining customer churn
The dataset is not labeled so we need to define and identify customer churn.
The relevant features we will use to identify whether a customer has churned or not are page and level .
- level: possible values are 'free', 'paid'- page: possible values are 'Cancel', 'Submit Downgrade', 'Thumbs Down', 'Home', 'Downgrade', 'Roll Advert', 'Logout', 'Save Settings', 'Cancellation Confirmation', 'About', 'Settings', 'Add to Playlist', 'Add Friend', 'NextSong', 'Thumbs Up', 'Help', 'Upgrade', 'Error', 'Submit Upgrade'
By investigating the user event patterns, we can identify that the Submit Downgrade and Cancellation Confirmation events seem be good proxies to identify whether and when a customer has churned.
Lets look at what happens before and after a Submit Downgrade event occurs for a given user (user id 100015):
- the user has a paying subscription
- the user goes to the downgrade page
- the user submits downgrade request
- the user has now a free subscription
Lets now look at what happens before and after a Cancellation Confirmation event occurs for a given user (user id 100015):
- the user has a paying subscription
- the user goes to the Cancel page
- the user submits Cancellation Confirmation
- the user has now ended the Sparkify subscription
Preprocessing and Feature engineering
With churn defined, the next step was to label, transform and perform feature engineer on the dataset with each row corresponding to a unique user. The churn labels flags whether a specific user has churned based on the previously definition of churn above (refer to the following html notebook here).
The steps we followed to preprocess and transform the dataset were :
- aggregated the data for each user, drop features uncorrelated to churn
- engineered features — count of events, songs for each user, tenure in days and time from registration to first use
- labeled each user (row) a binary churn features
- cleaned the data set and removing rows with null value
- saved the clean dataset in a Spark table
The Sparkify transformed labeled dataset features are described below:
userId - unique user id
gender_first - user gender
song_count - count of songs played by the user
song_nunique - count of unique songs played by the user
sessionId_count - count of session ids
length_sum - total song playing time
time_to_first_use_days - number of days between registration and first use
service_tenure_days - number days the user used the service
About - number of time the About page was visited or event took place for the user
Add_Friend - number of time the Add_Friend page was visited or event took place for the user
Add_to_Playlist - number of time the Add_Friend page was visited or event took place for the user
Cancel - number of time the Cancel page was visited or event took place for the user
Downgrade - number of time the Downgrade page was visited or event took place for the user
Error - number of time the Error page was visited or event took place for the user
Help - number of time the Help page was visited or event took place for the user
Home - number of time the Home page was visited bor event took place for the user
Logout - number of time the Logout page was visited or event took place for the user
NextSong - number of time the NextSong page was visited or event took place for the user
Roll_Advert - number of time the Roll_Advert page was visited or event took place for the user
Save_Settings - number of time the Save_Settings page was visited or event took place for the user
Settings - number of time the Settings page was visited or event took place for the user
Submit_Upgrade - number of time the Submit_Upgrade page was visited or event took place for the userr
Thumbs_Down - number of time the Thumbs_Down page was visited or event took place for the user
Thumbs_Up - number of time the Thumbs_Up page was visited or event took place for the user
Upgrade - number of time the Upgrade page was visited or event took place for the user
churn - whether the user has churned or not
A sample of the transformed labeled dataset is shown below:
Analyzing Customer Churn
With the transformed labeled dataset, we were then able to look into the customer behaviors, churn and answer key business questions.
- What proportion of Spakify customers have churned ?
Overall, 40% of Sparkify users have churned in the dataset. We therefore won't need to deal with an imbalanced dataset while modelling.
- What is the average user tenure of user by gender ?
- The average tenure in the dataset is 40 days and very similar between males and female Sparkify users- The tenure distribution is left skewed but we can see a peak very early with over 1,200 users that seem to have quit on the first day of usage
- What does the progression of churn over time look like (survival analysis)?
- 40% of customers have churned over a 2 months period
- Male and Female survival plots and churn profiles are very similar
- Sparkify seems to lose customer at faster clip within the first couple of days and then after the ~55 day mark where the curve is much steeper
- It takes approx. 55 days to reach 30% of user churn and then 5 days to see 10% more users churn
Model and predict customer churn
In the following notebook, we use the labeled dataset to fit models and evaluate churn predictions for 4 Spark ML algos. This a binary classification where the proportion of customer churned in the transformed dataset is well balanced as discussed in the previous section (refer to the following notebook here).
- Since this is a binary classification, the evaluation metric that we used was ROC-AUC (areaUnderROC — https://spark.apache.org/docs/2.2.0/mllib-evaluation-metrics.html)
- Built spark pipelines to transform features using Spark’s VectorAssembler which combines a list of columns into a single feature vector
- Split the processed dataset into train and a test sets
- Evaluated four ML models: LogisticRegression, RandomForestClassifier, LinearSVC and GBTClassifier with default parameters
- Assessed feature importance for the best model
- Hypertuned the RandomForestClassifier and assessed results
Data Processing Pipelines
We used Spark pipelines to process and shape the data so that chosen machine learning models could be properly trained.
StringIndexer and OneHotEncoder objects were used convert categorical variables into a set of numeric variables that only take on values 0 and 1.
- StringIndexer converts a column of string values to a column of label indexes. For example, it might convert the values “red”, “blue”, and “green” to 0, 1, and 2.
- OneHotEncoder maps a column of category indices to a column of binary vectors, with at most one “1” in each row that indicates the category index for that row.
- One-hot encoding in Spark is a two-step process. We first used the StringIndexer, followed by the OneHotEncoder.
Model Evaluation and Validation
We chose to train and evaluation four Spark classifiers on the labeled dataset. Default Spark ML Lib hyperparameters were used in a first phase to assess which model was most promising:
LogisticRegression is a special case of Generalized Linear models that predicts the probability of the outcomes. In spark.ml logistic regression can be used to predict a binary outcome by using binomial logistic regression, or it can be used to predict a multiclass outcome by using multinomial logistic regression.
Random forests train a set of decision trees separately, so the training can be done in parallel. The algorithm injects randomness into the training process so that each decision tree is a bit different. Combining the predictions from each tree reduces the variance of the predictions, improving the performance on test data.
A support vector machine constructs a hyperplane or set of hyperplanes in a high- or infinite-dimensional space, which can be used for classification, regression, or other tasks. Intuitively, a good separation is achieved by the hyperplane that has the largest distance to the nearest training-data points of any class (so-called functional margin), since in general the larger the margin the lower the generalization error of the classifier. LinearSVC in Spark ML supports binary classification with linear SVM.
GBTs iteratively train decision trees in order to minimize a loss function. Like decision trees, GBTs handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions.
spark.mllibsupports GBTs for binary classification and for regression, using both continuous and categorical features.
All classifiers reported a ROC/AUC score above 90% with the Gradient Boosted Tree model recording the best performance with an ROC/AUC score of 0.97
Training model : lr_classifier
Pipeline : lr_classifier areaUnderROC score 0.9402484122969708
Training model : rf_classifier
Pipeline : rf_classifier areaUnderROC score 0.9615282227093049
Training model : gtb_classifier
Pipeline : gtb_classifier areaUnderROC score 0.972535781861781
Training model : lsvc_classifier
Pipeline : lsvc_classifier areaUnderROC score 0.9475858952022346
For the Boosted Tree model, the features most important to predict user churn were:
- The Cancel event count
- The Submit Upgrade event count
- The Downgrade event count
- The Roll Advert event count
These seem to make sense as their are a direct reflection of the user’s sentiment about the Sparkify service. Surprisingly, service tenure seem to have little importance on the binary classification.
Our approach and results seem adequate to the problem. Since this a binary classification problem, as a benchmark , we can simply select the class that has the most observations and use that class as the result for all predictions as a point of reference for evaluation purposes.
If we predict the class value that is most common in the dataset, we would obtain 60% accuracy since we know that 60% of our customers have not churned (refer to Analyzing Customer Churn section). All our model outperform the significantly this baseline.
Due to the limited size of the Hadoop clusters (compute, one node) of the free Databricks community edition, we found that fine tuning the GBT model took too much time.
We instead fined tuned the Random Forest whose training can be parallelized and is much faster. There are numerous hyperparameters and we focused on tuning the following ones:
The first two parameters we mention are the most important, and tuning them can often improve performance:
numTrees: Number of trees in the forest.
Increasing the number of trees will decrease the variance in predictions, improving the model’s test-time accuracy. Training time increases roughly linearly in the number of trees.
maxDepth: Maximum depth of each tree in the forest.
Increasing the depth makes the model more expressive and powerful. However, deep trees take longer to train and are also more prone to overfitting. In general, it is acceptable to train deeper trees when using random forests than when using a single decision tree. One tree is more likely to overfit than a random forest (because of the variance reduction from averaging multiple trees in the forest).
Using a 3-fold cross validation approach with a 9 hyperparameter grid, we were able to improve our random forest model performance:
- Untuned RF model ROC/AUC results: 0.9592
- Tuned RF model ROC/AUC results: 0.9714
paramGrid = (ParamGridBuilder()
.addGrid(rf.maxDepth, [2, 6, 10])
.addGrid(rf.numTrees, [5, 20, 50])
.build())# Create 3-fold CrossValidator
cv = CrossValidator(estimator=rf, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3)
Improvement & Next Steps
To improve this Sparkify project, three items should be considered:
Refactor the code and use Object Oriented Programming so that the code is modular, easier to understand and can be easily used process other datasets with similar characteristics.
In particular we could use two main classes to manage most of the project requirements:
- a DataManager class: class that would be in charge of loading, preprocessing and transforming the data
- a ModelManager class: class that would be in charge of managing and training the models, assessing the results and selecting the best model based on the chosen evaluation metrics.
Try Other Approaches and Models
We could try to use a survival analysis, also known as “time to event analysis” and compare results to the current classification results obtained.
The following two popular approaches could be investigated:
Kaplan-Meier (KM) Survival Analysis
The KM Survival Curve plots time on the x-axis and the estimated survival probability on the y-axis. KM Survival Analysis requires only two inputs to predict the survival curve: Event (churned/non-churned) and Time to Event.
Cox Proportional Hazard (CPH) Model
The CPH model determines the effect that a unit change in a covariate will have on an observation’s survival probability. CPH is a semi-parametric model.
Spark provides the following survival regression model: https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/regression/AFTSurvivalRegression.html
Fit a parametric survival regression model named accelerated failure time (AFT) model — see Accelerated failure time model (Wikipedia) — based on the Weibull distribution of the survival time
A trained model is great and deployed model is even better. While not available in the free community edition, Databricks ‘s provides the tools to deploy models and serve
MLflow Model Serving allows you to host machine learning models from Model Registry as REST endpoints that are updated automatically based on the availability of model versions and their stages.
When you enable model serving for a given registered model, Databricks automatically creates a unique cluster for the model and deploys all non-archived versions of the model on that cluster. Databricks restarts the cluster if an error occurs and terminates the cluster when you disable model serving for the model. Model serving automatically syncs with Model Registry and deploys any new registered model versions. Deployed model versions can be queried with a standard REST API request.
We used the mini dataset to get a good understanding of the Sparkify service and users. This helped us define, identify customer churn and how to transform the raw data quickly.
We were then able to apply the necessary transformations to a much larger large dataset and label customers so that we could apply a supervised binary classification approach to predicting customer churn. We were able to obtain very good results with an ROC / AUC score of ~0.97 for our best classifier.
Using the Databricks ecosystem and notebooks along with the Koalas API (https://docs.databricks.com/languages/koalas.html) was a great way to ease in the big data and the Spark ecosystem.
In addition we used Spark MLflow to track, save and inventory the models and experiment results obtained along the way.
MLflow is an open source platform to manage the ML lifecycle, including experimentation, reproducibility, deployment, and a central model registry.
References and Sources