Skip to content

Commit d6a4002

Browse files
committed
[chores] refactoring API
1 parent 7190c23 commit d6a4002

File tree

12 files changed

+196
-230
lines changed

12 files changed

+196
-230
lines changed

src/api/main.py

Lines changed: 9 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,17 @@
66
import os
77
from contextlib import asynccontextmanager
88
from pathlib import Path
9-
from typing import Annotated, List
9+
from typing import Annotated
1010

1111
import mlflow
1212
import yaml
1313
from fastapi import Depends, FastAPI
1414
from fastapi.middleware.cors import CORSMiddleware
1515
from fastapi.security import HTTPBasicCredentials
16-
from pydantic import BaseModel
1716

18-
from api.models.forms import BatchForms, SingleForm
17+
from api.routes import predict_batch, predict_single
1918
from utils.logging import configure_logging
20-
from utils.preprocessing import preprocess_inputs
21-
from utils.utils import (
22-
get_credentials,
23-
get_model,
24-
process_response,
25-
)
19+
from utils.security import get_credentials
2620

2721

2822
@asynccontextmanager
@@ -40,7 +34,9 @@ async def lifespan(app: FastAPI):
4034
logger = logging.getLogger(__name__)
4135
logger.info("🚀 Starting API lifespan")
4236

43-
app.state.model = get_model(os.environ["MLFLOW_MODEL_NAME"], os.environ["MLFLOW_MODEL_VERSION"])
37+
app.state.model = mlflow.pyfunc.load_model(
38+
model_uri=f"models:/{os.environ['MLFLOW_MODEL_NAME']}/{os.environ['MLFLOW_MODEL_VERSION']}"
39+
)
4440
run_data = mlflow.get_run(app.state.model.metadata.run_id).data.params
4541
app.state.training_names = [
4642
run_data["text_feature"],
@@ -54,58 +50,16 @@ async def lifespan(app: FastAPI):
5450
logger.info("🛑 Shutting down API lifespan")
5551

5652

57-
class BatchPredictionRequest(BaseModel):
58-
"""
59-
Pydantic BaseModel for representing the input data for the API.
60-
61-
This BaseModel defines the structure of the input data required
62-
for the API's "/predict-batch" endpoint.
63-
64-
Attributes:
65-
description_activity (List[str]): The text description.
66-
other_nature_activity (List[str]): Other nature of the activity.
67-
precision_act_sec_agricole (List[str]): Precision of the activity in the agricultural sector.
68-
type_form (List[str]): The type of the form CERFA.
69-
nature (List[str]): The nature of the activity.
70-
surface (List[str]): The surface of activity.
71-
cj (List[str]): The legal category code.
72-
activity_permanence_status (List[str]): The activity permanence status (permanent or seasonal).
73-
74-
"""
75-
76-
description_activity: List[str]
77-
other_nature_activity: List[str]
78-
precision_act_sec_agricole: List[str]
79-
type_form: List[str]
80-
nature: List[str]
81-
surface: List[str]
82-
cj: List[str]
83-
activity_permanence_status: List[str]
84-
85-
class Config:
86-
json_schema_extra = {
87-
"example": {
88-
"description_activity": [
89-
("LOUEUR MEUBLE NON PROFESSIONNEL EN RESIDENCE DE SERVICES (CODE APE 6820A Location de logements)")
90-
],
91-
"other_nature_activity": [""],
92-
"precision_act_sec_agricole": [""],
93-
"type_form": ["I"],
94-
"nature": [""],
95-
"surface": [""],
96-
"cj": [""],
97-
"activity_permanence_status": [""],
98-
}
99-
}
100-
101-
10253
app = FastAPI(
10354
lifespan=lifespan,
10455
title="Prédiction code APE",
10556
description="Application de prédiction pour l'activité principale de l'entreprise (APE)",
10657
version="0.0.1",
10758
)
10859

60+
app.include_router(predict_single.router)
61+
app.include_router(predict_batch.router)
62+
10963

11064
app.add_middleware(
11165
CORSMiddleware,
@@ -128,71 +82,3 @@ def show_welcome_page(
12882
"Model_name": f"{os.environ['MLFLOW_MODEL_NAME']}",
12983
"Model_version": f"{os.environ['MLFLOW_MODEL_VERSION']}",
13084
}
131-
132-
133-
@app.post("/predict", tags=["Predict"])
134-
async def predict(
135-
credentials: Annotated[HTTPBasicCredentials, Depends(get_credentials)],
136-
form: SingleForm,
137-
nb_echos_max: int = 5,
138-
prob_min: float = 0.01,
139-
):
140-
"""
141-
Predict code APE.
142-
143-
This endpoint accepts input data as query parameters and uses the loaded
144-
ML model to predict the code APE based on the input data.
145-
146-
Args:
147-
nb_echos_max (int): Maximum number of echoes to consider. Default is 5.
148-
prob_min (float): Minimum probability threshold. Default is 0.01.
149-
150-
Returns:
151-
dict: Response containing APE codes.
152-
"""
153-
154-
query = preprocess_inputs(app.state.training_names, [form])
155-
156-
predictions = app.state.model.predict(query, params={"k": nb_echos_max})
157-
158-
response = process_response(predictions, 0, nb_echos_max, prob_min, app.state.libs)
159-
160-
# Logging
161-
query_to_log = {key: value[0] for key, value in query.items()}
162-
logging.info(f"{{'Query': {query_to_log}, 'Response': {response}}}")
163-
164-
return response
165-
166-
167-
@app.post("/predict-batch", tags=["Predict"])
168-
async def predict_batch(
169-
credentials: Annotated[HTTPBasicCredentials, Depends(get_credentials)],
170-
forms: BatchForms,
171-
nb_echos_max: int = 5,
172-
prob_min: float = 0.01,
173-
):
174-
"""
175-
Endpoint for predicting batches of data.
176-
177-
Args:
178-
credentials (HTTPBasicCredentials): The credentials for authentication.
179-
forms (Forms): The input data in the form of Forms object.
180-
nb_echos_max (int, optional): The maximum number of predictions to return. Defaults to 5.
181-
prob_min (float, optional): The minimum probability threshold for predictions. Defaults to 0.01.
182-
183-
Returns:
184-
list: The list of predicted responses.
185-
"""
186-
query = preprocess_inputs(app.state.training_names, forms.forms)
187-
188-
predictions = app.state.model.predict(query, params={"k": nb_echos_max})
189-
190-
response = [process_response(predictions, i, nb_echos_max, prob_min, app.state.libs) for i in range(len(predictions[0]))]
191-
192-
# Logging
193-
for line in range(len(query[app.state.training_names[0]])):
194-
query_line = {key: value[line] for key, value in query.items()}
195-
response_line = response[line]
196-
logging.info(f"{{'Query': {query_line}, 'Response': {response_line}}}")
197-
198-
return response

src/api/models/forms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class BatchForms(BaseModel):
3939

4040
@model_validator(mode="after")
4141
def check_description_not_empty(cls, values):
42-
forms = values.get("forms", [])
42+
forms = values.forms
4343
missing_indexes = [
4444
idx for idx, form in enumerate(forms) if not form.description_activity or form.description_activity.strip() == ""
4545
]

src/api/models/responses.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import Dict, Union
2+
3+
from pydantic import BaseModel, RootModel
4+
5+
6+
class Prediction(BaseModel):
7+
code: str
8+
probabilite: float
9+
libelle: str
10+
11+
12+
class PredictionResponse(RootModel[Dict[str, Union[Prediction, float]]]):
13+
pass

src/api/routes/batch.py

Whitespace-only changes.

src/api/routes/predict.py

Whitespace-only changes.

src/api/routes/predict_batch.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Annotated
2+
3+
from fastapi import APIRouter, Depends, Request
4+
from fastapi.security import HTTPBasicCredentials
5+
6+
from api.models.forms import BatchForms
7+
from api.models.responses import PredictionResponse
8+
from utils.logging import log_prediction
9+
from utils.prediction import process_response
10+
from utils.preprocessing import preprocess_inputs
11+
from utils.security import get_credentials
12+
13+
router = APIRouter(prefix="/batch", tags=["Predict a batch of activity"])
14+
15+
16+
@router.post("/predict", response_model=PredictionResponse)
17+
async def predict(
18+
credentials: Annotated[HTTPBasicCredentials, Depends(get_credentials)],
19+
request: Request,
20+
forms: BatchForms,
21+
nb_echos_max: int = 5,
22+
prob_min: float = 0.01,
23+
):
24+
"""
25+
Endpoint for predicting batches of data.
26+
27+
Args:
28+
credentials (HTTPBasicCredentials): The credentials for authentication.
29+
forms (Forms): The input data in the form of Forms object.
30+
nb_echos_max (int, optional): The maximum number of predictions to return. Defaults to 5.
31+
prob_min (float, optional): The minimum probability threshold for predictions. Defaults to 0.01.
32+
33+
Returns:
34+
list: The list of predicted responses.
35+
"""
36+
query = preprocess_inputs(request.app.state.training_names, forms.forms)
37+
38+
predictions = request.app.state.model.predict(query, params={"k": nb_echos_max})
39+
40+
response = [
41+
process_response(predictions, i, nb_echos_max, prob_min, request.app.state.libs) for i in range(len(predictions[0]))
42+
]
43+
44+
responses = []
45+
for i in range(len(predictions[0])):
46+
response = process_response(predictions, i, nb_echos_max, prob_min, request.app.state.libs)
47+
log_prediction(query, response, i)
48+
responses.append(response)
49+
50+
return responses

src/api/routes/predict_single.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Annotated
2+
3+
from fastapi import APIRouter, Depends, Request
4+
from fastapi.security import HTTPBasicCredentials
5+
6+
from api.models.forms import SingleForm
7+
from api.models.responses import PredictionResponse
8+
from utils.logging import log_prediction
9+
from utils.prediction import process_response
10+
from utils.preprocessing import preprocess_inputs
11+
from utils.security import get_credentials
12+
13+
router = APIRouter(prefix="/single", tags=["Predict an activity"])
14+
15+
16+
@router.post("/predict", response_model=PredictionResponse)
17+
async def predict(
18+
credentials: Annotated[HTTPBasicCredentials, Depends(get_credentials)],
19+
request: Request,
20+
form: SingleForm,
21+
nb_echos_max: int = 5,
22+
prob_min: float = 0.01,
23+
):
24+
"""
25+
Predict code APE.
26+
27+
This endpoint accepts input data as query parameters and uses the loaded
28+
ML model to predict the code APE based on the input data.
29+
30+
Args:
31+
nb_echos_max (int): Maximum number of echoes to consider. Default is 5.
32+
prob_min (float): Minimum probability threshold. Default is 0.01.
33+
34+
Returns:
35+
dict: Response containing APE codes.
36+
"""
37+
38+
query = preprocess_inputs(request.app.state.training_names, [form])
39+
40+
predictions = request.app.state.model.predict(query, params={"k": nb_echos_max})
41+
42+
response = process_response(predictions, 0, nb_echos_max, prob_min, request.app.state.libs)
43+
44+
log_prediction(query, response, 0)
45+
46+
return response

src/utils/logging.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22

3+
from api.models.responses import PredictionResponse
4+
35

46
def configure_logging():
57
logging.basicConfig(
@@ -10,3 +12,9 @@ def configure_logging():
1012
logging.StreamHandler(),
1113
],
1214
)
15+
16+
17+
def log_prediction(query: dict, response: PredictionResponse, index: int = 0):
18+
query_line = {key: value[index] for key, value in query.items()}
19+
logging.info(f"{{'Query': {query_line}, 'Response': {response}}}")
20+
# TODO : response.model_dump() ?

src/utils/prediction.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from fastapi import HTTPException
2+
3+
from api.models.responses import Prediction, PredictionResponse
4+
5+
6+
def process_response(
7+
predictions: tuple,
8+
liasse_nb: int,
9+
nb_echos_max: int,
10+
prob_min: float,
11+
libs: dict,
12+
) -> PredictionResponse:
13+
"""
14+
Process model
15+
predictions into a structured response.
16+
"""
17+
labels, probs = predictions
18+
pred_labels = labels[liasse_nb]
19+
pred_probs = probs[liasse_nb]
20+
21+
valid_indices = [i for i, p in enumerate(pred_probs) if p >= prob_min]
22+
k = min(nb_echos_max, len(valid_indices)) if valid_indices else 0
23+
24+
if k == 0:
25+
raise HTTPException(
26+
status_code=400,
27+
detail="No prediction exceeds the given minimum probability threshold.",
28+
)
29+
30+
response_data = {
31+
str(i + 1): Prediction(
32+
code=label,
33+
probabilite=float(prob),
34+
libelle=libs[label],
35+
)
36+
for i in range(k)
37+
for label, prob in [(pred_labels[i].replace("__label__", ""), pred_probs[i])]
38+
}
39+
40+
ic = response_data["1"].probabilite - float(pred_probs[1]) if k > 1 else 0.0
41+
response_data["IC"] = ic
42+
43+
return PredictionResponse(__root__=response_data)

src/utils/preprocessing.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@ def preprocess_inputs(training_names: list, inputs: list[SingleForm]) -> dict:
88
Preprocess both single and batch inputs using shared logic.
99
"""
1010
df = pd.DataFrame([form.model_dump() for form in inputs])
11-
df.fillna("NaN", inplace=True)
11+
12+
for feature in training_names[:2]: # textual features
13+
df[feature] = df[feature].fillna(value="")
14+
for feature in training_names[2:]: # categorical features
15+
df[feature] = df[feature].fillna(value="NaN")
16+
1217
df = df.astype(str)
1318

1419
mapping = {

0 commit comments

Comments
 (0)