VFL Prediction Session Example¶
This example continues the workflow in the previous sections: PRL Session Example and VFL Training Session Example.
Create and start a VFL prediction session¶
To create a VFL prediction session, specify the PRL session ID (prl_session_id
) and the VFL train session ID (training_session_id
) from your previous succesful PRL and VFL sessions.
Set the vfl_mode
to predict
.
# Example configuration of a VFL predict session
vfl_predict_session = client.create_vfl_session(
name="Testing notebook - VFL Predict",
description="I am testing VFL Predict session creation with an AWS task runner through a notebook",
prl_session_id=prl_session.id,
training_session_id=vfl_train_session.id,
vfl_mode="predict",
data_config=data_config
).start()
vfl_predict_session.id
Specify the full path for the storage location for your predictions, including the file name.
Note: by default the task runner creates a bucket for you to upload data into (e.g. s3://{aws_taskrunner_profile}-{aws_taskrunner_name}.integrate.ai
). Only the default S3 bucket and other buckets ending in *integrate.ai
are supported. If you are not using the default bucket created by the task runner when it was provisioned, ensure that your data is hosted in an S3 bucket with a URL ending in *integrate.ai
. Otherwise, the data will not be accessible to the task runner.
#Where to store VFL predictions - must be full path and file name
vfl_predict_active_storage_path = f's3://{base_aws_bucket}/vfl_predict/active_predictions'
vfl_predict_passive_storage_path = f's3://{base_aws_bucket}/vfl_predict/passive_predictions'
Create and start a task group for the prediction session¶
Create a task in the task group for each client. The number of client tasks in the task group must match the number of clients specified in the data_config
used to create the PRL and VFL train sessions.
storage_path = f"{vfl_predict_active_storage_path}_{vfl_predict_session.id}.csv"
vfl_predict_task_group_context = (SessionTaskGroup(vfl_predict_session)\
.add_task(iai_tb_aws.vfl_predict(
client_name="active_client",
dataset_path=active_test_path,
raw_output=True,
batch_size=1024,
storage_path = storage_path))\
.add_task(iai_tb_aws.vfl_predict(
client_name="passive_client",
dataset_path=passive_test_path,
batch_size=1024,
raw_output=True,
storage_path = storage_path))\
.start())
The following parameters are required for each client task:
client_name
- must be the same as the client name specified in the PRL and VFL train sessionsdataset_path
- the name of a registered datasetbatch_size
- set to a default valueraw_output
- raw_output (bool, optional): whether the raw model output should be saved. Defaults to False, in which case, a transformation corresponding to the ml task is applied.storage_path
- must be the full path and file name
Monitor submitted VFL prediction jobs¶
Each task in the task group kicks off a job in AWS Batch. You can monitor the jobs through the notebook, as shown.
# Example of monitoring tasks
for i in vfl_predict_task_group_context.contexts.values():
print(json.dumps(i.status(), indent=4))
vfl_predict_task_group_context.monitor_task_logs()
# Wait for the tasks to complete (success = True)
vfl_predict_task_group_context.wait(60*5, 2)
View VFL predictions¶
After the predict session completes successfully, you can view the predictions from the Active party and evaluate the performance.
# Retrieve the metrics
metrics = vfl_predict_session.metrics().as_dict()
metrics
presigned_result_urls = vfl_predict_session.prediction_result()
print(vfl_predict_active_storage_path)
df_pred = pd.read_csv(presigned_result_urls.get(storage_path))
df_pred.head()
Example output:
Back to VFL Model Training overview