Open In Colab View Notebook on GitHub

Insurance charges prediction [LGBM]ΒΆ

Giskard is an open-source framework for testing all ML models, from LLMs to tabular models. Don’t hesitate to give the project a star on GitHub ⭐️ if you find it useful!

In this notebook, you’ll learn how to create comprehensive test suites for your model in a few lines of code, thanks to Giskard’s open-source Python library.

Use-case:

  • Regression to predict the insurance charges based on medical and social data.

  • Model: LGBMRegressor

  • Dataset

Outline:

  • Detect vulnerabilities automatically with Giskard’s scan

  • Automatically generate & curate a comprehensive test suite to test your model beyond accuracy-related metrics

Install dependenciesΒΆ

Make sure to install the giskard

[ ]:
%pip install giskard --upgrade

We also install the project-specific dependencies for this tutorial.

[ ]:
%pip install lightgbm

TroubleshootingΒΆ

If you encounter a segmentation fault on macOS at any point during this tutorial, check: https://docs.giskard.ai/en/stable/community/contribution_guidelines/dev-environment.html#fatal-python-error-segmentation-fault-when-running-pytest-on-macos

Import librariesΒΆ

[1]:
import warnings
from pathlib import Path
from urllib.request import urlretrieve

import pandas as pd
from absl import logging
from lightgbm import LGBMRegressor
from sklearn.compose import ColumnTransformer
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler

from giskard import Dataset, Model, scan, testing

Notebook-level settingsΒΆ

[2]:
logging.set_verbosity(logging.ERROR)
warnings.filterwarnings("ignore", message=r"Passing", category=FutureWarning)

Define constantsΒΆ

[3]:
# Constants.
NUMERICAL_COLS = ["bmi", "age", "children"]
CATEGORICAL_COLS = ["sex", "smoker", "region"]

# Paths.
DATA_URL = "https://giskard-library-test-datasets.s3.eu-north-1.amazonaws.com/insurance_prediction_dataset-us_health_insurance_dataset.csv.tar.gz"
DATA_PATH = Path.home() / ".giskard" / "insurance_prediction_dataset" / "us_health_insurance_dataset.csv.tar.gz"

Dataset preparationΒΆ

Load dataΒΆ

[4]:
def fetch_demo_data(url: str, file: Path) -> None:
    """Helper to fetch data from the FTP server."""
    if not file.parent.exists():
        file.parent.mkdir(parents=True, exist_ok=True)

    if not file.exists():
        print(f"Downloading data from {url}")
        urlretrieve(url, file)

    print(f"Data was loaded!")


def download_data(**kwargs) -> pd.DataFrame:
    """Download the dataset using URL."""
    fetch_demo_data(DATA_URL, DATA_PATH)
    _df = pd.read_csv(DATA_PATH, **kwargs)
    return _df
[ ]:
df = download_data()

Train-test splitΒΆ

[6]:
X_train, X_test, y_train, y_test = train_test_split(df.drop(columns=["charges"]), df.charges, random_state=0)

Wrap dataset with GiskardΒΆ

To prepare for the vulnerability scan, make sure to wrap your dataset using Giskard’s Dataset class. More details here.

[ ]:
raw_data = pd.concat([X_test, y_test], axis=1)
giskard_dataset = Dataset(
    df=raw_data,
    # A pandas.DataFrame that contains the raw data (before all the pre-processing steps) and the actual ground truth variable (target).
    target="charges",  # Ground truth variable.
    name="insurance dataset",  # Optional.
    cat_columns=CATEGORICAL_COLS,
    # List of categorical columns. Optional, but is a MUST if available. Inferred automatically if not.
)

Model buildingΒΆ

Define preprocessing pipelineΒΆ

[8]:
preprocessor = ColumnTransformer(
    transformers=[
        ("scaler", StandardScaler(), NUMERICAL_COLS),
        ("one_hot_encoder", OneHotEncoder(handle_unknown="ignore", sparse_output=False), CATEGORICAL_COLS),
    ]
)

Build estimatorΒΆ

[ ]:
pipeline = Pipeline(steps=[("preprocessor", preprocessor), ("regressor", LGBMRegressor(n_estimators=30))])

pipeline.fit(X_train, y_train)

y_train_pred = pipeline.predict(X_train)
y_test_pred = pipeline.predict(X_test)

train_r2 = r2_score(y_train, y_train_pred)
test_r2 = r2_score(y_test, y_test_pred)

print(f"Train R2-score: {train_r2:.2f}")
print(f"Test R2-score: {test_r2:.2f}")

Wrap model with GiskardΒΆ

To prepare for the vulnerability scan, make sure to wrap your model using Giskard’s Model class. You can choose to either wrap the prediction function (preferred option) or the model object. More details here.

[ ]:
# Wrap the prediction function
def prediction_function(df):
    return pipeline.predict(df)


giskard_model = Model(
    model=prediction_function,
    # A prediction function that encapsulates all the data pre-processing steps and that could be executed with the dataset used by the scan.
    model_type="regression",  # Either regression, classification or text_generation.
    name="insurance model",  # Optional.
    feature_names=X_train.columns,  # Default: all columns of your dataset.
)

# Validate wrapped model.
wrapped_predict = giskard_model.predict(giskard_dataset)
wrapped_test_metric = r2_score(y_test, wrapped_predict.prediction)

print(f"Wrapped Test R2-score: {wrapped_test_metric:.2f}")

Detect vulnerabilities in your modelΒΆ

Scan your model for vulnerabilities with GiskardΒΆ

Giskard’s scan allows you to detect vulnerabilities in your model automatically. These include performance biases, unrobustness, data leakage, stochasticity, underconfidence, ethical issues, and more. For detailed information about the scan feature, please refer to our scan documentation.

