66import os
77from contextlib import asynccontextmanager
88from pathlib import Path
9- from typing import Annotated , List
9+ from typing import Annotated
1010
1111import mlflow
1212import yaml
1313from fastapi import Depends , FastAPI
1414from fastapi .middleware .cors import CORSMiddleware
1515from fastapi .security import HTTPBasicCredentials
16- from pydantic import BaseModel
1716
18- from api .models . forms import BatchForms , SingleForm
17+ from api .routes import predict_batch , predict_single
1918from utils .logging import configure_logging
20- from utils .preprocessing import preprocess_inputs
21- from utils .utils import (
22- get_credentials ,
23- get_model ,
24- process_response ,
25- )
19+ from utils .security import get_credentials
2620
2721
2822@asynccontextmanager
@@ -40,7 +34,9 @@ async def lifespan(app: FastAPI):
4034 logger = logging .getLogger (__name__ )
4135 logger .info ("🚀 Starting API lifespan" )
4236
43- app .state .model = get_model (os .environ ["MLFLOW_MODEL_NAME" ], os .environ ["MLFLOW_MODEL_VERSION" ])
37+ app .state .model = mlflow .pyfunc .load_model (
38+ model_uri = f"models:/{ os .environ ['MLFLOW_MODEL_NAME' ]} /{ os .environ ['MLFLOW_MODEL_VERSION' ]} "
39+ )
4440 run_data = mlflow .get_run (app .state .model .metadata .run_id ).data .params
4541 app .state .training_names = [
4642 run_data ["text_feature" ],
@@ -54,58 +50,16 @@ async def lifespan(app: FastAPI):
5450 logger .info ("🛑 Shutting down API lifespan" )
5551
5652
57- class BatchPredictionRequest (BaseModel ):
58- """
59- Pydantic BaseModel for representing the input data for the API.
60-
61- This BaseModel defines the structure of the input data required
62- for the API's "/predict-batch" endpoint.
63-
64- Attributes:
65- description_activity (List[str]): The text description.
66- other_nature_activity (List[str]): Other nature of the activity.
67- precision_act_sec_agricole (List[str]): Precision of the activity in the agricultural sector.
68- type_form (List[str]): The type of the form CERFA.
69- nature (List[str]): The nature of the activity.
70- surface (List[str]): The surface of activity.
71- cj (List[str]): The legal category code.
72- activity_permanence_status (List[str]): The activity permanence status (permanent or seasonal).
73-
74- """
75-
76- description_activity : List [str ]
77- other_nature_activity : List [str ]
78- precision_act_sec_agricole : List [str ]
79- type_form : List [str ]
80- nature : List [str ]
81- surface : List [str ]
82- cj : List [str ]
83- activity_permanence_status : List [str ]
84-
85- class Config :
86- json_schema_extra = {
87- "example" : {
88- "description_activity" : [
89- ("LOUEUR MEUBLE NON PROFESSIONNEL EN RESIDENCE DE SERVICES (CODE APE 6820A Location de logements)" )
90- ],
91- "other_nature_activity" : ["" ],
92- "precision_act_sec_agricole" : ["" ],
93- "type_form" : ["I" ],
94- "nature" : ["" ],
95- "surface" : ["" ],
96- "cj" : ["" ],
97- "activity_permanence_status" : ["" ],
98- }
99- }
100-
101-
10253app = FastAPI (
10354 lifespan = lifespan ,
10455 title = "Prédiction code APE" ,
10556 description = "Application de prédiction pour l'activité principale de l'entreprise (APE)" ,
10657 version = "0.0.1" ,
10758)
10859
60+ app .include_router (predict_single .router )
61+ app .include_router (predict_batch .router )
62+
10963
11064app .add_middleware (
11165 CORSMiddleware ,
@@ -128,71 +82,3 @@ def show_welcome_page(
12882 "Model_name" : f"{ os .environ ['MLFLOW_MODEL_NAME' ]} " ,
12983 "Model_version" : f"{ os .environ ['MLFLOW_MODEL_VERSION' ]} " ,
13084 }
131-
132-
133- @app .post ("/predict" , tags = ["Predict" ])
134- async def predict (
135- credentials : Annotated [HTTPBasicCredentials , Depends (get_credentials )],
136- form : SingleForm ,
137- nb_echos_max : int = 5 ,
138- prob_min : float = 0.01 ,
139- ):
140- """
141- Predict code APE.
142-
143- This endpoint accepts input data as query parameters and uses the loaded
144- ML model to predict the code APE based on the input data.
145-
146- Args:
147- nb_echos_max (int): Maximum number of echoes to consider. Default is 5.
148- prob_min (float): Minimum probability threshold. Default is 0.01.
149-
150- Returns:
151- dict: Response containing APE codes.
152- """
153-
154- query = preprocess_inputs (app .state .training_names , [form ])
155-
156- predictions = app .state .model .predict (query , params = {"k" : nb_echos_max })
157-
158- response = process_response (predictions , 0 , nb_echos_max , prob_min , app .state .libs )
159-
160- # Logging
161- query_to_log = {key : value [0 ] for key , value in query .items ()}
162- logging .info (f"{{'Query': { query_to_log } , 'Response': { response } }}" )
163-
164- return response
165-
166-
167- @app .post ("/predict-batch" , tags = ["Predict" ])
168- async def predict_batch (
169- credentials : Annotated [HTTPBasicCredentials , Depends (get_credentials )],
170- forms : BatchForms ,
171- nb_echos_max : int = 5 ,
172- prob_min : float = 0.01 ,
173- ):
174- """
175- Endpoint for predicting batches of data.
176-
177- Args:
178- credentials (HTTPBasicCredentials): The credentials for authentication.
179- forms (Forms): The input data in the form of Forms object.
180- nb_echos_max (int, optional): The maximum number of predictions to return. Defaults to 5.
181- prob_min (float, optional): The minimum probability threshold for predictions. Defaults to 0.01.
182-
183- Returns:
184- list: The list of predicted responses.
185- """
186- query = preprocess_inputs (app .state .training_names , forms .forms )
187-
188- predictions = app .state .model .predict (query , params = {"k" : nb_echos_max })
189-
190- response = [process_response (predictions , i , nb_echos_max , prob_min , app .state .libs ) for i in range (len (predictions [0 ]))]
191-
192- # Logging
193- for line in range (len (query [app .state .training_names [0 ]])):
194- query_line = {key : value [line ] for key , value in query .items ()}
195- response_line = response [line ]
196- logging .info (f"{{'Query': { query_line } , 'Response': { response_line } }}" )
197-
198- return response
0 commit comments