VFL Predictions¶
After you have completed a successful PRL session and a VFL training session, you can use these two sessions to create a prediction session.
In the example, the sessions are run in sequence, so the session IDs for the PRL and VFL train sessions are readily available to use in the predict session. If instead you run a PRL session and want to reuse the session later in a different VFL session, make sure that you save the session ID (prl_session.id). Then you can provide the session ID directly in the predict session setup instead of relying on the variable. The 3 sessions must use the same client_name and datasets in order to run successfully.
To create a VFL prediction session:
Specify the PRL session ID (
prl_session_id)Specify the VFL train session ID (
training_session_id) from your previous succesful PRL and VFL sessions.Set the
vfl_modetopredict.
# Create and start (load) a VFL-GLM prediction session
vfl_predict_session = client.create_vfl_session(
name="Testing notebook - VFL Predict",
description="I am testing VFL Predict session creation with an AWS task runner through a notebook",
prl_session_id=prl_session1.id,
training_session_id=fl_train_session1.id,
vfl_mode="predict",
data_config=data_config
).start()
vfl_predict_session.id # Prints the session ID for reference
Create and start VFL Predict session¶
Create and start a task group with one task for each of the clients joining the session
# Create the task builder
from integrate_ai_sdk.taskgroup.taskbuilder.integrate_ai import IntegrateAiTaskBuilder
from integrate_ai_sdk.taskgroup.base import SessionTaskGroup
iai_tb_aws_consumer = IntegrateAiTaskBuilder(client=client, task_runner_id="")
iai_tb_aws_provider = IntegrateAiTaskBuilder(client=client, task_runner_id="")
# Create the session
vfl_predict_task_group_context = (SessionTaskGroup(vfl_predict_session)\
.add_task(iai_tb_aws_consumer.vfl_predict(
client_name=consumer_train_name,
dataset_name=consumer_train_name,
raw_output=True,
batch_size=1024))\
.add_task(iai_tb_aws_provider.vfl_predict(
client_name=acme_data,
dataset_name=provider_train_name,
batch_size=1024,
raw_output=True))\
).start()
The following parameters are required for each client task:
client_name- must be the same as theclient_namespecified in the PRL and VFLtrainsessionsdataset_path- the variable name of a registered datasetbatch_size- set to 1024 by default, but can be increased for larger batchesraw_output(bool, optional) - whether the raw model output should be saved. Defaults toFalse, in which case a transformation corresponding to theml_taskis applied.
Sessions may take some time to run depending on the compute environment. You can check the session status in the workspace, on the Sessions page. Wait for the status to update to Completed.
View VFL predictions¶
After the predict session completes successfully, you can view the predictions from the Active party and evaluate the performance.
# Retrieve the metrics
metrics = vfl_predict_session.metrics().as_dict()
metrics
import pandas as pd
presigned_result_urls = vfl_predict_session.prediction_result()
df_pred = pd.read_csv(presigned_result_urls.get(vfl_predict_active_storage_path))
df_pred.head()