Files
resourcespace/plugins/faces/scripts/faces_service.py
2025-07-18 16:20:14 +07:00

188 lines
5.7 KiB
Python
Executable File

from fastapi import FastAPI, UploadFile, File, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import insightface
from insightface.app import FaceAnalysis
import numpy as np
import cv2
import uvicorn
import faiss
import argparse
import mysql.connector
from datetime import datetime, timedelta
# Command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--db-host", default="localhost")
parser.add_argument("--db-user", default="root")
parser.add_argument("--db-pass", default="")
parser.add_argument("--port", default=8001, type=int)
args, unknown = parser.parse_known_args()
# Initialise FastAPI app
app = FastAPI()
# Allow cross-origin requests if needed (optional)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialise InsightFace (CPU-only)
face_app = FaceAnalysis(name='buffalo_l')
face_app.prepare(ctx_id=-1) # Use CPU
# Dictionary to hold FAISS index and metadata per database
db_indexes = {}
# DB connection helper
def get_mysql_connection(db_name):
return mysql.connector.connect(
host=args.db_host,
database=db_name,
user=args.db_user,
password=args.db_pass
)
# Load vectors from MySQL for a given database
def load_vectors(db_name):
conn = get_mysql_connection(db_name)
cursor = conn.cursor()
cursor.execute("SELECT ref, resource, vector_blob, node FROM resource_face")
results = cursor.fetchall()
conn.close()
if not results:
print(f"No face vectors found in database '{db_name}'.")
return
vectors = []
index_to_metadata = []
max_ref = 0
for ref, resource, blob, node in results:
vector = np.frombuffer(blob, dtype=np.float32)
vector = vector / np.linalg.norm(vector)
vectors.append(vector)
index_to_metadata.append({"ref": ref, "resource": resource, "node": node})
max_ref = max(max_ref, ref)
d = len(vectors[0])
index = faiss.IndexFlatIP(d)
index.add(np.array(vectors).astype('float32'))
db_indexes[db_name] = {
"index": index,
"metadata": index_to_metadata,
"vectors": np.array(vectors).astype('float32'),
"last_used": datetime.utcnow(),
"max_ref": max_ref
}
print(f"Loaded {len(vectors)} vectors for database '{db_name}'.")
# Request model for similarity search
class FaceSearchRequest(BaseModel):
ref: int
db: str
threshold: float = 0.0
k: int = 10
@app.post("/extract_faces")
async def extract_faces(file: UploadFile = File(...)):
try:
contents = await file.read()
image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
if image is None:
raise ValueError("Could not decode image")
faces = face_app.get(image)
results = []
for face in faces:
results.append({
"bbox": face.bbox.tolist(),
"embedding": face.embedding.tolist(),
"det_score": float(face.det_score)
})
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/find_similar_faces")
async def find_similar_faces(request: FaceSearchRequest):
db_name = request.db
now = datetime.utcnow()
# Determine if the cache should be refreshed
should_reload = False
if db_name not in db_indexes:
should_reload = True
else:
last_used = db_indexes[db_name].get("last_used")
max_known_ref = db_indexes[db_name].get("max_ref", 0)
if now - last_used > timedelta(hours=1):
print(f"Cache for '{db_name}' is older than 1 hour. Refreshing.")
should_reload = True
else:
# Check if there are new face entries
conn = get_mysql_connection(db_name)
cursor = conn.cursor()
cursor.execute("SELECT MAX(ref) FROM resource_face")
row = cursor.fetchone()
conn.close()
latest_ref = row[0] if row and row[0] is not None else 0
if latest_ref > max_known_ref:
print(f"New faces detected in '{db_name}'. Reloading vectors.")
should_reload = True
if should_reload:
load_vectors(db_name)
if db_name not in db_indexes:
raise HTTPException(status_code=500, detail=f"Unable to load vector index for database '{db_name}'")
db_indexes[db_name]["last_used"] = now
conn = get_mysql_connection(db_name)
cursor = conn.cursor()
cursor.execute("SELECT vector_blob FROM resource_face WHERE ref = %s", (request.ref,))
row = cursor.fetchone()
conn.close()
if not row:
raise HTTPException(status_code=404, detail="Face vector not found")
query_vector = np.frombuffer(row[0], dtype=np.float32)
query_vector = query_vector / np.linalg.norm(query_vector)
query_vector = query_vector.reshape(1, -1)
face_index = db_indexes[db_name]["index"]
metadata = db_indexes[db_name]["metadata"]
distances, indices = face_index.search(query_vector, request.k + 1)
matches = []
for dist, idx in zip(distances[0], indices[0]):
if idx < 0 or idx >= len(metadata):
continue
match = metadata[idx].copy()
if match["ref"] == request.ref:
continue
similarity = float(round(dist, 4))
if similarity >= request.threshold:
match["similarity"] = similarity
matches.append(match)
print(f"Match: ref={match['ref']} similarity={match['similarity']:.4f}")
matches.sort(key=lambda x: -x["similarity"])
return matches
if __name__ == "__main__":
uvicorn.run("faces_service:app", host="127.0.0.1", port=args.port)