Sparkify — Predicting customer churn for a popular music streaming service

This post provides an overview of my Udacity Data Science Nanodegree’s capstone project. I chose to work on a project focused on predicting customer churn for a popular music streaming service.

The Sparkify Music Streaming mobile app

Project Definition

We were provided with two datasets :

  • a ‘mini’ dataset with music choices from 225 unique service customers that we’ve used to explore and get an understanding of the dataset’s characteristics and features (286K rows)
  • a ‘large’ dataset with music choices from 22K unique service customers, for which we had to use Spark and the big data Hadoop framework to analyze and model due the sheer size of the dataset (26M rows)

We’ve chosen to use the Databricks ecosystem (PySpark, Spark SQL, Spark ML, ML Flow) to analyze the dataset, create features and perform prediction.

Problem Statement and Approach

1. In a first phase we focused on the Sparkify mini dataset:

  • Performed EDA on the mini dataset answering business questions as we go along
  • Identified and defined customer churn
  • Transformed the dataset with the appropriate features and label so that supervised machine learning framework can be used

2. In a second phase, we turned our attention to the large Sparkify dataset to perform churn predictions and useed what we had learned from the mini dataset :

  • Performed EDA on the large dataset
  • Preprocessed the dataset, performed feature engineering and labeled customers with a churn flag
  • Explored the processed dataset and performed survival analysis
  • Model and predicted customer churn using Spark’s ML libraries
    Evaluate models and assess results



The ROC curve is created by plotting the true positive rate (TPR) against the false positive rate (FPR) at various threshold settings.

AUC stands for area under the (ROC) curve and returns the AUC score between 0.0 and 1.0 for no skill and perfect skill respectively.

Development Environment — Databricks

Databricks is the implementation of Apache Spark on cloud (Microsoft Azure, Amazon AWS or GCP). With fully managed Spark clusters, it is used to process large workloads of data and also helps in data engineering, data exploring and also visualizing data using Machine learning. It runs a distributed system behind the scenes, meaning the workload is automatically split across various processors and scales up and down on demand.


Some key features of Databricks:

  • Databricks Workspace — interactive workspace that enables data scientists and data engineers to collaborate and work closely together on notebooks and dashboards
  • Databricks Runtime — includes Apache Spark and support Scalla, PySpark, Python and Spark SQL amongst languages
  • Fully managed service — resources like storage, virtual network cluster and compute are easy to start and stop
  • Databricks File System (DBFS) — abstraction layer on top of cloud object storage (blob, S3 …)

We will be used Databricks Community Edition to analyze the Sparkify dataset and model customer churn.

Exploring the Dataset

Udacity Sparkify customer dataset events

The mini Sparkify dataset contains music choices from 225 customers over a 60 days period. The size of this mini dataset is 286,500 rows.

The dataset is composed on the following features and can be thought of as a sequence of service user events and music listened :

artist: string - artist name
auth: string - authentication method
firstName: string - user first name
gender: string - user gender
itemInSession: long
lastName: string - user last name
length: double - length of the song listened
level: string - sparkify user service level (paid or free)
location: string - location of the user
method: string - http service method
page: - user service interaction event
registration: long - timestamp of user service registration
sessionId: service session id
song: song name played by the user
status: - http status
ts: long - timestamp of user service event
userAgent: - web browser used
userId: string - unique userid

The dataset’s features can be broken down into four broad categories :

  • User information: userid, first name, last name, gender, location
  • Sparkify service information: level, registration
  • User service event/interactions information: method, iteminsession, page, status, ts, useragent, sessionid
  • Songs and artist information: artist, song, length

Below is a sample of a user’s interactions with the service order by timestamp:

Snapshot of the Sparkify dataset

Exploratory Data Analysis

To get more familiar with the dataset we explored the music streaming data and answered a few interesting questions about the users behaviors and their use of Sparkify’s service (refer to the following notebook here)

  • What is the distribution of of users by gender ?
