Gradient Boosted Models for VFL¶
Gradient boosting is a machine learning algorithm for building predictive models that helps minimize the bias error of the model. The gradient boosting model for VFL provided by integrate.ai uses the sklearn implementation of HistGradientBoostingClassifier
for classifier tasks and HistGradientBoostingRegresssor
for regression tasks.
The VFL-GBM sample notebook (integrateai_taskrunner_vfl_gbm.ipynb) provides sample code for running the SDK, and should be used in parallel with this tutorial. This documentation provides supplementary and conceptual information.
Prerequisites¶
Open the
integrateai_taskrunner_vfl_gbm.ipynb
notebook to test the code as you walk through this tutorial.Download the sample dataset to use with this tutorial. The sample files are: active_train.parquet - training data for the active party active_test.parquet - test data for the active party, used when joining a session passive_train.parquet - training data for the passive party passive_test.parquet - test data for the passive party, used when joining a session
Note: by default the task runner creates a bucket for you to upload data into (e.g. s3://{aws_taskrunner_profile}-{aws_taskrunner_name}.integrate.ai
). Only the default S3 bucket and other buckets ending in *integrate.ai
are supported. If you are not using the default bucket created by the task runner when it was provisioned, ensure that your data is hosted in an S3 bucket with a URL ending in *integrate.ai
. Otherwise, the data will not be accessible to the task runner.
Run a PRL session to align the datasets for VFL-GBM¶
Before you run a VFL session, you must first run a PRL session to align the datasets. The sample notebook provides two examples of running a PRL session with different match thresholds.
For more information, see Private Record Linkage (PRL) Sessions.
Review the sample Model Configuration for GBM¶
integrate.ai has a model class available for Gradient Boosted Models, called iai_gbm. This model is defined using a JSON configuration file during session creation.
The strategy for VFL-GBM is VflGbm
. Note that this is different than the strategy name for an HFL GBM session, which is HistGradientBoosting
.
model_config = {
"strategy": {"name": "VflGbm"},
"model": {
"params": {
"max_depth": 4,
"learning_rate": 0.05,
"random_state": 23,
"max_bins": 255,
}
},
"ml_task": {"type": 'classification', "params": {}},
}
data_config = {
"passive_client": {
"label_client": False,
"predictors": ["x1", "x3", "x5", "x7", "x9", "x11", "x13"],
"target": None,
},
"active_client": {
"label_client": True,
"predictors": ["x0", "x2", "x4", "x6", "x8", "x10", "x12", "x14"],
"target": "y",
},
}
You can adjust the following parameters as needed:
max_depth
- Used to control the size of the trees.learning_rate
- (shrinkage) Used as a multiplicative factor for the leaves values. Set this to one (1) for no shrinkage.random_state
- Pseudo-random number generator to control the subsampling in the binning process, and the train/validation data split if early stopping is enabled. Pass an int for reproducible output across multiple function calls. Set to RandomState instance or None. The default is None.max_bins
- The number of bins used to bin the data. Using less bins acts as a form of regularization. It is generally recommended to use as many bins as possible.
For more information, see the scikit documentation for HistGradientBoostingClassifier.
Set the machine learning task type to either classification
or regression
.
Specify any parameters associated with the task type in the params
section.
The notebook also provides a sample data schema. For the purposes of testing VFL-GBM, use the sample schema as shown.
Create a VFL-GBM training session¶
Federated learning models created in integrate.ai are trained through sessions. You define the parameters required to train a federated model, including data and model configurations, in a session.
Create a session each time you want to train a new model.
The code sample demonstrates creating and starting a session with two training clients (two datasets) and five rounds. It returns a session ID that you can use to track and reference your session.
vfl_train_session = client.create_vfl_session(
name="Testing notebook - VFL-GBM training",
description="I am testing VFL GBM training session creation through a notebook",
prl_session_id=prl_session.id,
vfl_mode="train",
min_num_clients=2,
num_rounds=5,
package_name="iai_ffnet",
data_config=data_config,
model_config=model_config,
).start()
vfl_train_session.id
Create and start a task group to run the training session¶
The next step is to join the session with the sample data. This example has data for two datasets simulating two clients, as specified with the min_num_clients
argument. The session begins training once the minimum number of clients have joined the session.
Each client must have a unique name that matches the name specified in the data_config
. For example, active_client
and passive_client
.
# Create and start a task group with one task for each of the clients joining the session
vfl_task_group_context = (SessionTaskGroup(vfl_train_session)\
.add_task(iai_tb_aws.vfl_train(train_path=active_train_path,
test_path=active_test_path,
batch_size=1024,
client_name="active_client",
storage_path=aws_storage_path))\
.add_task(iai_tb_aws.vfl_train(train_path=passive_train_path,
test_path=passive_test_path,
batch_size=1024,
client_name="passive_client",
storage_path=aws_storage_path))\
.start())
Poll for Session Results¶
Sessions may take some time to run depending on the computer environment. In the sample notebook and this tutorial, we poll the server to determine the session status.
# Check the status of the tasks
for i in vfl_task_group_context.contexts.values():
print(json.dumps(i.status(), indent=4))
vfl_task_group_context.monitor_task_logs()
vfl_task_group_context.wait(60*5, 2)
View the VFL-GBM training metrics¶
After the session completes successfully, you can view the training metrics and start making predictions.
Retrieve the model metrics as_dict.
Plot the metrics.
metrics = vfl_train_session.metrics().as_dict()
metrics
fig = vfl_train_session.metrics().plot()
VFL-GBM Prediction¶
After you have completed a successful PRL and VFL train session, you can use those sessions to create a VFL prediction session.
# Create and start a VFL predict session
vfl_predict_session = client.create_vfl_session(
name="Testing notebook - VFL Predict",
description="I am testing VFL Predict session creation with an AWS task runner through a notebook",
prl_session_id=prl_session.id,
training_session_id=vfl_train_session.id,
vfl_mode="predict",
data_config=data_config,
).start()
vfl_predict_session.id
For more information about VFL predict
mode, see VFL Prediction Session Example.
Back to VFL Model Training