arslan911's picture
update
725063e verified
raw
history blame contribute delete
981 Bytes
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import numpy as np
# Load the saved model (make sure it's in the same directory or provide the correct path)
model = joblib.load("linear_regression_model.pkl")
# Initialize the FastAPI app
app = FastAPI()
# Define the input schema for predictions
class PredictionInput(BaseModel):
feature1: float
# Define the prediction endpoint
@app.post("/predict")
def predict(input_data: PredictionInput):
try:
# Convert input into model-compatible format (as a 2D array)
input_features = np.array([[input_data.feature1]])
prediction = model.predict(input_features)
return {"prediction": prediction.tolist()} # Return prediction as a list
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Basic greeting endpoint (optional)
@app.get("/")
def greet_json():
return {"message": "Welcome to the Linear Regression API!"}