- Users are fairly well balanced by gender, 212 males vs 104 females
  • How many users were paying for the service using the first and last event recorded for each user in the dataset
- Out of 225 users, 48 (22 %) had a paid subscription at the earliest date recorded in dataset vs 177 (88 %) users which did not (free subscription)- Out of 225 users, 145 (65 %) had a paid subscription at the latest date recorded in dataset vs 80 (35 %) users which did not (free subscription)
  • Which users listened to the most songs ?
- user 39 listened to ~6k songs in two months and average of 100 songs per days !
- users 92 and 140 listened to ~4,500 songs over two months
  • Where are the most service users located ?
- Los Angeles & New York are where most Sparkigy users are located
  • How many unique artists and songs does the dataset have ?
- The dataset holds approx. 17K artists and 58k songs
  • What are the most popular songs ? artists ?
The most popular songs are 
1. You're The One by Dwight Yoakam
2. Undo by Bjork
3. Revelry by Kings Of Leon**
The most popular artists are
1. Kings Of Leon
2. Coldplay
3. Florence + The Machine**
  • Which users listened to the most songs ?
- user 39 listened to ~6k songs over a two months period, an average of 100 songs per days
- users 92 and 140 listened to ~4,500 songs over two months

Identifying and defining customer churn

The relevant features we will use to identify whether a customer has churned or not are page and level .

- level: possible values  are 'free', 'paid'- page: possible values are 'Cancel', 'Submit Downgrade', 'Thumbs Down', 'Home', 'Downgrade', 'Roll Advert', 'Logout', 'Save Settings', 'Cancellation Confirmation', 'About', 'Settings', 'Add to Playlist', 'Add Friend', 'NextSong', 'Thumbs Up', 'Help', 'Upgrade', 'Error', 'Submit Upgrade'

By investigating the user event patterns, we can identify that the Submit Downgrade and Cancellation Confirmation events seem be good proxies to identify whether and when a customer has churned.

Lets look at what happens before and after a Submit Downgrade event occurs for a given user (user id 100015):

  1. the user has a paying subscription
  2. the user goes to the downgrade page
  3. the user submits downgrade request
  4. the user has now a free subscription

Lets now look at what happens before and after a Cancellation Confirmation event occurs for a given user (user id 100015):

  1. the user has a paying subscription
  2. the user goes to the Cancel page
  3. the user submits Cancellation Confirmation
  4. the user has now ended the Sparkify subscription

Preprocessing and Feature engineering


  1. aggregated the data for each user, drop features uncorrelated to churn
  2. engineered features — count of events, songs for each user, tenure in days and time from registration to first use
  3. labeled each user (row) a binary churn features
  4. cleaned the data set and removing rows with null value
  5. saved the clean dataset in a Spark table

Dataset features

userId - unique user id
gender_first - user gender
song_count - count of songs played by the user
song_nunique - count of unique songs played by the user
sessionId_count - count of session ids
length_sum - total song playing time
time_to_first_use_days - number of days between registration and first use
service_tenure_days - number days the user used the service
About - number of time the About page was visited or event took place for the user
Add_Friend - number of time the Add_Friend page was visited or event took place for the user
Add_to_Playlist - number of time the Add_Friend page was visited or event took place for the user
Cancel - number of time the Cancel page was visited or event took place for the user
Downgrade - number of time the Downgrade page was visited or event took place for the user
Error - number of time the Error page was visited or event took place for the user
Help - number of time the Help page was visited or event took place for the user
Home - number of time the Home page was visited bor event took place for the user
Logout - number of time the Logout page was visited or event took place for the user
NextSong - number of time the NextSong page was visited or event took place for the user
Roll_Advert - number of time the Roll_Advert page was visited or event took place for the user
Save_Settings - number of time the Save_Settings page was visited or event took place for the user
Settings - number of time the Settings page was visited or event took place for the user
Submit_Upgrade - number of time the Submit_Upgrade page was visited or event took place for the userr
Thumbs_Down - number of time the Thumbs_Down page was visited or event took place for the user
Thumbs_Up - number of time the Thumbs_Up page was visited or event took place for the user
Upgrade - number of time the Upgrade page was visited or event took place for the user
churn - whether the user has churned or not

