VFL Predictions

After you have completed a successful PRL session and a VFL training session, you can use these two sessions to create a prediction session.

In the example, the sessions are run in sequence, so the session IDs for the PRL and VFL train sessions are readily available to use in the predict session. If instead you run a PRL session and want to reuse the session later in a different VFL session, make sure that you save the session ID (prl_session.id). Then you can provide the session ID directly in the predict session setup instead of relying on the variable. The 3 sessions must use the same client_name and datasets in order to run successfully.

To create a VFL prediction session:

  1. Specify the PRL session ID (prl_session_id)

  2. Specify the VFL train session ID (training_session_id) from your previous succesful PRL and VFL sessions.

  3. Set the vfl_mode to predict.

# Create and start (load) a VFL-GLM prediction session

vfl_predict_session = client.create_vfl_session(
    name="Testing notebook - VFL Predict",
    description="I am testing VFL Predict session creation with an AWS task runner through a notebook",
    prl_session_id=prl_session1.id,
    training_session_id=fl_train_session1.id,
    vfl_mode="predict",
    data_config=data_config
).start()

vfl_predict_session.id  # Prints the session ID for reference

Create and start VFL Predict session

Create and start a task group with one task for each of the clients joining the session

# Create the task builder
from integrate_ai_sdk.taskgroup.taskbuilder.integrate_ai import IntegrateAiTaskBuilder
from integrate_ai_sdk.taskgroup.base import SessionTaskGroup

iai_tb_aws_consumer = IntegrateAiTaskBuilder(client=client, task_runner_id="")
iai_tb_aws_provider = IntegrateAiTaskBuilder(client=client, task_runner_id="")

# Create the session
vfl_predict_task_group_context = (SessionTaskGroup(vfl_predict_session)\
.add_task(iai_tb_aws_consumer.vfl_predict(
        client_name=consumer_train_name, 
        dataset_name=consumer_train_name,
        raw_output=True,
        batch_size=1024))\
.add_task(iai_tb_aws_provider.vfl_predict(
        client_name=acme_data,
        dataset_name=provider_train_name,
        batch_size=1024,
        raw_output=True))\
).start()

The following parameters are required for each client task:

  • client_name - must be the same as the client_name specified in the PRL and VFL train sessions

  • dataset_path - the variable name of a registered dataset

  • batch_size - set to 1024 by default, but can be increased for larger batches

  • raw_output (bool, optional) - whether the raw model output should be saved. Defaults to False, in which case a transformation corresponding to the ml_task is applied.

Sessions may take some time to run depending on the compute environment. You can check the session status in the workspace, on the Sessions page. Wait for the status to update to Completed.

View VFL predictions

After the predict session completes successfully, you can view the predictions from the Active party and evaluate the performance.

# Retrieve the metrics
metrics = vfl_predict_session.metrics().as_dict()
metrics
import pandas as pd
presigned_result_urls = vfl_predict_session.prediction_result()

df_pred = pd.read_csv(presigned_result_urls.get(vfl_predict_active_storage_path))

df_pred.head()