VFL Prediction Session Example

This example continues the workflow in the previous sections: PRL Session Example and VFL Training Session Example.

Create and start a VFL prediction session

To create a VFL prediction session, specify the PRL session ID (prl_session_id) and the VFL train session ID (training_session_id) from your previous succesful PRL and VFL sessions.

Set the vfl_mode to predict.

# Example configuration of a VFL predict session

vfl_predict_session = client.create_vfl_session(
    name="Testing notebook - VFL Predict",
    description="I am testing VFL Predict session creation with an AWS task runner through a notebook",
    prl_session_id=prl_session.id,
    training_session_id=vfl_train_session.id,
    vfl_mode="predict",
    data_config=data_config
).start()

vfl_predict_session.id

Specify the full path for the storage location for your predictions, including the file name.

Note: by default the task runner creates a bucket for you to upload data into (e.g. s3://{aws_taskrunner_profile}-{aws_taskrunner_name}.integrate.ai ). Only the default S3 bucket and other buckets ending in *integrate.ai are supported. If you are not using the default bucket created by the task runner when it was provisioned, ensure that your data is hosted in an S3 bucket with a URL ending in *integrate.ai. Otherwise, the data will not be accessible to the task runner.

#Where to store VFL predictions - must be full path and file name
vfl_predict_active_storage_path = f's3://{base_aws_bucket}/vfl_predict/active_predictions'
vfl_predict_passive_storage_path = f's3://{base_aws_bucket}/vfl_predict/passive_predictions'

Create and start a task group for the prediction session

Create a task in the task group for each client. The number of client tasks in the task group must match the number of clients specified in the data_config used to create the PRL and VFL train sessions.

storage_path = f"{vfl_predict_active_storage_path}_{vfl_predict_session.id}.csv"

vfl_predict_task_group_context = (SessionTaskGroup(vfl_predict_session)\
.add_task(iai_tb_aws.vfl_predict(
        client_name="active_client", 
        dataset_path=active_test_path, 
        raw_output=True,
        batch_size=1024, 
        storage_path = storage_path))\
.add_task(iai_tb_aws.vfl_predict(
        client_name="passive_client",
        dataset_path=passive_test_path,
        batch_size=1024,
        raw_output=True,
        storage_path = storage_path))\
.start())

The following parameters are required for each client task:

  • client_name - must be the same as the client name specified in the PRL and VFL train sessions

  • dataset_path - the name of a registered dataset

  • batch_size - set to a default value

  • raw_output - raw_output (bool, optional): whether the raw model output should be saved. Defaults to False, in which case, a transformation corresponding to the ml task is applied.

  • storage_path - must be the full path and file name

Monitor submitted VFL prediction jobs

Each task in the task group kicks off a job in AWS Batch. You can monitor the jobs through the notebook, as shown.

# Example of monitoring tasks

for i in vfl_predict_task_group_context.contexts.values():
    print(json.dumps(i.status(), indent=4))

vfl_predict_task_group_context.monitor_task_logs()

# Wait for the tasks to complete (success = True)

vfl_predict_task_group_context.wait(60*5, 2)

View VFL predictions

After the predict session completes successfully, you can view the predictions from the Active party and evaluate the performance.

# Retrieve the metrics

metrics = vfl_predict_session.metrics().as_dict()
metrics
presigned_result_urls = vfl_predict_session.prediction_result()

print(vfl_predict_active_storage_path)
df_pred = pd.read_csv(presigned_result_urls.get(storage_path))

df_pred.head()

Example output:

VFL prediction results

Back to VFL Model Training overview