A sample of the transformed labeled dataset is shown below:

Snapshot of the processed Sparkify dataset with feature engineered and customer churn label

Analyzing Customer Churn

  • What proportion of Spakify customers have churned ?
Overall, 40% of Sparkify users have churned in the dataset. We therefore won't need to deal with an imbalanced dataset while modelling.
Proportion of customer that have churned in 60 days
  • What is the average user tenure of user by gender ?
 - The average tenure in the dataset is 40 days and very similar between males and female Sparkify users- The tenure distribution is left skewed but we can see a peak very early with over 1,200 users that seem to have quit on the first day of usage
distribution of churned customer by service tenure in days
  • What does the progression of churn over time look like (survival analysis)?
- 40% of customers have churned over a 2 months period
- Male and Female survival plots and churn profiles are very similar
- Sparkify seems to lose customer at faster clip within the first couple of days and then after the ~55 day mark where the curve is much steeper
- It takes approx. 55 days to reach 30% of user churn and then 5 days to see 10% more users churn
Customer Churn Survival Plot

Model and predict customer churn


  1. Built spark pipelines to transform features using Spark’s VectorAssembler which combines a list of columns into a single feature vector
  2. Split the processed dataset into train and a test sets
  3. Evaluated four ML models: LogisticRegression, RandomForestClassifier, LinearSVC and GBTClassifier with default parameters
  4. Assessed feature importance for the best model
  5. Hypertuned the RandomForestClassifier and assessed results

Data Processing Pipelines

StringIndexer and OneHotEncoder objects were used convert categorical variables into a set of numeric variables that only take on values 0 and 1.

  • StringIndexer converts a column of string values to a column of label indexes. For example, it might convert the values “red”, “blue”, and “green” to 0, 1, and 2.
  • OneHotEncoder maps a column of category indices to a column of binary vectors, with at most one “1” in each row that indicates the category index for that row.
  • One-hot encoding in Spark is a two-step process. We first used the StringIndexer, followed by the OneHotEncoder.


Model Evaluation and Validation

  • LogisticRegression

LogisticRegression is a special case of Generalized Linear models that predicts the probability of the outcomes. In logistic regression can be used to predict a binary outcome by using binomial logistic regression, or it can be used to predict a multiclass outcome by using multinomial logistic regression.

  • RandomForestClassifier

Random forests train a set of decision trees separately, so the training can be done in parallel. The algorithm injects randomness into the training process so that each decision tree is a bit different. Combining the predictions from each tree reduces the variance of the predictions, improving the performance on test data.

  • LinearSVC

A support vector machine constructs a hyperplane or set of hyperplanes in a high- or infinite-dimensional space, which can be used for classification, regression, or other tasks. Intuitively, a good separation is achieved by the hyperplane that has the largest distance to the nearest training-data points of any class (so-called functional margin), since in general the larger the margin the lower the generalization error of the classifier. LinearSVC in Spark ML supports binary classification with linear SVM.

  • GBTClassifier

GBTs iteratively train decision trees in order to minimize a loss function. Like decision trees, GBTs handle categorical features, extend to the multiclass classification setting, do not require feature scaling, and are able to capture non-linearities and feature interactions.spark.mllib supports GBTs for binary classification and for regression, using both continuous and categorical features.



  Training model : lr_classifier 
Pipeline : lr_classifier areaUnderROC score 0.9402484122969708
Training model : rf_classifier
Pipeline : rf_classifier areaUnderROC score 0.9615282227093049
Training model : gtb_classifier
Pipeline : gtb_classifier areaUnderROC score 0.972535781861781
Training model : lsvc_classifier
Pipeline : lsvc_classifier areaUnderROC score 0.9475858952022346

Feature importance

