webshap
v0.1.4
Published
Explain any ML models anywhere
Downloads
24
Maintainers
Readme
WebSHAP
JavaScript library that explains any machine learning models in your browser!
What is WebSHAP?
WebSHAP is a JavaScript library that adapts Kernel SHAP for the Web environments. You can use it to explain any machine learning models available on the Web directly in your browser. Given a model's prediction on a data point, WebSHAP can compute the importance score for each input feature. WebSHAP leverages modern Web technologies such as WebGL to accelerate computations. With a moderate model size and number of input features, WebSHAP can generate explanations in real time.✨
Getting Started
Installation
WebSHAP supports both browser and Node.js environments. To install WebSHAP, you can use npm
:
npm install webshap
Explain Machine Learning Models
WebSHAP uses the Kernel SHAP algorithm to interpret machine learning (ML) models. This algorithm uses a game theoretic approach to approximate the importance of each input feature. You can learn more about Kernel SHAP from the original paper or this nice tutorial.
To run WebSHAP on your model, you need to prepare the following three arguments.
|Name|Description|Type|Details|
|:---|:---|:---|:---|
|ML Model|A function that transforms input data into predicted probabilities|(x: number[][]) => Promise<number[]>
|This function wraps your ML model inference code. WebSHAP is model-agnostic, so any model can be used (e.g. random forest, CNNs, transformers).|
|Data Point|The input data for a prediction.|number[][]
|WebSHAP generates local explanations by computing the feature importance for individual predictions.|
|Background Data|A 2D array that represents feature "missingness" |number[][]
|WebSHAP approximates the contribution of a feature by comparing it to its missing value (also known as the base value). Using all zeros is the simplest option, but using the median or a subset of your data can improve accuracy.|
Then, you can generate explanations with WebSHAP through two functions:
// Import the class KernelSHAP from the webshap module
import { KernelSHAP } from 'webshap';
// Create an explainer object by feeding it with background data
const explainer = new KernelSHAP(
(x: number[][]) => myModel(x), // ML Model function wrapper
backgroundData, // Background data
0.2022 // Random seed
);
// Explain one prediction
let shapValues = await explainer.explainOneInstance(x);
// By default, WebSHAP automatically chooses the number of feature
// permutations. You can also pass it as an argument here.
const nSamples = 512;
shapValues = await explainer.explainOneInstance(x, nSamples);
// Finally, `shapValues` contains the importance score for each feature in `x`
console.log(shapValue);
See the WebSHAP Documentation for more details.
Application Example
Demo 1: Explaining XGBoost
|| |:---:| |🔎 WebSHAP explaining an XGBoost-based loan approval model 💰|
We present Loan Explainer
as an example of applying WebSHAP to explain a financial ML model in browsers. For a live demo of Loan Explainer, visit this webpage.
This example showcases a bank using an XGBoost classifier on the LendingClub dataset to predict if a loan applicant will be able to repay the loan on time. With this model, the bank can make automatic loan approval decisions. It's important to understand how these high-stakes decisions are being made, and that's where WebSHAP comes in. It provides private, ubiquitous, and interactive ML explanations.
This demo runs entirely on the client side, making it accessible from desktops, tablets, and phones. The model inference is powered by ONNX Runtime. The UI is implemented using Svelte. With Loan Explainer, users can experiment with different feature inputs and instantly see the model's predictions, along with clear explanations for those predictions.
Demo 2: Explaining Convolutional Neural Networks
|| |:---:| |🔎 WebSHAP explaining a convolutional neural network for image classification 🌠|
We apply WebSHAP to explain convolutional neural networks (CNNs) in browsers. The live demo of this explainer is available on this webpage.
In this example, we first train a TinyVGG model to classify images into four categories: 🐞Ladybug
, ☕️Espresso
, 🍊Orange
, and 🚙Sports Car
. TinyVGG is a type of convolutional neural network. For more details about the model architecture, check out CNN Explainer. TinyVGG is implemented using TensorFlow.js.
To explain the predictions of TinyVGG, we first apply image segmentation (SLIC) to divide the input image into multiple segments. Then, we compute SHAP scores on each segment for each class. The background data here are white pixels. We compute SHAP values for segments instead of raw pixels for computation efficiency. For example, in the figure above, there are only 16 input features (16 segments) for WebSHAP, but there would have been $64 \times 64 \times 3 = 12288$ input features if we use raw pixels. Finally, we visualize the SHAP scores of each segment as an overlay with a diverging color scale on top of the original input image.
Everything in this example (TinyVGG, image segmenter, WebSHAP) runs in the user's browser. In addition, WebSHAP enables interactive explanation: users can click a button to use a random input image or upload their own images. Both model inference and SHAP computation are real-time.
Demo 3: Explaining Transformer-based Text Classifiers
|| |:---:| |🔎 WebSHAP explaining a transformer model for text classification 🔤|
We use WebSHAP to explain the predictions of a Transformer text classifier in browsers. The live demo for this explainer is accessible on this webpage.
We train an XtremeDistil model to predict if an input text is toxic. The XtremeDistil model is a distilled version of pre-trained transformer-based language model BERT. We train this model on the Toxic Comments dataset. Then, we quantize and export the trained model to use int8
weights with ONNX. We use TensorFlow.js for tokenization and ONNX Runtime for model inference.
To explain the model's predictions, we compute SHAP scores for each input token. For background data, we use BERT's attention mechanism to mask tokens. For example, we represent a "missing" token by setting its attention map to 0
, which tells the model to ignore this token. Finally, we visualize the SHAP scores as token's background color with a diverging color scale.
All components in this example (XtremeDistil, tokenizer, WebSHAP) runs on the client-side. WebSHAP provides private, ubiquitous, and interactive explanations. Users can edit the input text and see new predictions and explanations. The model inference is real-time, and SHAP computation takes about 5 seconds for 50 tokens.
Developing WebSHAP
Clone or download this repository:
git clone [email protected]:poloclub/webshap.git
Install the dependencies:
npm install
Use Vitest for unit testing:
npm run test
Developing the Application Examples
Clone or download this repository:
git clone [email protected]:poloclub/webshap.git
Navigate to the example folder:
cd ./examples/demo
Install the dependencies:
npm install
Then run Loan Explainer:
npm run dev
Navigate to localhost:3000. You should see three Explainers running in your browser :)
Credits
WebSHAP is created by Jay Wang and Polo Chau.
Citation
To learn more about WebSHAP, please read our research paper (published at TheWebConf'23). If you find WebSHAP useful for your research, please consider citing our paper. And if you're building any exciting projects with WebSHAP, we'd love to hear about them!
@inproceedings{wangWebSHAPExplainingAny2023,
title = {{{WebSHAP}}: {{Towards Explaining Any Machine Learning Models Anywhere}}},
shorttitle = {{{WebSHAP}}},
booktitle = {Companion {{Proceedings}} of the {{Web Conference}} 2023},
author = {Wang, Zijie J. and Chau, Duen Horng},
year = {2023},
langid = {english}
}
License
The software is available under the MIT License.
Contact
If you have any questions, feel free to open an issue or contact Jay Wang.