Source code for diva.monitoring.api_service_reliability
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from pydantic import BaseModel
import threading
import uvicorn
from diva import parameters
import streamlit as st
import warnings
warnings.filterwarnings(
"ignore", message="missing ScriptRunContext! This warning can be ignored when running in bare mode.")
app = FastAPI()
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# ---------------- Middleware to block ip ----------------
[docs]
@app.middleware("http")
async def restrict_ip_middleware(request: Request, call_next):
client_host = request.client.host
if client_host not in parameters.authorized_IP:
raise HTTPException(status_code=403, detail="Forbidden: Unauthorized IP")
response = await call_next(request)
return response
# ---------------- endpoint /llm_prompt ----------------
[docs]
class PromptRequest(BaseModel):
prompt: str
[docs]
@app.post("/llm_prompt")
@limiter.limit(parameters.limit_requests_fastapi)
async def llm_prompt(request: Request, payload: PromptRequest): # ajoute request ici !
try:
from diva.llm import llms
generated_text = llms.generator(payload.prompt)
return JSONResponse(content={"response": generated_text}, status_code=200)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ---------------- endpoint /get_data ----------------
[docs]
class DataRequest(BaseModel):
location_name_orig: str
location_name: str
addresstype: str
start_date: str
end_date: str
[docs]
@app.post("/get_data")
@limiter.limit(parameters.limit_requests_fastapi)
async def get_data(request: Request, payload: DataRequest):
try:
import pandas as pd
from diva.data.dataset import DataCollection
from diva import parameters
from diva.logging_.logger import Process
if "process_logging" not in st.session_state:
st.session_state.process_logging = Process()
time_intervals = [[pd.to_datetime(payload.start_date, format="%Y-%m-%d"),
pd.to_datetime(payload.end_date, format="%Y-%m-%d")]]
locs = [{
"location_name_orig": payload.location_name_orig,
"location_name": payload.location_name,
"addresstype": payload.addresstype
}]
dc = DataCollection(parameters.cache_b_collections_normal, 't2m')
dc = dc.sample_time(time_intervals)
dc = dc.apply_masks(locs)
dc = dc.spatial_aggregation()
raw_vals = dc.get_values()
return JSONResponse(content={"data": str(raw_vals)}, status_code=200)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ---------------- endpoint /generate_graph ----------------
[docs]
class GraphRequest(BaseModel):
starttime: str
endtime: str
location: str
elementofinterest: str
graph_type: str
aggreg_type: str
[docs]
@app.post("/generate_graph")
@limiter.limit(parameters.limit_requests_fastapi)
async def generate_graph(request: Request, payload: GraphRequest):
# try:
from diva.logging_.logger import Process
from diva.graphs.service_graph_generation import ServiceGeneratePlotlyGraph
if "process_logging" not in st.session_state:
st.session_state.process_logging = Process()
if "plots_history" not in st.session_state:
st.session_state.plots_history = {}
if "messages" not in st.session_state:
st.session_state.messages = {}
params = {
"starttime": payload.starttime,
"endtime": payload.endtime,
"location": payload.location,
"elementofinterest": payload.elementofinterest,
"graph_type": payload.graph_type,
"aggreg_type": payload.aggreg_type,
"climate_variable": "temperature"
}
gg = ServiceGeneratePlotlyGraph(
params=params, langage="English", user_type="normal")
gg.generate(show=False)
return JSONResponse(content={"data": str(gg.fig.data)}, status_code=200)
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e))
# ---------------- run api ----------------
[docs]
def run_api():
uvicorn.run(app, host=parameters.url_reliability,
port=parameters.port_reliability)
api_thread = threading.Thread(target=run_api, daemon=True)
api_thread.start()
# if __name__ == "__main__":
# uvicorn.run("api_service_reliability:app",
# host="0.0.0.0", port=8602, reload=True)