Insurance charges prediction [LGBM]ΒΆ
Giskard is an open-source framework for testing all ML models, from LLMs to tabular models. Donβt hesitate to give the project a star on GitHub βοΈ if you find it useful!
In this notebook, youβll learn how to create comprehensive test suites for your model in a few lines of code, thanks to Giskardβs open-source Python library.
Use-case:
Regression to predict the insurance charges based on medical and social data.
Model:
LGBMRegressor
Outline:
Detect vulnerabilities automatically with Giskardβs scan
Automatically generate & curate a comprehensive test suite to test your model beyond accuracy-related metrics
Install dependenciesΒΆ
Make sure to install the giskard
[ ]:
%pip install giskard --upgrade
We also install the project-specific dependencies for this tutorial.
[ ]:
%pip install lightgbm
TroubleshootingΒΆ
If you encounter a segmentation fault on macOS at any point during this tutorial, check: https://docs.giskard.ai/en/stable/community/contribution_guidelines/dev-environment.html#fatal-python-error-segmentation-fault-when-running-pytest-on-macos
Import librariesΒΆ
[1]:
import warnings
from pathlib import Path
from urllib.request import urlretrieve
import pandas as pd
from absl import logging
from lightgbm import LGBMRegressor
from sklearn.compose import ColumnTransformer
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from giskard import Dataset, Model, scan, testing
Notebook-level settingsΒΆ
[2]:
logging.set_verbosity(logging.ERROR)
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)
Define constantsΒΆ
[3]:
# Constants.
NUMERICAL_COLS = ["bmi", "age", "children"]
CATEGORICAL_COLS = ["sex", "smoker", "region"]
# Paths.
DATA_URL = "https://giskard-library-test-datasets.s3.eu-north-1.amazonaws.com/insurance_prediction_dataset-us_health_insurance_dataset.csv.tar.gz"
DATA_PATH = Path.home() / ".giskard" / "insurance_prediction_dataset" / "us_health_insurance_dataset.csv.tar.gz"
Dataset preparationΒΆ
Load dataΒΆ
[4]:
def fetch_demo_data(url: str, file: Path) -> None:
"""Helper to fetch data from the FTP server."""
if not file.parent.exists():
file.parent.mkdir(parents=True, exist_ok=True)
if not file.exists():
print(f"Downloading data from {url}")
urlretrieve(url, file)
print(f"Data was loaded!")
def download_data(**kwargs) -> pd.DataFrame:
"""Download the dataset using URL."""
fetch_demo_data(DATA_URL, DATA_PATH)
_df = pd.read_csv(DATA_PATH, **kwargs)
return _df
[ ]:
df = download_data()
Train-test splitΒΆ
[6]:
X_train, X_test, y_train, y_test = train_test_split(df.drop(columns=["charges"]), df.charges, random_state=0)
Wrap dataset with GiskardΒΆ
To prepare for the vulnerability scan, make sure to wrap your dataset using Giskardβs Dataset class. More details here.
[ ]:
raw_data = pd.concat([X_test, y_test], axis=1)
giskard_dataset = Dataset(
df=raw_data,
# A pandas.DataFrame that contains the raw data (before all the pre-processing steps) and the actual ground truth variable (target).
target="charges", # Ground truth variable.
name="insurance dataset", # Optional.
cat_columns=CATEGORICAL_COLS,
# List of categorical columns. Optional, but is a MUST if available. Inferred automatically if not.
)
Model buildingΒΆ
Define preprocessing pipelineΒΆ
[8]:
preprocessor = ColumnTransformer(
transformers=[
("scaler", StandardScaler(), NUMERICAL_COLS),
("one_hot_encoder", OneHotEncoder(handle_unknown="ignore", sparse_output=False), CATEGORICAL_COLS),
]
)
Build estimatorΒΆ
[ ]:
pipeline = Pipeline(steps=[("preprocessor", preprocessor), ("regressor", LGBMRegressor(n_estimators=30))])
pipeline.fit(X_train, y_train)
y_train_pred = pipeline.predict(X_train)
y_test_pred = pipeline.predict(X_test)
train_r2 = r2_score(y_train, y_train_pred)
test_r2 = r2_score(y_test, y_test_pred)
print(f"Train R2-score: {train_r2:.2f}")
print(f"Test R2-score: {test_r2:.2f}")
Wrap model with GiskardΒΆ
To prepare for the vulnerability scan, make sure to wrap your model using Giskardβs Model class. You can choose to either wrap the prediction function (preferred option) or the model object. More details here.
[ ]:
# Wrap the prediction function
def prediction_function(df):
return pipeline.predict(df)
giskard_model = Model(
model=prediction_function,
# A prediction function that encapsulates all the data pre-processing steps and that could be executed with the dataset used by the scan.
model_type="regression", # Either regression, classification or text_generation.
name="insurance model", # Optional.
feature_names=X_train.columns, # Default: all columns of your dataset.
)
# Validate wrapped model.
wrapped_predict = giskard_model.predict(giskard_dataset)
wrapped_test_metric = r2_score(y_test, wrapped_predict.prediction)
print(f"Wrapped Test R2-score: {wrapped_test_metric:.2f}")
Detect vulnerabilities in your modelΒΆ
Scan your model for vulnerabilities with GiskardΒΆ
Giskardβs scan allows you to detect vulnerabilities in your model automatically. These include performance biases, unrobustness, data leakage, stochasticity, underconfidence, ethical issues, and more. For detailed information about the scan feature, please refer to our scan documentation.
[ ]:
results = scan(giskard_model, giskard_dataset)
[12]:
display(results)
Generate comprehensive test suites automatically for your modelΒΆ
Generate test suites from the scanΒΆ
The objects produced by the scan can be used as fixtures to generate a test suite that integrate all detected vulnerabilities. Test suites allow you to evaluate and validate your modelβs performance, ensuring that it behaves as expected on a set of predefined test cases, and to identify any regressions or issues that might arise during development or updates.
[13]:
test_suite = results.generate_test_suite("Test suite")
test_suite.run()
2024-05-29 13:33:44,636 pid:63742 MainThread giskard.datasets.base INFO Casting dataframe columns from {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'} to {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'}
2024-05-29 13:33:44,638 pid:63742 MainThread giskard.utils.logging_utils INFO Predicted dataset with shape (96, 7) executed in 0:00:00.010454
Executed 'MSE on data slice β`region` == "northeast"β' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x1741309d0>, 'threshold': 17058376.93295317}:
Test failed
Metric: 20989697.45
2024-05-29 13:33:44,653 pid:63742 MainThread giskard.datasets.base INFO Casting dataframe columns from {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'} to {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'}
2024-05-29 13:33:44,655 pid:63742 MainThread giskard.utils.logging_utils INFO Predicted dataset with shape (153, 7) executed in 0:00:00.008035
Executed 'MSE on data slice β`sex` == "female"β' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x173786ce0>, 'threshold': 17058376.93295317}:
Test failed
Metric: 18686445.91
2024-05-29 13:33:44,666 pid:63742 MainThread giskard.datasets.base INFO Casting dataframe columns from {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'} to {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'}
2024-05-29 13:33:44,669 pid:63742 MainThread giskard.utils.logging_utils INFO Predicted dataset with shape (264, 7) executed in 0:00:00.007542
Executed 'MSE on data slice β`smoker` == "no"β' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x177ca3430>, 'threshold': 17058376.93295317}:
Test failed
Metric: 18440360.15
2024-05-29 13:33:44,677 pid:63742 MainThread giskard.datasets.base INFO Casting dataframe columns from {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'} to {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'}
2024-05-29 13:33:44,678 pid:63742 MainThread giskard.utils.logging_utils INFO Predicted dataset with shape (88, 7) executed in 0:00:00.004697
Executed 'MSE on data slice β`region` == "southeast"β' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x177ca0ac0>, 'threshold': 17058376.93295317}:
Test failed
Metric: 17201661.08
2024-05-29 13:33:44,682 pid:63742 MainThread giskard.core.suite INFO Executed test suite 'Test suite'
2024-05-29 13:33:44,682 pid:63742 MainThread giskard.core.suite INFO result: failed
2024-05-29 13:33:44,683 pid:63742 MainThread giskard.core.suite INFO MSE on data slice β`region` == "northeast"β ({'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x1741309d0>, 'threshold': 17058376.93295317}): {failed, metric=20989697.452960182}
2024-05-29 13:33:44,683 pid:63742 MainThread giskard.core.suite INFO MSE on data slice β`sex` == "female"β ({'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x173786ce0>, 'threshold': 17058376.93295317}): {failed, metric=18686445.912572958}
2024-05-29 13:33:44,683 pid:63742 MainThread giskard.core.suite INFO MSE on data slice β`smoker` == "no"β ({'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x177ca3430>, 'threshold': 17058376.93295317}): {failed, metric=18440360.147996694}
2024-05-29 13:33:44,684 pid:63742 MainThread giskard.core.suite INFO MSE on data slice β`region` == "southeast"β ({'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x177ca0ac0>, 'threshold': 17058376.93295317}): {failed, metric=17201661.078229036}
[13]:
Customize your suite by loading objects from the Giskard catalogΒΆ
The Giskard open source catalog will enable to load:
Tests such as metamorphic, performance, prediction & data drift, statistical tests, etc
Slicing functions such as detectors of toxicity, hate, emotion, etc
Transformation functions such as generators of typos, paraphrase, style tune, etc
To create custom tests, refer to this page.
For demo purposes, we will load a simple unit test (test_rmse) that checks if the test RMSE score is below the given threshold. For more examples of tests and functions, refer to the Giskard catalog.
[ ]:
test_suite.add_test(testing.test_rmse(model=giskard_model, dataset=giskard_dataset, threshold=10.0)).run()