How to Create a Web Application Firewall Using Machine Learning - Part II

This article is part of a series. Part I can be found here. The full code of this tutorial was posted on Github.

In the previous article we walked through the process of creating a simple predictive model for detection of malicious HTTP requests. We learned how to:
  • load and explore our data
  • split the data into training, development and test set
  • use scikit-learn's CountVectorizer and SGDClassifier to generate simple predictions
  • evaluate our model and visualise its performance with precision-recall curve
In this part we will show you how to improve effectiveness of our model by using the excellent XGBoost library. We will also introduce scikit-learn's pipelines and learn how to calculate importance of the features.

Single Evaluation Metric

In part I we created a simple model with 41% precision and 38% recall. The performance of that model was also visualised in the following plot with precision-recall curve:

Now we would like to try different approaches and see if they improve the effectiveness of the model. To evaluate our models we will use the area under the precision-recall curve (this metric is called average precision in scikit-learn library).

Why shouldn't we just use the values of precision and recall for evaluation? Optimisation of two metrics at the same time can often cause problems and slow down our progress. Let's say we trained a new classifier which obtained 44% precision and 33% recall on our development set. These scores do not tell us immediately whether this model is better than our previous one with 41% precision and 38% recall. Choosing a single evaluation metric (which in this case is a combination of both precision and recall) will help us to instantly decide which model performs best. In result we will be able to faster iterate on our ideas.

Higher value of the area under precision-recall curve corresponds to better performance of the algorithm (it can be easily deduced by looking at the plot). This metric is great for problems with highly imbalanced datasets - like ours. Let's check what is the value of average precision for our simple model. This score will serve as a baseline for our further experiments.

XGBoost and Pipelines

Ok, we can now focus on improving our model. We will test XGBoost which is an algorithm that has recently gained huge popularity thanks to many wins in competitions organised on Kaggle. XGBoost is a very efficient implementation of a gradient boosting algorithm (you can find out more about XGBoost and gradient boosting here).

This time we will also use scikit-learn's Pipeline class. The last time we wanted to create a model we had to sequentially call many methods like fit(), transform() or predict() of both CountVectorizer and SGDClassifier. Pipeline can be used to run many steps of our process with a single call. We will need to call fit() only once and the Pipeline will internally fit all of our transformers and the final estimator one after the other. With pipeline we will also be able to make predictions with a single call to predict() or predict_proba(). With only one transformer and one classifier the benefit of having a pipeline is small but it will become very useful when we will add new steps to our process. Pipeline is also very useful for hyperparameter tuning, since at the same time we can tune parameters for all of our transformers and estimator.

Let's take a look at our pipeline:
Now we can simply call fit() on the training set, and see what are the predictions on the development set:
Let's also plot the precision-recall curve and calculate metrics for our new XGBoost model:

We can see that using XGBoost instead of SGDClassifier strongly improved our effectiveness in detection of malicious requests. The area under precision-recall curve (average precision) increased from 34% to 73%. We can also achieve 97% precision and 37% recall if we stick to the default threshold of the estimator (which means classifying requests as malicious if they have probability of at least 50%). These results seem to be pretty decent for such simple features. Remember that we are passing to classifier only the information about the numbers of individual characters in the URI.

Feature Importance

Another benefit of XGBoost is that it can automatically calculate importance of the features. From the trained model we can extract scores that will show how valuable was each feature during training of the model. These scores can give us an intuitive understanding of how our model makes its decisions. Information about feature importance is also helpful during feature engineering - it provides us with feedback on whether or not our new features are contributing to the model. It can also help us to invent new features - for example through combination of some features that were marked as important.

We can access feature importance scores through feature_importances_ field of the XGBoost model. We would like to print top 10 features so we will define a simple function get_top_k_indices() that will take a list and return indices of elements with highest scores. The field feature_importances_ does not contain names of our features so we will need to take them from vocabulary_ field of CountVectorizer. Let's run the code and see which features are considered as the most important:

Importance scores of our model look reasonable. The two most important features are counts of symbols: . and / that are characteristic to path traversal attacks (that often have parts of URL that look like this: ../../../../../). The symbols like: )=*&; can also have large number of occurrences for example in SQL Injection attacks. Since such features were useful for the model we can speculate that features like the frequencies of .., ../ or SQL keywords like SELECT or DELETE could make our model even better.

In the next part of this tutorial we will try to expand our feature set to improve effectiveness of our model.

Andrzej Prałat
Machine Learning Engineer
Grey Wizard - Data Science Team

Brak komentarzy :

Prześlij komentarz