Predicting Customer Churn Rates Using Apache Spark and IBM Watson

Jesse Peterson
10 min readFeb 17, 2021
Churn Rate: The annual percentage rate at which customers stop subscribing to a service or employees leave a job.

Project Definition

This project uses a dataset of users from a fictitiously created music streaming service called ‘Sparkify’ to predict the churn rate of users from paid subscriptions to free subscriptions (cancellations) and from free subscriptions to paid subscriptions (upgrades). The purpose of this projects is to examine how we might be able to predict what type of user is most likely to cancel or upgrade their subscription using the Apache Spark analytics engine.

Understanding the churn rate associated with a set of users can help organizations create data-driven strategies to either retain members of a customer base or create campaigns to bring in new users. Knowing the factors that either encourage users to stay on a platform or entice them to pay for a subscription enables businesses to maximize the allocation of resources and minimize the amount of time spent pursuing unprofitable opportunities.

The code associated with this project can be found in the associated Github Repository.

The medium-sparkify-event-data.json dataset for this project was provided by the Udacity platform and can only be access by registered users.

Problem Statement

The datasets contains 543,705 records for users of the Sparkify platform spanning 18 columns. I would like to use this dataset to predict the ‘churn rate’ of different users. A model will be developed to identify the most optimal set of features to predict users either upgrading their subscriptions or downgrading their subscriptions.

Metrics

In the context of customer retention, the churn rate is the rate at which a customer moves from one group to another over a defined period of time. The churn rate may include other features about users such as changes in user behaviors as a result of using the product or service, allowing observers to quantify the impact a specific variable has on churn rates. The success of the model will be determined by the following descriptive statistics:

F-1 Score: 2 * (Weighted Precision Score * Weighted Recall Score) / (Weighted Precision Score +Weighted Recall Score)
Weighted Precision Score: The number of True Positives / The number of True Positives and False Positives
Weighted Recall Score: The number of True Positives / The number of True Positives and False Negatives
Accuracy Score:
The number of True Positives and The number of True Negatives / The number of True Positives and True Negatives and False Positives and False Negatives

False Positives: The raw count of incorrectly predicted users that churned.
False Negatives: The raw count of incorrectly predicted users that did not churn.
True Positives: The raw count of correctly predicted users that churned.
True Negatives: The raw count of correctly predicted users that did not churn.

The models with the highest ‘scores’ and the most True’ predictions will be considered the best.

Data Analysis

Datasets

The below subset of the data was extracted from the original “medium-sparkify-event-data.json” dataset for this analysis:

|-- gender: string (nullable = true)
|-- length: double (nullable = true)
|-- level: string (nullable = true)
|-- location: string (nullable = true)
|-- method: string (nullable = true)
|-- page: string (nullable = true)
|-- sessionID: long (nullable = true)
|-- song: string (nullable = true)
|-- ts: long (nullable = true)
|-- userAgent: string (nullable = true)
|-- userId: string (nullable = true)

Below is a description of these fields:

gender: The gender (M or F) of the user

level: The subscription type (either Free or Paid) of the user

location: The City and State of the users location

method: The web request (PUT or GET) made by the user

page: The page within the Sparkify platform being accessed by the user

sessionID: The unique ID of the users session on the Sparkify platform

song: The name of the song accessed by the user

ts: The timestamp associated with the users logging into the Sparkify platform

userAgent: The Browser, Operating System and Device that the users is accessing the page with

userId: The users unique identification number

Data Preprocessing

In order to identify the feature that are most predictive of churn rates, it is first necessary to define what variable will be used as the target event that is the ‘churn’ being predicted. In this case, I choose to use the Page variable as the churn event, and subset the data to encode a value of ‘1’ when an existing user either downgrades a subscriptions of confirms the cancellation of subscription, and encoded a value of ‘0’ otherwise. These variables were added to a new column named ‘churn’:

Churn Categorizations

The churn column then contains the following distribution of churned (1) to unchurned (0) records:

+-----+------+
|churn| count|
+-----+------+
| 0|527789|
| 1| 216|
+-----+------+

Looking at these numbers, it is clear that the original data set is highly imbalanced, as 99% of the users are part of the non-churned group, which would bias any predictions towards this label because of the disparity in records. So in order to balance the ratio of churned to unchurned users to obtain a more accurate prediction, I used oversampling to duplicate records from the minority class (in this case the 216 churned users) until there were an equal number of records to the 527,789 unchurned users, creating a grand total of 1,055,578 records.

Next, I profiled the level, gender, length, location, method, song, ts, and userAgent features to be used in a predictive model to identify their distributions.

Looking at the level variable, 79% of users are currently paying for the service while 20.8% are currently using it for free:

Ratio of Paid to Free Users

There is also an imbalance of male to female users, with around 300,000 Male users and 2200 female users:

Ratio of Male to Female Users

Among the most commonly chosen Page Actions, the overwhelming number of actions are Users navigating to a song, with a smaller number of actions being rating songs, going to the home page, adding a song to a playlist, and adding friends:

User Actions by Page

Looking at how users interact with the platform, users tend to interact with the platform most often on Weekdays with a relatively normal distribution, but interaction occurs much less frequently on the Weekends:

User Activity by Day

