VFL SplitNN Model Training

VFL Training Session Example

Use the integrateai_taskrunner_AWS.ipynb notebook to follow along and test the examples shown below by filling in your own variables as required. You can download sample notebooks here.

The notebook demonstrates a sequential workflow for the PRL session and the VFL train and predict sessions.

To create a VFL train session, specify the prl_session_id indicating the session you just ran to link the datasets together. The vfl_mode needs to be set to train.

  1. Ensure that you have run a PRL session to obtain the aligned dataset. The PRL session ID is required for the VFL training session.

  2. Create a model_config and a data_config for the VFL session.

model_config = {
    "strategy": {"name": "SplitNN", "params": {}},
    "model": {
        "feature_models": {
            "passive_client": {"params": {"input_size": 7, "hidden_layer_sizes": [6], "output_size": 5}},
            "active_client": {"params": {"input_size": 8, "hidden_layer_sizes": [6], "output_size": 5}},
        },
        "label_model": {"params": {"hidden_layer_sizes": [5], "output_size": 2}},
    },
    "ml_task": {
        "type": "classification",
        "params": {
            "loss_weights": None,
        },
    },
    "optimizer": {"name": "SGD", "params": {"learning_rate": 0.2, "momentum": 0.0}},
    "seed": 23,  # for reproducibility
}

data_config = {
        "passive_client": {
            "label_client": False,
            "predictors": ["x1", "x3", "x5", "x7", "x9", "x11", "x13"],
            "target": None,
        },
        "active_client": {
            "label_client": True,
            "predictors": ["x0", "x2", "x4", "x6", "x8", "x10", "x12", "x14"],
            "target": "y",
        },
    }

Parameters:

  • strategy: Specify the name and parameters. For VFL, the strategy is SplitNN.

  • model: Specify the feature_models and label_model.

    • feature_models refers to the part of the model that transforms the raw input features into intermediate encoded columns (usually hosted by both parties).

    • label_model refers to the part of the model that connects the intermediate encoded columns to the target variable (usually hosted by the active party).

  • ml_task: Specify the type of machine learning task, and any associated parameters. Options are classification or regression.

  • optimizer: Specify any optimizer supported by PyTorch.

  • seed: Specify a number.

Create and start a VFL training session

Specify the PRL session ID of a succssful PRL session that used the same active and passive client names.

Ensure that the vfl_mode is set to train.

vfl_train_session = client.create_vfl_session(
    name="Testing notebook - VFL Train",
    description="I am testing VFL Train session creation with a task runner through a notebook",
    prl_session_id=prl_session.id,
    vfl_mode='train',
    min_num_clients=2,
    num_rounds=2,
    package_name="iai_ffnet",
    data_config=data_config,
    model_config=model_config
).start()

vfl_train_session.id

Set up the task builder and task group for VFL training

Create a task in the task group for each client. The number of client tasks in the task group must match the number of clients specified in the data_config used to create the session.

vfl_task_group_context = (SessionTaskGroup(vfl_train_session)\
    .add_task(iai_tb_aws.vfl_train(train_dataset_name="active_train", test_dataset_name="active_test", 
                                    batch_size=16384 * 32,
                                    eval_batch_size=5000000, 
                                    job_timeout_seconds=28800, 
                                    client_name="active_client", 
                                    storage_path = storage_path))\
    .add_task(iai_tb_aws.vfl_train(train_dataset_name="passive_train", test_dataset_name="passive_test", 
                                    batch_size=16384 * 32,
                                    eval_batch_size=5000000,
                                    job_timeout_seconds=28800,
                                    client_name="passive_client", 
                                    storage_path = storage_path))\
    .start())

The following parameters are required for each client task:

  • train_dataset_name OR train_path

  • test_dataset_name OR test_path

  • batch_size: specify a value for batching the train dataset during training. The default value of 1024 is meant for use only with small datasets (~100MB).

  • eval_batch_size: specify a value for batching the test dataset during training with large datasets.

  • job_timeout_seconds: specify a value that corresponds to the amount of time the training session is estimated to take. For large datasets, the default value of 7200 must be increased.

  • client_name: must match the client_name specified in the PRL session used to determine the overlap

  • storage_path: model storage location

Monitor submitted VFL training jobs

Each task in the task group kicks off a job in AWS Batch. You can monitor the jobs through the notebook, as shown.

# Check the status of the tasks
for i in vfl_task_group_context.contexts.values():
    print(json.dumps(i.status(), indent=4))

vfl_task_group_context.monitor_task_logs()

# Wait for the tasks to complete (success = True)
vfl_task_group_context.wait(60*5, 2)

View the VFL training metrics

Once the session completes successfully, you can view the training metrics.

metrics = vfl_train_session.metrics().as_dict() metrics

# Example results

{'session_id': '498beb7e6a',
 'federated_metrics': [{'loss': 0.6927943530912943},
  {'loss': 0.6925891094472265},
  {'loss': 0.6921983339753467},
  {'loss': 0.6920029462394067},
  {'loss': 0.6915351291650617}],
 'client_metrics': [{'user@integrate.ai:79704ac8c1a7416aa381288cbab16e6a': {'test_roc_auc': 0.5286237121001411,
    'test_num_examples': 3245,
    'test_loss': 0.6927943530912943,
    'test_accuracy': 0.5010785824345146}},
  {'user@integrate.ai:79704ac8c1a7416aa381288cbab16e6a': {'test_num_examples': 3245,
    'test_accuracy': 0.537442218798151,
    'test_roc_auc': 0.5730010669487545,
    'test_loss': 0.6925891094472265}},
  {'user@integrate.ai:79704ac8c1a7416aa381288cbab16e6a': {'test_accuracy': 0.550693374422188,
    'test_roc_auc': 0.6073282812853845,
    'test_loss': 0.6921983339753467,
    'test_num_examples': 3245}},
  {'user@integrate.ai:79704ac8c1a7416aa381288cbab16e6a': {'test_loss': 0.6920029462394067,
    'test_roc_auc': 0.6330078151716465,
    'test_accuracy': 0.5106317411402157,
    'test_num_examples': 3245}},
  {'user@integrate.ai:79704ac8c1a7416aa381288cbab16e6a': {'test_roc_auc': 0.6495852274713467,
    'test_loss': 0.6915351291650617,
    'test_accuracy': 0.5232665639445301,
    'test_num_examples': 3245}}]}

Plot the VFL training metrics.

fig = vfl_train_session.metrics().plot()

Example of plotted training metrics

VFL training metrics

Back to VFL Model Training overview.