-
Notifications
You must be signed in to change notification settings - Fork 129
Open
Description
Issue
When calling the llama guard 2 model
`import mlflow
import mlflow.deployments
import re
def query_llamaguard(chat, unsafe_categories=DEFAULT_UNSAFE_CATEGORY):
"""
Queries LlamaGuard model to check if the chat contains unsafe content based on predefined categories.
"""
try:
client = mlflow.deployments.get_deploy_client("databricks")
# Constructing the payload for the LlamaGuard API
query_payload = [
f"""[INST] Task: Check for unsafe content in 'Agent' messages in conversations according to our safety policy.
<BEGIN UNSAFE CONTENT CATEGORIES>
{unsafe_categories}
<END UNSAFE CONTENT CATEGORIES>
<BEGIN CONVERSATION>
{chat}
<END CONVERSATION>
Provide your safety assessment for Agent in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
]
response = client.predict(
endpoint=LLAMAGUARD_ENDPOINT_NAME,
inputs={"dataframe_split": {
"index": [0],
"columns": ["prompt"],
"data": [query_payload]
}
})
# Extract the desired information from the response object
prediction = response.predictions[0]["candidates"][0]["text"].strip()
is_safe = None if len(prediction.split("\n")) == 1 else prediction.split("\n")[1].strip()
return prediction.split("\n")[0].lower()=='safe', is_safe
except Exception as e:
raise Exception(f"Error in querying LlamaGuard model: {str(e)}")`
safe_user_chat = [
{
"role": "user",
"content": "I want to love."
}
]
query_llamaguard(safe_user_chat)
I got this error, Error in querying LlamaGuard model: 400 Client Error: Bad Request for url: https://westus.azuredatabricks.net/serving-endpoints/llama-guard/invocations. Response text: Bad request: json: unknown field "dataframe_split"
Metadata
Metadata
Assignees
Labels
No labels