[ ]:
results = scan(giskard_model, giskard_dataset)
[12]:
display(results)

Generate comprehensive test suites automatically for your modelΒΆ

Generate test suites from the scanΒΆ

The objects produced by the scan can be used as fixtures to generate a test suite that integrate all detected vulnerabilities. Test suites allow you to evaluate and validate your model’s performance, ensuring that it behaves as expected on a set of predefined test cases, and to identify any regressions or issues that might arise during development or updates.

[13]:
test_suite = results.generate_test_suite("Test suite")
test_suite.run()
2024-05-29 13:33:44,636 pid:63742 MainThread giskard.datasets.base INFO     Casting dataframe columns from {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'} to {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'}
2024-05-29 13:33:44,638 pid:63742 MainThread giskard.utils.logging_utils INFO     Predicted dataset with shape (96, 7) executed in 0:00:00.010454
Executed 'MSE on data slice β€œ`region` == "northeast"”' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x1741309d0>, 'threshold': 17058376.93295317}:
               Test failed
               Metric: 20989697.45


2024-05-29 13:33:44,653 pid:63742 MainThread giskard.datasets.base INFO     Casting dataframe columns from {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'} to {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'}
2024-05-29 13:33:44,655 pid:63742 MainThread giskard.utils.logging_utils INFO     Predicted dataset with shape (153, 7) executed in 0:00:00.008035
Executed 'MSE on data slice β€œ`sex` == "female"”' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x173786ce0>, 'threshold': 17058376.93295317}:
               Test failed
               Metric: 18686445.91


2024-05-29 13:33:44,666 pid:63742 MainThread giskard.datasets.base INFO     Casting dataframe columns from {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'} to {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'}
2024-05-29 13:33:44,669 pid:63742 MainThread giskard.utils.logging_utils INFO     Predicted dataset with shape (264, 7) executed in 0:00:00.007542
Executed 'MSE on data slice β€œ`smoker` == "no"”' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x177ca3430>, 'threshold': 17058376.93295317}:
               Test failed
               Metric: 18440360.15


2024-05-29 13:33:44,677 pid:63742 MainThread giskard.datasets.base INFO     Casting dataframe columns from {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'} to {'age': 'int64', 'sex': 'object', 'bmi': 'float64', 'children': 'int64', 'smoker': 'object', 'region': 'object'}
2024-05-29 13:33:44,678 pid:63742 MainThread giskard.utils.logging_utils INFO     Predicted dataset with shape (88, 7) executed in 0:00:00.004697
Executed 'MSE on data slice β€œ`region` == "southeast"”' with arguments {'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x177ca0ac0>, 'threshold': 17058376.93295317}:
               Test failed
               Metric: 17201661.08


2024-05-29 13:33:44,682 pid:63742 MainThread giskard.core.suite INFO     Executed test suite 'Test suite'
2024-05-29 13:33:44,682 pid:63742 MainThread giskard.core.suite INFO     result: failed
2024-05-29 13:33:44,683 pid:63742 MainThread giskard.core.suite INFO     MSE on data slice β€œ`region` == "northeast"” ({'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x1741309d0>, 'threshold': 17058376.93295317}): {failed, metric=20989697.452960182}
2024-05-29 13:33:44,683 pid:63742 MainThread giskard.core.suite INFO     MSE on data slice β€œ`sex` == "female"” ({'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x173786ce0>, 'threshold': 17058376.93295317}): {failed, metric=18686445.912572958}
2024-05-29 13:33:44,683 pid:63742 MainThread giskard.core.suite INFO     MSE on data slice β€œ`smoker` == "no"” ({'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x177ca3430>, 'threshold': 17058376.93295317}): {failed, metric=18440360.147996694}
2024-05-29 13:33:44,684 pid:63742 MainThread giskard.core.suite INFO     MSE on data slice β€œ`region` == "southeast"” ({'model': <giskard.models.function.PredictionFunctionModel object at 0x1737879d0>, 'dataset': <giskard.datasets.base.Dataset object at 0x173787340>, 'slicing_function': <giskard.slicing.slice.QueryBasedSliceFunction object at 0x177ca0ac0>, 'threshold': 17058376.93295317}): {failed, metric=17201661.078229036}
[13]:
close Test suite failed.
Test MSE on data slice β€œ`region` == "northeast"”
Measured Metric = 20989697.45296 close Failed
model insurance model
dataset insurance dataset
slicing_function `region` == "northeast"
threshold 17058376.93295317
Test MSE on data slice β€œ`sex` == "female"”
Measured Metric = 18686445.91257 close Failed
model insurance model
dataset insurance dataset
slicing_function `sex` == "female"
threshold 17058376.93295317
Test MSE on data slice β€œ`smoker` == "no"”
Measured Metric = 18440360.148 close Failed
model insurance model
dataset insurance dataset
slicing_function `smoker` == "no"
threshold 17058376.93295317
Test MSE on data slice β€œ`region` == "southeast"”
Measured Metric = 17201661.07823 close Failed
model insurance model
dataset insurance dataset
slicing_function `region` == "southeast"
threshold 17058376.93295317

Customize your suite by loading objects from the Giskard catalogΒΆ

The Giskard open source catalog will enable to load:

  • Tests such as metamorphic, performance, prediction & data drift, statistical tests, etc

  • Slicing functions such as detectors of toxicity, hate, emotion, etc

  • Transformation functions such as generators of typos, paraphrase, style tune, etc

To create custom tests, refer to this page.

For demo purposes, we will load a simple unit test (test_rmse) that checks if the test RMSE score is below the given threshold. For more examples of tests and functions, refer to the Giskard catalog.

[ ]:
test_suite.add_test(testing.test_rmse(model=giskard_model, dataset=giskard_dataset, threshold=10.0)).run()