first commit

This commit is contained in:
2025-07-18 16:20:14 +07:00
commit 98af45c018
16382 changed files with 3148096 additions and 0 deletions

View File

@@ -0,0 +1,478 @@
import argparse
import getpass
import uvicorn
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
import numpy as np
import io
import mysql.connector
from typing import Optional
import json
import time
import faiss
import threading
import torch
import clip
import requests
import hashlib, os, time
import tempfile
# Command-line arguments
parser = argparse.ArgumentParser(description="CLIP search service for ResourceSpace")
parser.add_argument("--dbuser", help="MySQL username")
parser.add_argument("--dbpass", help="MySQL password")
parser.add_argument("--host", default="0.0.0.0", help="Host to run on")
parser.add_argument("--port", type=int, default=8000, help="Port to run on")
args = parser.parse_args()
if not args.dbuser:
args.dbuser = input("Enter MySQL username (or use --dbuser): ")
if not args.dbpass:
args.dbpass = getpass.getpass("Enter MySQL password (or use --dbpass): ")
# Global DB credentials (used later)
DB_CONFIG = {
"host": "localhost",
"user": args.dbuser,
"password": args.dbpass
}
# Set up FastAPI app and in-memory cache
app = FastAPI()
device = "cpu"
print("🔌 Loading CLIP model...")
model, preprocess = clip.load("ViT-B/32", device=device)
print("✅ Model loaded.")
cached_vectors = {} # { db_name: (vectors_np, resource_ids) }
loaded_max_ref = {} # { db_name: max_ref }
faiss_indexes = {} # { db_name: faiss.IndexFlatIP }
tag_vector_cache = {} # { url: (tag_list, tag_vectors) }
tag_faiss_index_cache = {} # { url: faiss.IndexFlatIP }
def load_vectors_for_db(db_name, force_reload=False):
global cached_vectors, loaded_max_ref, faiss_indexes
if db_name in cached_vectors and not force_reload:
return cached_vectors[db_name]
print(f"🔄 Loading vectors from DB: {db_name}")
try:
conn = mysql.connector.connect(**DB_CONFIG, database=db_name)
cursor = conn.cursor()
last_ref = loaded_max_ref.get(db_name, 0)
start = time.time()
cursor.execute(
"SELECT ref, resource, vector_blob FROM resource_clip_vector WHERE is_text=0 AND ref > %s ORDER BY resource",
(last_ref,)
)
rows = cursor.fetchall()
conn.close()
elapsed_ms = round((time.time() - start) * 1000)
print(f"📥 Vector load from MySQL took {elapsed_ms}ms")
start = time.time()
except Exception as e:
raise HTTPException(status_code=500, detail=f"DB error: {e}")
if not rows and db_name in cached_vectors:
return cached_vectors[db_name]
new_vectors = []
new_ids = []
for ref, resource, blob in rows:
if resource is None:
print(f"❌ Skipping ref {ref}: resource is None")
continue
if len(blob) != 2048:
print(f"❌ Skipping ref {ref} (resource {resource}): blob is {len(blob)} bytes, expected 2048")
continue
try:
vector = np.frombuffer(blob, dtype=np.float32).copy()
if vector.shape != (512,):
print(f"❌ Skipping ref {ref} (resource {resource}): vector shape {vector.shape}, expected (512,)")
continue
norm = np.linalg.norm(vector)
if norm == 0 or np.isnan(norm):
print(f"⚠️ Skipping ref {ref} (resource {resource}): invalid norm ({norm})")
continue
vector /= norm
new_vectors.append(vector)
new_ids.append(resource)
except Exception as e:
print(f"❌ Exception parsing vector for ref {ref} (resource {resource}): {e}")
continue
if db_name in cached_vectors:
old_vectors, old_ids = cached_vectors[db_name]
vectors = np.vstack([old_vectors, new_vectors])
ids = old_ids + new_ids
else:
vectors = np.stack(new_vectors) if new_vectors else np.empty((0, 512), dtype=np.float32)
ids = new_ids
cached_vectors[db_name] = (vectors, ids)
if ids:
loaded_max_ref[db_name] = max(ref for ref, _, _ in rows)
# Rebuild or update FAISS index
if db_name not in faiss_indexes:
index = faiss.IndexFlatIP(512)
if len(vectors) > 0:
index.add(vectors)
faiss_indexes[db_name] = index
else:
if new_vectors:
faiss_indexes[db_name].add(np.stack(new_vectors))
elapsed_ms = round((time.time() - start) * 1000)
print(f"⚙️ Vector processing and indexing took {elapsed_ms}ms")
print(f"✅ Cached {len(ids)} vectors for DB: {db_name}")
return cached_vectors[db_name]
@app.post("/vector")
async def generate_vector(
image: Optional[UploadFile] = File(None),
text: Optional[str] = Form(None),
):
if image is None and text is None:
raise HTTPException(status_code=400, detail="Provide either 'image' or 'text'")
try:
if image:
contents = await image.read()
img = Image.open(io.BytesIO(contents)).convert("RGB")
img_input = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
vector = model.encode_image(img_input)
else:
tokens = clip.tokenize([text]).to(device)
with torch.no_grad():
vector = model.encode_text(tokens)
# Normalise and return vector
vector = vector / vector.norm(dim=-1, keepdim=True)
vector_np = vector.cpu().numpy().flatten().tolist()
return JSONResponse(content=vector_np)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Vector generation error: {e}")
@app.post("/search")
async def search(
db: str = Form(...),
text: str = Form(None),
image: UploadFile = File(None),
resource: int = Form(None),
ref: str = Form(None),
top_k: int = Form(5)
):
if not any([text, image, resource, ref]):
raise HTTPException(status_code=400, detail="Provide one of: text, image, resource, or ref")
print(f"▶️ SEARCH: db={db}, top_k={top_k}")
vectors, resource_ids = load_vectors_for_db(db)
print(f"🧠 Vectors loaded: {len(resource_ids)} resources")
if len(resource_ids) == 0:
return JSONResponse(content=[])
try:
index = faiss_indexes.get(db)
if not index:
raise HTTPException(status_code=500, detail="FAISS index not found")
# --- Create query vector ---
if text:
print("🔤 Text query")
tokens = clip.tokenize([text]).to(device)
with torch.no_grad():
query_vector = model.encode_text(tokens)
elif image:
print("🖼️ Image query")
contents = await image.read()
img = Image.open(io.BytesIO(contents)).convert("RGB")
img_input = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
query_vector = model.encode_image(img_input)
elif resource is not None or ref is not None:
# Determine which field to query
column = "resource" if resource is not None else "ref"
value = resource if resource is not None else ref
print(f"🔁 Query from DB vector: {column} = {value}")
conn = mysql.connector.connect(**DB_CONFIG, database=db)
cursor = conn.cursor()
cursor.execute(f"SELECT vector_blob FROM resource_clip_vector WHERE {column} = %s AND is_text=0", (value,))
row = cursor.fetchone()
conn.close()
if not row or not row[0] or len(row[0]) != 2048:
raise HTTPException(status_code=404, detail=f"Valid vector_blob not found for {column}={value}")
query_vector = np.frombuffer(row[0], dtype=np.float32).copy()
if query_vector.shape != (512,):
raise HTTPException(status_code=400, detail="Malformed vector shape")
norm = np.linalg.norm(query_vector)
if norm == 0 or np.isnan(norm):
raise HTTPException(status_code=400, detail="Invalid vector norm")
query_vector = torch.tensor(query_vector).unsqueeze(0)
else:
raise HTTPException(status_code=400, detail="Invalid input combination")
# --- Search ---
query_vector = query_vector / query_vector.norm(dim=-1, keepdim=True)
query_np = query_vector.cpu().numpy().flatten()
print("✅ Query vector created")
print("🔍 Performing FAISS search")
print(f"FAISS index size: {index.ntotal}")
start = time.time()
D, I = index.search(query_np.reshape(1, -1), top_k + 1)
elapsed_ms = round((time.time() - start) * 1000)
print(f"FAISS search took {elapsed_ms}ms")
print(f"🎯 Search results: {I[0]}")
results = []
ref_int = int(ref) if ref not in (None, "") else None
for i, score in zip(I[0], D[0]):
if i < 0:
continue
candidate_id = int(resource_ids[i])
# Skip self-match for resource/ref queries
if (resource is not None and candidate_id == resource) or \
(ref_int is not None and candidate_id == ref_int):
continue
results.append({
"resource": candidate_id,
"score": float(score)
})
if len(results) == top_k:
break
print("Returning", len(results), "results")
return JSONResponse(content=results)
except Exception as e:
print(f"❌ Exception in /search: {e}")
raise HTTPException(status_code=500, detail=f"Search error: {e}")
@app.post("/duplicates")
async def find_duplicates(
db: str = Form(...),
threshold: float = Form(0.9)
):
vectors, resource_ids = load_vectors_for_db(db)
if len(resource_ids) == 0:
return JSONResponse(content=[])
try:
index = faiss_indexes.get(db)
if not index:
raise HTTPException(status_code=500, detail="FAISS index not found")
results = []
top_k = 50 # You can increase this if you want more candidates
D, I = index.search(vectors, top_k) # batch search all against all
for i, (distances, indices) in enumerate(zip(D, I)):
for j, score in zip(indices, distances):
if score >= threshold and i != j:
a = int(resource_ids[i])
b = int(resource_ids[j])
if a < b: # avoid duplicates (a,b) vs (b,a)
results.append({
"resource": a,
"resource_match": b,
"score": float(score)
})
return JSONResponse(content=results)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Duplicate detection error: {e}")
@app.post("/tag")
async def tag_search(
db: str = Form(...),
url: str = Form(...),
top_k: int = Form(5),
resource: Optional[int] = Form(None),
vector: Optional[str] = Form(None)
):
CACHE_DIR = os.path.join(tempfile.gettempdir(), "clip_tag_cache")
os.makedirs(CACHE_DIR, exist_ok=True)
def get_cache_filename(url):
hash = hashlib.sha256(url.encode()).hexdigest()
return os.path.join(CACHE_DIR, f"{hash}.tagdb")
cache_path = get_cache_filename(url)
cache_expiry_secs = 30 * 86400 # 30 days
use_disk_cache = (
os.path.exists(cache_path) and
(time.time() - os.path.getmtime(cache_path)) < cache_expiry_secs
)
if not use_disk_cache:
try:
start = time.time()
response = requests.get(url)
response.raise_for_status()
with open(cache_path, 'w', encoding='utf-8') as f:
f.write(response.text)
elapsed = int((time.time() - start) * 1000)
print(f"📡 Downloaded tag database from URL in {elapsed}ms: {url}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to download tag vectors: {e}")
if url not in tag_vector_cache:
try:
start = time.time()
with open(cache_path, 'r', encoding='utf-8') as f:
lines = f.read().strip().split('\n')
tags = []
vectors = []
for line in lines:
parts = line.strip().split()
if len(parts) != 513:
continue
tag = parts[0]
vector_arr = np.array([float(x) for x in parts[1:]], dtype=np.float32)
norm = np.linalg.norm(vector_arr)
if norm == 0 or np.isnan(norm):
continue
vector_arr /= norm
tags.append(tag)
vectors.append(vector_arr)
if not vectors:
raise ValueError("No valid tag vectors found.")
tag_vectors = np.stack(vectors)
tag_vector_cache[url] = (tags, tag_vectors)
index = faiss.IndexFlatIP(512)
index.add(tag_vectors)
tag_faiss_index_cache[url] = index
elapsed = int((time.time() - start) * 1000)
print(f"💾 Loaded tag database from disk cache in {elapsed}ms: {url}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to load tag vectors from cache: {e}")
tags, tag_vectors = tag_vector_cache[url]
index = tag_faiss_index_cache[url]
if vector:
try:
vector_list = json.loads(vector)
resource_vector = np.array(vector_list, dtype=np.float32)
if resource_vector.shape != (512,):
raise HTTPException(status_code=400, detail="Malformed input vector shape")
norm = np.linalg.norm(resource_vector)
if norm == 0 or np.isnan(norm):
raise HTTPException(status_code=400, detail="Invalid input vector norm")
resource_vector /= norm
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid 'vector' input: {e}")
elif resource is not None:
try:
conn = mysql.connector.connect(**DB_CONFIG, database=db)
cursor = conn.cursor()
cursor.execute(
"SELECT vector_blob FROM resource_clip_vector WHERE resource = %s AND is_text = 0",
(resource,)
)
row = cursor.fetchone()
conn.close()
if not row or not row[0] or len(row[0]) != 2048:
raise HTTPException(status_code=404, detail="Valid vector_blob not found for the specified resource.")
resource_vector = np.frombuffer(row[0], dtype=np.float32).copy()
if resource_vector.shape != (512,):
raise HTTPException(status_code=400, detail="Malformed vector shape.")
norm = np.linalg.norm(resource_vector)
if norm == 0 or np.isnan(norm):
raise HTTPException(status_code=400, detail="Invalid vector norm.")
resource_vector /= norm
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error retrieving resource vector: {e}")
else:
raise HTTPException(status_code=400, detail="Either 'resource' or 'vector' must be provided.")
try:
D, I = index.search(resource_vector.reshape(1, -1), top_k)
results = []
for idx, score in zip(I[0], D[0]):
if idx < 0 or idx >= len(tags):
continue
results.append({
"tag": tags[idx],
"score": float(score)
})
return JSONResponse(content=results)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during tagging: {e}")
def background_vector_loader():
while True:
time.sleep(30)
try:
for db_name in cached_vectors.keys():
load_vectors_for_db(db_name, force_reload=True)
except Exception as e:
print(f"⚠️ Background update failed: {e}")
@app.on_event("startup")
def start_background_task():
print("🌀 Starting background vector refresher thread...")
thread = threading.Thread(target=background_vector_loader, daemon=True)
thread.start()
# Start the server
if __name__ == "__main__":
uvicorn.run(app, host=args.host, port=args.port, log_level="debug")