Batch Sessions¶
You can run VFL or HFL sessions in batches to quickly iterate over parameters. The integrate.ai client has a batch method that runs sessions in batches, based on the parameters provided in the model_config for the session.
Model config iterations allow you to create a batch by iterating over model configuration parameters.
For example:
model_config_params = {
"optimizer.params.learning_rate": [0.1, 0.2, 0.3],
}
This code launches 3 sessions with the optimizer.params.learning_rate parameter set to one of the specified values. The batch method executes only the configured number of parallel sessions at a time.
If you specify several parameters, the number of sessions is equal to the number of possible permutations, that is, it is a multuple of value list sizes.
Specify the VFL data configuration¶
To run batch sessions, specify the data_config for the VFL session as usual.
data_config = {
"passive_client": {
"label_client": False,
"predictors": ["x1", "x3", "x5", "x7"],
"target": None,
},
"active_client": {
"label_client": True,
"predictors": ["x0", "x2", "x4", "x6"],
"target": "y",
},
}
Specify the VFL batch model configuration¶
Specify the model_config as usual, and specify the parameters to iterate over in the model_config_params.
model_config = {
"strategy": {"name": "VflGlm", "params": {"expand_duplicates": False}},
"model": {
"passive_client": {"params": {}},
"active_client": {"params": {}},
},
"ml_task": {
"type": "regression",
"loss_function": "mse",
"params": {},
},
"optimizer": {"name": "SGD", "params": {"learning_rate": 0.01, "momentum": 0.0}}, # specifies the optimizer and parameters
"influence_score": {"enable": False, "params": {}}, # enables/disables influence score calculation
"feature_importance_score": {"enable": False, "params": {}}, # enables/disables feature importance score calculation
"seed": 23, # for reproducibility
}
# Iterate over all possible combinations of model config parameters:
model_config_params = {
"optimizer.params.learning_rate": [0.1, 0.2, 0.3],
}
Create and start session batches¶
batch_id, batch_sessions = vfl_batch = client.run_vfl_batch(
name="Testing VFL batch",
description="batching VFL",
prl_session_id=prl_session.id,
vfl_mode='train',
min_num_clients=2,
num_rounds=2,
package_name="iai_glm",
data_config=data_config,
model_config=model_config,
# Batch params:
model_config_params=model_config_params,
# How many sessions to run in parallel
capacity=4,
# Add one task for each client, as usual for a VFL session
tasks=[
iai_tb_aws.vfl_train(
train_dataset_name=active_train_dataset, test_dataset_name=active_test_dataset,
batch_size=1024,
client_name="active_client"),
iai_tb_aws.vfl_train(
train_dataset_name=passive_train_dataset, test_dataset_name=passive_test_dataset,
batch_size=1024,
client_name="passive_client"),
],
)
batch.id
Batch complete¶
Now you can view the VFL training metrics and start making predictions.
batch_sessions is a dict keyed by session objects that stores task group context, data config, and model config for each session.
#print(batch_sessions)
print("batch_id", batch_id)
for s in batch_sessions.keys():
print(s.id, s.name, s.status)
#s.metrics().plot()
You also can query the API for a previously executed batch. Specify the batch_id for the batch you want to review.
for s in client.get_batch_sessions(batch_id):
print(s.id, s.name, s.status)