Deploying a multi-framework model
Within Verta, a "Model" can be any arbitrary function: a traditional ML model (e.g., sklearn, PyTorch, TF, etc); a function (e.g., squaring a number, making a DB function etc.); or a mixture of the above (e.g., pre-processing code, a DB call, and then a model application.
This tutorial provides an example of how to deploy models using multiple frameworks on Verta. In this case, we will consider a model that uses scikit-learn and XGBoost.
The key concept in Verta for model deployment is an Endpoint. An endpoint is a URL where a deployed model becomes available for use. Deploying a model is therefore a 2-step process:
- 1.Create an endpoint
- 2.Update the endpoint with a model
We'll look at these in turn.
endpoint = client.create_endpoint(path="/some-path")
To deploy a model utilizing multiple frameworks, we wrap the logic into a class that extends VertaModelBase.
from verta.registry import VertaModelBase
class CensusTwoStep(VertaModelBase):
def __init__(self, artifacts):
import cloudpickle
self.hpw_model = cloudpickle.load(
open(artifacts["sklearn_model"], "rb"))
self.income_model = cloudpickle.load(
open(artifacts["xgboost_model"], "rb"))
def predict(self, batch_input):
import numpy as np
results = []
for one_input in batch_input:
output = self.hpw_model.predict(one_input)
output = np.concatenate((np.array(one_input), np.reshape(output, (-1,1))), axis=1)
output = self.income_model.predict(output)
results.append(output)
return results
Note that the different frameworks likely expect input/output in different formats and your class needs to account for that.
Once the class has been defined, we create a Registered Model Version with it and update the endpoint.
from verta.environment import Python
model_version = registered_model.create_standard_model(
model_cls=CensusTwoStep,
environment=Python(requirements=["sklearn", "xgboost"]),
artifacts={
"hpw_model" : hpw_model,
"income_model" : income_model
},
name="v6"
)
census_multiple_endpoint = client.get_or_create_endpoint("census-multiple")
census_multiple_endpoint.update(model_version, wait=True)
deployed_model = census_multiple_endpoint.get_deployed_model()
deployed_model.predict([X_train_hpw.values.tolist()[:5]])
Last modified 1yr ago