first commit
This commit is contained in:
478
plugins/clip/scripts/clip_service.py
Normal file
478
plugins/clip/scripts/clip_service.py
Normal file
@@ -0,0 +1,478 @@
|
||||
import argparse
|
||||
import getpass
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import io
|
||||
import mysql.connector
|
||||
from typing import Optional
|
||||
import json
|
||||
import time
|
||||
import faiss
|
||||
import threading
|
||||
import torch
|
||||
import clip
|
||||
import requests
|
||||
import hashlib, os, time
|
||||
import tempfile
|
||||
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser(description="CLIP search service for ResourceSpace")
|
||||
parser.add_argument("--dbuser", help="MySQL username")
|
||||
parser.add_argument("--dbpass", help="MySQL password")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="Host to run on")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port to run on")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.dbuser:
|
||||
args.dbuser = input("Enter MySQL username (or use --dbuser): ")
|
||||
|
||||
if not args.dbpass:
|
||||
args.dbpass = getpass.getpass("Enter MySQL password (or use --dbpass): ")
|
||||
|
||||
# Global DB credentials (used later)
|
||||
DB_CONFIG = {
|
||||
"host": "localhost",
|
||||
"user": args.dbuser,
|
||||
"password": args.dbpass
|
||||
}
|
||||
|
||||
# Set up FastAPI app and in-memory cache
|
||||
app = FastAPI()
|
||||
device = "cpu"
|
||||
print("🔌 Loading CLIP model...")
|
||||
model, preprocess = clip.load("ViT-B/32", device=device)
|
||||
print("✅ Model loaded.")
|
||||
|
||||
cached_vectors = {} # { db_name: (vectors_np, resource_ids) }
|
||||
loaded_max_ref = {} # { db_name: max_ref }
|
||||
faiss_indexes = {} # { db_name: faiss.IndexFlatIP }
|
||||
tag_vector_cache = {} # { url: (tag_list, tag_vectors) }
|
||||
tag_faiss_index_cache = {} # { url: faiss.IndexFlatIP }
|
||||
|
||||
def load_vectors_for_db(db_name, force_reload=False):
|
||||
global cached_vectors, loaded_max_ref, faiss_indexes
|
||||
|
||||
if db_name in cached_vectors and not force_reload:
|
||||
return cached_vectors[db_name]
|
||||
|
||||
print(f"🔄 Loading vectors from DB: {db_name}")
|
||||
try:
|
||||
conn = mysql.connector.connect(**DB_CONFIG, database=db_name)
|
||||
cursor = conn.cursor()
|
||||
|
||||
last_ref = loaded_max_ref.get(db_name, 0)
|
||||
|
||||
start = time.time()
|
||||
cursor.execute(
|
||||
"SELECT ref, resource, vector_blob FROM resource_clip_vector WHERE is_text=0 AND ref > %s ORDER BY resource",
|
||||
(last_ref,)
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
elapsed_ms = round((time.time() - start) * 1000)
|
||||
print(f"📥 Vector load from MySQL took {elapsed_ms}ms")
|
||||
start = time.time()
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"DB error: {e}")
|
||||
|
||||
if not rows and db_name in cached_vectors:
|
||||
return cached_vectors[db_name]
|
||||
|
||||
new_vectors = []
|
||||
new_ids = []
|
||||
|
||||
for ref, resource, blob in rows:
|
||||
if resource is None:
|
||||
print(f"❌ Skipping ref {ref}: resource is None")
|
||||
continue
|
||||
|
||||
if len(blob) != 2048:
|
||||
print(f"❌ Skipping ref {ref} (resource {resource}): blob is {len(blob)} bytes, expected 2048")
|
||||
continue
|
||||
|
||||
try:
|
||||
vector = np.frombuffer(blob, dtype=np.float32).copy()
|
||||
if vector.shape != (512,):
|
||||
print(f"❌ Skipping ref {ref} (resource {resource}): vector shape {vector.shape}, expected (512,)")
|
||||
continue
|
||||
|
||||
norm = np.linalg.norm(vector)
|
||||
if norm == 0 or np.isnan(norm):
|
||||
print(f"⚠️ Skipping ref {ref} (resource {resource}): invalid norm ({norm})")
|
||||
continue
|
||||
|
||||
vector /= norm
|
||||
new_vectors.append(vector)
|
||||
new_ids.append(resource)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Exception parsing vector for ref {ref} (resource {resource}): {e}")
|
||||
continue
|
||||
|
||||
if db_name in cached_vectors:
|
||||
old_vectors, old_ids = cached_vectors[db_name]
|
||||
vectors = np.vstack([old_vectors, new_vectors])
|
||||
ids = old_ids + new_ids
|
||||
else:
|
||||
vectors = np.stack(new_vectors) if new_vectors else np.empty((0, 512), dtype=np.float32)
|
||||
ids = new_ids
|
||||
|
||||
cached_vectors[db_name] = (vectors, ids)
|
||||
if ids:
|
||||
loaded_max_ref[db_name] = max(ref for ref, _, _ in rows)
|
||||
|
||||
# Rebuild or update FAISS index
|
||||
if db_name not in faiss_indexes:
|
||||
index = faiss.IndexFlatIP(512)
|
||||
if len(vectors) > 0:
|
||||
index.add(vectors)
|
||||
faiss_indexes[db_name] = index
|
||||
else:
|
||||
if new_vectors:
|
||||
faiss_indexes[db_name].add(np.stack(new_vectors))
|
||||
|
||||
elapsed_ms = round((time.time() - start) * 1000)
|
||||
print(f"⚙️ Vector processing and indexing took {elapsed_ms}ms")
|
||||
print(f"✅ Cached {len(ids)} vectors for DB: {db_name}")
|
||||
return cached_vectors[db_name]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@app.post("/vector")
|
||||
async def generate_vector(
|
||||
image: Optional[UploadFile] = File(None),
|
||||
text: Optional[str] = Form(None),
|
||||
):
|
||||
if image is None and text is None:
|
||||
raise HTTPException(status_code=400, detail="Provide either 'image' or 'text'")
|
||||
|
||||
try:
|
||||
if image:
|
||||
contents = await image.read()
|
||||
img = Image.open(io.BytesIO(contents)).convert("RGB")
|
||||
img_input = preprocess(img).unsqueeze(0).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
vector = model.encode_image(img_input)
|
||||
else:
|
||||
tokens = clip.tokenize([text]).to(device)
|
||||
with torch.no_grad():
|
||||
vector = model.encode_text(tokens)
|
||||
|
||||
# Normalise and return vector
|
||||
vector = vector / vector.norm(dim=-1, keepdim=True)
|
||||
vector_np = vector.cpu().numpy().flatten().tolist()
|
||||
return JSONResponse(content=vector_np)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Vector generation error: {e}")
|
||||
|
||||
|
||||
|
||||
@app.post("/search")
|
||||
async def search(
|
||||
db: str = Form(...),
|
||||
text: str = Form(None),
|
||||
image: UploadFile = File(None),
|
||||
resource: int = Form(None),
|
||||
ref: str = Form(None),
|
||||
top_k: int = Form(5)
|
||||
):
|
||||
if not any([text, image, resource, ref]):
|
||||
raise HTTPException(status_code=400, detail="Provide one of: text, image, resource, or ref")
|
||||
|
||||
print(f"▶️ SEARCH: db={db}, top_k={top_k}")
|
||||
vectors, resource_ids = load_vectors_for_db(db)
|
||||
print(f"🧠 Vectors loaded: {len(resource_ids)} resources")
|
||||
|
||||
if len(resource_ids) == 0:
|
||||
return JSONResponse(content=[])
|
||||
|
||||
try:
|
||||
index = faiss_indexes.get(db)
|
||||
if not index:
|
||||
raise HTTPException(status_code=500, detail="FAISS index not found")
|
||||
|
||||
# --- Create query vector ---
|
||||
if text:
|
||||
print("🔤 Text query")
|
||||
tokens = clip.tokenize([text]).to(device)
|
||||
with torch.no_grad():
|
||||
query_vector = model.encode_text(tokens)
|
||||
|
||||
elif image:
|
||||
print("🖼️ Image query")
|
||||
contents = await image.read()
|
||||
img = Image.open(io.BytesIO(contents)).convert("RGB")
|
||||
img_input = preprocess(img).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
query_vector = model.encode_image(img_input)
|
||||
|
||||
elif resource is not None or ref is not None:
|
||||
# Determine which field to query
|
||||
column = "resource" if resource is not None else "ref"
|
||||
value = resource if resource is not None else ref
|
||||
print(f"🔁 Query from DB vector: {column} = {value}")
|
||||
|
||||
conn = mysql.connector.connect(**DB_CONFIG, database=db)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"SELECT vector_blob FROM resource_clip_vector WHERE {column} = %s AND is_text=0", (value,))
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if not row or not row[0] or len(row[0]) != 2048:
|
||||
raise HTTPException(status_code=404, detail=f"Valid vector_blob not found for {column}={value}")
|
||||
|
||||
query_vector = np.frombuffer(row[0], dtype=np.float32).copy()
|
||||
if query_vector.shape != (512,):
|
||||
raise HTTPException(status_code=400, detail="Malformed vector shape")
|
||||
norm = np.linalg.norm(query_vector)
|
||||
if norm == 0 or np.isnan(norm):
|
||||
raise HTTPException(status_code=400, detail="Invalid vector norm")
|
||||
query_vector = torch.tensor(query_vector).unsqueeze(0)
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid input combination")
|
||||
|
||||
# --- Search ---
|
||||
query_vector = query_vector / query_vector.norm(dim=-1, keepdim=True)
|
||||
query_np = query_vector.cpu().numpy().flatten()
|
||||
print("✅ Query vector created")
|
||||
|
||||
print("🔍 Performing FAISS search")
|
||||
print(f"FAISS index size: {index.ntotal}")
|
||||
start = time.time()
|
||||
D, I = index.search(query_np.reshape(1, -1), top_k + 1)
|
||||
elapsed_ms = round((time.time() - start) * 1000)
|
||||
print(f"FAISS search took {elapsed_ms}ms")
|
||||
print(f"🎯 Search results: {I[0]}")
|
||||
|
||||
results = []
|
||||
ref_int = int(ref) if ref not in (None, "") else None
|
||||
|
||||
for i, score in zip(I[0], D[0]):
|
||||
if i < 0:
|
||||
continue
|
||||
candidate_id = int(resource_ids[i])
|
||||
|
||||
# Skip self-match for resource/ref queries
|
||||
if (resource is not None and candidate_id == resource) or \
|
||||
(ref_int is not None and candidate_id == ref_int):
|
||||
continue
|
||||
|
||||
|
||||
results.append({
|
||||
"resource": candidate_id,
|
||||
"score": float(score)
|
||||
})
|
||||
|
||||
if len(results) == top_k:
|
||||
break
|
||||
|
||||
print("Returning", len(results), "results")
|
||||
return JSONResponse(content=results)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Exception in /search: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Search error: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
@app.post("/duplicates")
|
||||
async def find_duplicates(
|
||||
db: str = Form(...),
|
||||
threshold: float = Form(0.9)
|
||||
):
|
||||
vectors, resource_ids = load_vectors_for_db(db)
|
||||
|
||||
if len(resource_ids) == 0:
|
||||
return JSONResponse(content=[])
|
||||
|
||||
try:
|
||||
index = faiss_indexes.get(db)
|
||||
if not index:
|
||||
raise HTTPException(status_code=500, detail="FAISS index not found")
|
||||
|
||||
results = []
|
||||
top_k = 50 # You can increase this if you want more candidates
|
||||
|
||||
D, I = index.search(vectors, top_k) # batch search all against all
|
||||
|
||||
for i, (distances, indices) in enumerate(zip(D, I)):
|
||||
for j, score in zip(indices, distances):
|
||||
if score >= threshold and i != j:
|
||||
a = int(resource_ids[i])
|
||||
b = int(resource_ids[j])
|
||||
if a < b: # avoid duplicates (a,b) vs (b,a)
|
||||
results.append({
|
||||
"resource": a,
|
||||
"resource_match": b,
|
||||
"score": float(score)
|
||||
})
|
||||
|
||||
return JSONResponse(content=results)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Duplicate detection error: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
@app.post("/tag")
|
||||
async def tag_search(
|
||||
db: str = Form(...),
|
||||
url: str = Form(...),
|
||||
top_k: int = Form(5),
|
||||
resource: Optional[int] = Form(None),
|
||||
vector: Optional[str] = Form(None)
|
||||
):
|
||||
CACHE_DIR = os.path.join(tempfile.gettempdir(), "clip_tag_cache")
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
|
||||
def get_cache_filename(url):
|
||||
hash = hashlib.sha256(url.encode()).hexdigest()
|
||||
return os.path.join(CACHE_DIR, f"{hash}.tagdb")
|
||||
|
||||
cache_path = get_cache_filename(url)
|
||||
cache_expiry_secs = 30 * 86400 # 30 days
|
||||
|
||||
use_disk_cache = (
|
||||
os.path.exists(cache_path) and
|
||||
(time.time() - os.path.getmtime(cache_path)) < cache_expiry_secs
|
||||
)
|
||||
|
||||
if not use_disk_cache:
|
||||
try:
|
||||
start = time.time()
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
with open(cache_path, 'w', encoding='utf-8') as f:
|
||||
f.write(response.text)
|
||||
elapsed = int((time.time() - start) * 1000)
|
||||
print(f"📡 Downloaded tag database from URL in {elapsed}ms: {url}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to download tag vectors: {e}")
|
||||
|
||||
if url not in tag_vector_cache:
|
||||
try:
|
||||
start = time.time()
|
||||
with open(cache_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.read().strip().split('\n')
|
||||
|
||||
tags = []
|
||||
vectors = []
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
if len(parts) != 513:
|
||||
continue
|
||||
tag = parts[0]
|
||||
vector_arr = np.array([float(x) for x in parts[1:]], dtype=np.float32)
|
||||
norm = np.linalg.norm(vector_arr)
|
||||
if norm == 0 or np.isnan(norm):
|
||||
continue
|
||||
vector_arr /= norm
|
||||
tags.append(tag)
|
||||
vectors.append(vector_arr)
|
||||
|
||||
if not vectors:
|
||||
raise ValueError("No valid tag vectors found.")
|
||||
|
||||
tag_vectors = np.stack(vectors)
|
||||
tag_vector_cache[url] = (tags, tag_vectors)
|
||||
index = faiss.IndexFlatIP(512)
|
||||
index.add(tag_vectors)
|
||||
tag_faiss_index_cache[url] = index
|
||||
elapsed = int((time.time() - start) * 1000)
|
||||
print(f"💾 Loaded tag database from disk cache in {elapsed}ms: {url}")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load tag vectors from cache: {e}")
|
||||
|
||||
tags, tag_vectors = tag_vector_cache[url]
|
||||
index = tag_faiss_index_cache[url]
|
||||
|
||||
if vector:
|
||||
try:
|
||||
vector_list = json.loads(vector)
|
||||
resource_vector = np.array(vector_list, dtype=np.float32)
|
||||
if resource_vector.shape != (512,):
|
||||
raise HTTPException(status_code=400, detail="Malformed input vector shape")
|
||||
norm = np.linalg.norm(resource_vector)
|
||||
if norm == 0 or np.isnan(norm):
|
||||
raise HTTPException(status_code=400, detail="Invalid input vector norm")
|
||||
resource_vector /= norm
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid 'vector' input: {e}")
|
||||
|
||||
elif resource is not None:
|
||||
try:
|
||||
conn = mysql.connector.connect(**DB_CONFIG, database=db)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT vector_blob FROM resource_clip_vector WHERE resource = %s AND is_text = 0",
|
||||
(resource,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
if not row or not row[0] or len(row[0]) != 2048:
|
||||
raise HTTPException(status_code=404, detail="Valid vector_blob not found for the specified resource.")
|
||||
resource_vector = np.frombuffer(row[0], dtype=np.float32).copy()
|
||||
if resource_vector.shape != (512,):
|
||||
raise HTTPException(status_code=400, detail="Malformed vector shape.")
|
||||
norm = np.linalg.norm(resource_vector)
|
||||
if norm == 0 or np.isnan(norm):
|
||||
raise HTTPException(status_code=400, detail="Invalid vector norm.")
|
||||
resource_vector /= norm
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error retrieving resource vector: {e}")
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Either 'resource' or 'vector' must be provided.")
|
||||
|
||||
try:
|
||||
D, I = index.search(resource_vector.reshape(1, -1), top_k)
|
||||
results = []
|
||||
for idx, score in zip(I[0], D[0]):
|
||||
if idx < 0 or idx >= len(tags):
|
||||
continue
|
||||
results.append({
|
||||
"tag": tags[idx],
|
||||
"score": float(score)
|
||||
})
|
||||
return JSONResponse(content=results)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error during tagging: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def background_vector_loader():
|
||||
while True:
|
||||
time.sleep(30)
|
||||
try:
|
||||
for db_name in cached_vectors.keys():
|
||||
load_vectors_for_db(db_name, force_reload=True)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Background update failed: {e}")
|
||||
|
||||
@app.on_event("startup")
|
||||
def start_background_task():
|
||||
print("🌀 Starting background vector refresher thread...")
|
||||
thread = threading.Thread(target=background_vector_loader, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
|
||||
|
||||
# Start the server
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="debug")
|
||||
|
Reference in New Issue
Block a user