GBT classifier feature importance for Sparkify’s customer churn prediction

For the Boosted Tree model, the features most important to predict user churn were:

  1. The Cancel event count
  2. The Submit Upgrade event count
  3. The Downgrade event count
  4. The Roll Advert event count

These seem to make sense as their are a direct reflection of the user’s sentiment about the Sparkify service. Surprisingly, service tenure seem to have little importance on the binary classification.


If we predict the class value that is most common in the dataset, we would obtain 60% accuracy since we know that 60% of our customers have not churned (refer to Analyzing Customer Churn section). All our model outperform the significantly this baseline.


We instead fined tuned the Random Forest whose training can be parallelized and is much faster. There are numerous hyperparameters and we focused on tuning the following ones:

The first two parameters we mention are the most important, and tuning them can often improve performance:

numTrees: Number of trees in the forest.
Increasing the number of trees will decrease the variance in predictions, improving the model’s test-time accuracy. Training time increases roughly linearly in the number of trees.

maxDepth: Maximum depth of each tree in the forest.
Increasing the depth makes the model more expressive and powerful. However, deep trees take longer to train and are also more prone to overfitting. In general, it is acceptable to train deeper trees when using random forests than when using a single decision tree. One tree is more likely to overfit than a random forest (because of the variance reduction from averaging multiple trees in the forest).

Using a 3-fold cross validation approach with a 9 hyperparameter grid, we were able to improve our random forest model performance:

  • Untuned RF model ROC/AUC results: 0.9592
  • Tuned RF model ROC/AUC results: 0.9714
paramGrid = (ParamGridBuilder()
.addGrid(rf.maxDepth, [2, 6, 10])
.addGrid(rf.numTrees, [5, 20, 50])
# Create 3-fold CrossValidator
cv = CrossValidator(estimator=rf, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3)
Tuned Random Forest — hyperparameters recorded in ML FLow

Improvement & Next Steps

Code Refactoring

In particular we could use two main classes to manage most of the project requirements:

  • a DataManager class: class that would be in charge of loading, preprocessing and transforming the data
  • a ModelManager class: class that would be in charge of managing and training the models, assessing the results and selecting the best model based on the chosen evaluation metrics.

Try Other Approaches and Models

The following two popular approaches could be investigated:

Kaplan-Meier (KM) Survival Analysis

The KM Survival Curve plots time on the x-axis and the estimated survival probability on the y-axis. KM Survival Analysis requires only two inputs to predict the survival curve: Event (churned/non-churned) and Time to Event.

Cox Proportional Hazard (CPH) Model

The CPH model determines the effect that a unit change in a covariate will have on an observation’s survival probability. CPH is a semi-parametric model.

Spark provides the following survival regression model:

Fit a parametric survival regression model named accelerated failure time (AFT) model — see Accelerated failure time model (Wikipedia) — based on the Weibull distribution of the survival time


Model Deployment

MLflow Model Serving allows you to host machine learning models from Model Registry as REST endpoints that are updated automatically based on the availability of model versions and their stages.

When you enable model serving for a given registered model, Databricks automatically creates a unique cluster for the model and deploys all non-archived versions of the model on that cluster. Databricks restarts the cluster if an error occurs and terminates the cluster when you disable model serving for the model. Model serving automatically syncs with Model Registry and deploys any new registered model versions. Deployed model versions can be queried with a standard REST API request.


Databricks Model Deployment Pipeline Using ML flow



We were then able to apply the necessary transformations to a much larger large dataset and label customers so that we could apply a supervised binary classification approach to predicting customer churn. We were able to obtain very good results with an ROC / AUC score of ~0.97 for our best classifier.


In addition we used Spark MLflow to track, save and inventory the models and experiment results obtained along the way.

MLflow is an open source platform to manage the ML lifecycle, including experimentation, reproducibility, deployment, and a central model registry.

References and Sources

I am a product analyst and enjoy bridging business and data science to deliver business insights, enhance product features and improve decision making.