Usage of the platform throughout the day shows that usage peaks between 4 PM and 8 PM, then declines until the next morning around 8 AM:

User Activity by Hour

Looking at the correlation between variables, there does not appear to be any one variable that is highly correlated with another, aside from the relationship between a session ID and the length of a session (although this has not inherent meaning):

Correlations Between User Features

I separately extracted the state from the location variable.

Model Implementation and Refinement

Once the datasets were balanced and Exploratory Data Analysis performed, I created Indexes from the Churn, Gender, Level, Location, and User Agent fields and fit them into a normal distribution so they could be treated as vectorized columns to be inputted into a machine learning based model.

For the final part of the data preparation, I added columns for the number of times the Save Settings Page was accessed, number of songs a user listened to, counts of Thumbs-up and Thumbs-down rankings, the number of times a user Added a song to a playlist, and the average count of songs per session.

The final columns to be used in the model are: gender_index, level_index, userAgent_index, ts_day, saved_settings, num_songs, thumbs_up, thumbs_down, playlist_added, and songs_per_session .

With the data prepared to be fed into the model, I used a StringIndexer to cast the ‘churn’ column into a column of label indices, a Vector Assembler to combine all of the columns into a single Vector Column, and then use a Normalizer to normalize the resulting vector along a standard distribution.

The dataset was then split into a 80% training and 20% testing partition to be feed into the model, and a Multiclass Evaluator was created to review the results after they are run through the model.

Four Modeling Techniques were then implemented to predict churn rates: a Logistic Regression Model, a Random Forest Classifier, a Gradient Boost Tree Classifier, and a Naive Bayes Classifier:

  • The Logistic Regression model will measure the input features probability of accurately predict a churned users based on a scale ranging from 0 to 1, with 0 representing no probability of an accurate prediction and 1 representing a perfectly accurate prediction.
  • The Random Forest Classifier will use decision trees to perform iterative operations evaluating the best combination of input features to correctly predict user churn rates.
  • The Gradient Boost Tree Classifier uses similar decision tree approaches, except on smaller subsets of the data and then using averaging techniques to choose the most optimal combination of trees to maximize the probability of an accurate prediction.
  • The Naive Bayes classifier will used to evaluate the features ability to accurately predict a label given both prior predictions and the ratio of the overall accuracy of all known predictions.

Model Evaluation

Two functions were created to evaluate the result sets: an Evaluation Function to output the F1 score, Weighted Precision, Weighted Recall, and Accuracy Rates, and a Confusion Matrix to output the True Positive, True Negative, False Positive, False Negative, Precision and Recall rates.

The results of these are listed below:

Descriptive Statistics Comparing Model Performance
Confusion Matrix from the Gradient Boost Tree

Justification

After initially running the models training the models then using the test data to determine their accuracy, the scores of the Gradient Boost Tree model clearly show that it performed the best. GBT scored the highest in all of the descriptive statistical categories and is among the most powerful ‘Ensemble’ regression tree techniques for classification problems. GBT was able to correctly identify 47,332 cases of users churning and correctly reject 71021 users not churning.

The GBT Classifier was able to accurately predict around 80% of user churns, and could be fine tuned using Cross-fold validation to obtain even higher rates of precision. Given that this was a binary classification problem with many non-linear relationships likely between the features and label, an optimized-GBT approach is likely to be the best solution to avoid outlier influence and overtraining risks.

Reflection

Having taken a subset of the original Sparkify data, balanced the resulting labels, performed Exploratory Data Analysis, then created, tuned, and evaluated the model, a basic approach to predicting user churn rates was created with 80% accuracy. Fine tuning this model would be necessary to identify the optimal set of parameters to increase the overall rate of accuracy, or alternatively, Principal Component Analysis should be used instead to identify the optimal combination and order of features that are most predictive of churn labels.

Whether a ‘model tuning’ approach like what was used in this project or a potential ‘feature engineering’ approach is taken to enhance the accuracy of the model requires subjective assessments using the available data and individual problems at hand, as there is ‘No Free Lunch’ or perfectly scalable combination of algorithms to reliable predict accuracy.

Model Improvement

This model could be improved by using different sampling techniques to balance the original dataset, such as random over or under sampling, synthetic over or under sampling, and some combination of the two approaches. More about sampling strategies can be read here.

When performing Exploratory Data Analysis on the original dataset, some missing values were discovered for certain columns, and the records with missing values were removed from the data sets. This may have negatively impacted the accuracy of the models, as the removed records may have had a meaningful relationship with the churn rate. Instead of removing the records with null values, I could have used substitution techniques such as sampling with replacement, sampling without replacement, or stratified sampling to increase the accuracy of the model.

Alternatively, Principle Component Analysis could have been used to determine the most optimal set of features in the input dataset using features such as eigenvalues to determine which input variables capture the maximum amount of variance in the churn rates, then these could have been used to populate the original models.

Larger features could have also been broken up into smaller features, such using regex to break up the ‘location’ column into single city-state values or breaking the ‘userAgent’ column up into multiple columns for the users browser and operating system.

Finally, after running our Gradient Boost Tree Model, hyper parameters could have been used to tune the model’s Pipeline and Parameter Grid values to identify the optimal set of attributes to tune the model by minimizing a loss function.

--

--