Gradient Boosted Models for VFL¶
Gradient boosting is a machine learning algorithm for building predictive models that helps minimize the bias error of the model. The gradient boosting model for VFL provided by integrate.ai uses the sklearn implementation of HistGradientBoostingClassifier
for classifier tasks and HistGradientBoostingRegresssor
for regression tasks.
Run a PRL session to align the datasets for VFL-GBM¶
Before you run a VFL session, you must first run a PRL session to align the datasets. For more information, see Private Record Linkage (PRL) Sessions.
Specify the VFL-GBM model and data configuration¶
integrate.ai has a model class available for Gradient Boosted Models, called iai_gbm
. This model is defined using a JSON configuration file during session creation.
The strategy for VFL-GBM is VflGbm
. Note that this is different than the strategy name for an HFL GBM session, which is HistGradientBoosting
.
model_config = {
"strategy": {"name": "VflGbm"},
"model": {
"params": {
"max_depth": 4,
"learning_rate": 0.05,
"random_state": 23,
"max_bins": 255,
}
},
"ml_task": {"type": 'classification', "params": {}},
}
data_config = {
"passive_client": {
"label_client": False,
"predictors": ["x1", "x3", "x5", "x7", "x9", "x11", "x13"],
"target": None,
},
"active_client": {
"label_client": True,
"predictors": ["x0", "x2", "x4", "x6", "x8", "x10", "x12", "x14"],
"target": "y",
},
}
You can adjust the following parameters as needed:
max_depth
- Used to control the size of the trees.learning_rate
- (shrinkage) Used as a multiplicative factor for the leaves values. Set this to one (1) for no shrinkage.random_state
- Pseudo-random number generator to control the subsampling in the binning process, and the train/validation data split if early stopping is enabled. Pass an int for reproducible output across multiple function calls. Set to RandomState instance or None. The default is None.max_bins
- The number of bins used to bin the data. Using less bins acts as a form of regularization. It is generally recommended to use as many bins as possible.
For more information, see the scikit documentation for HistGradientBoostingClassifier.
Set the machine learning task type to either classification
or regression
.
Specify any parameters associated with the task type in the params
section.
The notebook also provides a sample data schema. For the purposes of testing VFL-GBM, use the sample schema as shown.
Create a VFL-GBM training session¶
Federated learning models created in integrate.ai are trained through sessions. You define the parameters required to train a federated model, including data and model configurations, in a session.
Create a session each time you want to train a new model.
The code sample demonstrates creating and starting a session with two training clients (two datasets) and five rounds. It returns a session ID that you can use to track and reference your session.
vfl_train_session = client.create_vfl_session(
name="Testing notebook - VFL-GBM training",
description="I am testing VFL GBM training session creation through a notebook",
prl_session_id=prl_session.id,
vfl_mode="train",
min_num_clients=2,
num_rounds=5,
package_name="iai_ffnet",
data_config=data_config,
model_config=model_config,
).start()
vfl_train_session.id
Create and start a task group to run the training session¶
The next step is to join the session with the sample data. This example has data for two datasets simulating two clients, as specified with the min_num_clients
argument. The session begins training once the minimum number of clients have joined the session.
Each client must have a unique name that matches the name specified in the data_config
. For example, active_client
and passive_client
.
# Create and start a task group with one task for each of the clients joining the session
vfl_task_group_context = (SessionTaskGroup(vfl_train_session)\
.add_task(iai_tb_aws.vfl_train(train_dataset_name=active_train_path,
test_dataset_name=active_test_path,
batch_size=1024,
client_name="active_client"))\
.add_task(iai_tb_aws.vfl_train(train_dataset_name=passive_train_path,
test_dataset_name=passive_test_path,
batch_size=1024,
client_name="passive_client"))\
.start())
Wait for session results
Depending on the type of session and the size of the datasets, sessions may take some time to run. You can poll the server to determine the session status, or wait for the session status to change to “Completed” in the UI.
View the VFL-GBM training metrics¶
After the session completes successfully, you can view the training metrics and start making predictions.
Retrieve the model metrics as_dict.
Plot the metrics.
metrics = vfl_train_session.metrics().as_dict()
metrics
fig = vfl_train_session.metrics().plot()
VFL-GBM Prediction¶
After you have completed a successful PRL and VFL train session, you can use those sessions to create a VFL prediction session.
# Create and start a VFL predict session
vfl_predict_session = client.create_vfl_session(
name="Testing notebook - VFL Predict",
description="I am testing VFL Predict session creation with an AWS task runner through a notebook",
prl_session_id=prl_session.id,
training_session_id=vfl_train_session.id,
vfl_mode="predict",
data_config=data_config,
).start()
vfl_predict_session.id
For more information about VFL predict
mode, see VFL Prediction Session Example.
Back to VFL Model Training