188 lines
5.7 KiB
Python
Executable File
188 lines
5.7 KiB
Python
Executable File
from fastapi import FastAPI, UploadFile, File, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
import insightface
|
|
from insightface.app import FaceAnalysis
|
|
import numpy as np
|
|
import cv2
|
|
import uvicorn
|
|
import faiss
|
|
import argparse
|
|
import mysql.connector
|
|
from datetime import datetime, timedelta
|
|
|
|
# Command-line arguments
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--db-host", default="localhost")
|
|
parser.add_argument("--db-user", default="root")
|
|
parser.add_argument("--db-pass", default="")
|
|
parser.add_argument("--port", default=8001, type=int)
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
# Initialise FastAPI app
|
|
app = FastAPI()
|
|
|
|
# Allow cross-origin requests if needed (optional)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Initialise InsightFace (CPU-only)
|
|
face_app = FaceAnalysis(name='buffalo_l')
|
|
face_app.prepare(ctx_id=-1) # Use CPU
|
|
|
|
# Dictionary to hold FAISS index and metadata per database
|
|
db_indexes = {}
|
|
|
|
# DB connection helper
|
|
def get_mysql_connection(db_name):
|
|
return mysql.connector.connect(
|
|
host=args.db_host,
|
|
database=db_name,
|
|
user=args.db_user,
|
|
password=args.db_pass
|
|
)
|
|
|
|
# Load vectors from MySQL for a given database
|
|
def load_vectors(db_name):
|
|
conn = get_mysql_connection(db_name)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT ref, resource, vector_blob, node FROM resource_face")
|
|
results = cursor.fetchall()
|
|
conn.close()
|
|
|
|
if not results:
|
|
print(f"No face vectors found in database '{db_name}'.")
|
|
return
|
|
|
|
vectors = []
|
|
index_to_metadata = []
|
|
max_ref = 0
|
|
|
|
for ref, resource, blob, node in results:
|
|
vector = np.frombuffer(blob, dtype=np.float32)
|
|
vector = vector / np.linalg.norm(vector)
|
|
vectors.append(vector)
|
|
index_to_metadata.append({"ref": ref, "resource": resource, "node": node})
|
|
max_ref = max(max_ref, ref)
|
|
|
|
d = len(vectors[0])
|
|
index = faiss.IndexFlatIP(d)
|
|
index.add(np.array(vectors).astype('float32'))
|
|
|
|
db_indexes[db_name] = {
|
|
"index": index,
|
|
"metadata": index_to_metadata,
|
|
"vectors": np.array(vectors).astype('float32'),
|
|
"last_used": datetime.utcnow(),
|
|
"max_ref": max_ref
|
|
}
|
|
print(f"Loaded {len(vectors)} vectors for database '{db_name}'.")
|
|
|
|
# Request model for similarity search
|
|
class FaceSearchRequest(BaseModel):
|
|
ref: int
|
|
db: str
|
|
threshold: float = 0.0
|
|
k: int = 10
|
|
|
|
@app.post("/extract_faces")
|
|
async def extract_faces(file: UploadFile = File(...)):
|
|
try:
|
|
contents = await file.read()
|
|
image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
|
|
if image is None:
|
|
raise ValueError("Could not decode image")
|
|
|
|
faces = face_app.get(image)
|
|
results = []
|
|
for face in faces:
|
|
results.append({
|
|
"bbox": face.bbox.tolist(),
|
|
"embedding": face.embedding.tolist(),
|
|
"det_score": float(face.det_score)
|
|
})
|
|
return results
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/find_similar_faces")
|
|
async def find_similar_faces(request: FaceSearchRequest):
|
|
db_name = request.db
|
|
now = datetime.utcnow()
|
|
|
|
# Determine if the cache should be refreshed
|
|
should_reload = False
|
|
|
|
if db_name not in db_indexes:
|
|
should_reload = True
|
|
else:
|
|
last_used = db_indexes[db_name].get("last_used")
|
|
max_known_ref = db_indexes[db_name].get("max_ref", 0)
|
|
|
|
if now - last_used > timedelta(hours=1):
|
|
print(f"Cache for '{db_name}' is older than 1 hour. Refreshing.")
|
|
should_reload = True
|
|
else:
|
|
# Check if there are new face entries
|
|
conn = get_mysql_connection(db_name)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT MAX(ref) FROM resource_face")
|
|
row = cursor.fetchone()
|
|
conn.close()
|
|
latest_ref = row[0] if row and row[0] is not None else 0
|
|
if latest_ref > max_known_ref:
|
|
print(f"New faces detected in '{db_name}'. Reloading vectors.")
|
|
should_reload = True
|
|
|
|
if should_reload:
|
|
load_vectors(db_name)
|
|
|
|
if db_name not in db_indexes:
|
|
raise HTTPException(status_code=500, detail=f"Unable to load vector index for database '{db_name}'")
|
|
|
|
db_indexes[db_name]["last_used"] = now
|
|
|
|
conn = get_mysql_connection(db_name)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT vector_blob FROM resource_face WHERE ref = %s", (request.ref,))
|
|
row = cursor.fetchone()
|
|
conn.close()
|
|
|
|
if not row:
|
|
raise HTTPException(status_code=404, detail="Face vector not found")
|
|
|
|
query_vector = np.frombuffer(row[0], dtype=np.float32)
|
|
query_vector = query_vector / np.linalg.norm(query_vector)
|
|
query_vector = query_vector.reshape(1, -1)
|
|
|
|
face_index = db_indexes[db_name]["index"]
|
|
metadata = db_indexes[db_name]["metadata"]
|
|
|
|
distances, indices = face_index.search(query_vector, request.k + 1)
|
|
|
|
matches = []
|
|
for dist, idx in zip(distances[0], indices[0]):
|
|
if idx < 0 or idx >= len(metadata):
|
|
continue
|
|
|
|
match = metadata[idx].copy()
|
|
|
|
if match["ref"] == request.ref:
|
|
continue
|
|
|
|
similarity = float(round(dist, 4))
|
|
if similarity >= request.threshold:
|
|
match["similarity"] = similarity
|
|
matches.append(match)
|
|
print(f"Match: ref={match['ref']} similarity={match['similarity']:.4f}")
|
|
|
|
matches.sort(key=lambda x: -x["similarity"])
|
|
return matches
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run("faces_service:app", host="127.0.0.1", port=args.port) |