Files
q-tools/py-packages/imagelist2/imagelist2/db.py
2025-06-09 14:18:41 +03:00

253 lines
7.7 KiB
Python

import os
import sqlite3
import sys
import time
from math import sqrt as sqlite_sqrt
import sqlite_vec
from .image import (
calculate_color_difference,
calculate_phash_distance,
calculate_shape_difference,
)
class DB:
def __init__(self, sqlfile):
self.sqlfile = sqlfile
self.root_path = os.path.dirname(os.path.realpath(sqlfile))
self.migrated = False
self.create_db()
self.connect()
def create_db(self):
from imagelist2 import __version__
if os.path.exists(self.sqlfile):
self.migrate(__version__)
return
db = sqlite3.connect(self.sqlfile, timeout=30)
db.text_factory = str
cursor = db.cursor()
cursor.execute("CREATE TABLE list (file TEXT PRIMARY KEY,hash TEXT,date INTEGER,size INTEGER)")
cursor.execute(
"""CREATE TABLE data (
hash TEXT PRIMARY KEY,
description TEXT,
portrait BOOLEAN,
width INTEGER,
height INTEGER,
p_hash TEXT,
sharpness NUMERIC,
R REAL, G REAL, B REAL, BR REAL, BG REAL, BB REAL,
broken BOOLEAN
)"""
)
cursor.execute("CREATE TABLE tags (hash TEXT,tag TEXT)")
cursor.execute(
"""CREATE VIEW files AS
SELECT list.file, list.date, list.size, data.*
FROM list
LEFT JOIN data ON data.hash = list.hash"""
)
cursor.execute("CREATE TABLE config (key TEXT PRIMARY KEY, value TEXT)")
cursor.execute("CREATE UNIQUE INDEX data_hash ON data(hash)")
cursor.execute("CREATE UNIQUE INDEX list_file ON list(file)")
db.commit()
cursor = db.cursor()
cursor.execute("INSERT INTO config (key,value) VALUES (?,?)", ("version", __version__))
db.commit()
return
def migrate(self, running_version):
"""Before versions in config"""
try:
db = sqlite3.connect(self.sqlfile, timeout=30)
cursor = db.cursor()
config_version = cursor.execute("SELECT value FROM config WHERE key = 'version'").fetchall()[0][0]
if config_version == running_version:
# versions match
return
db.close()
except Exception:
# last version without config
config_version = "0.0.6"
self.migrated = True
if config_version == "0.0.6": # => 0.0.7
try:
with sqlite3.connect(self.sqlfile, timeout=30) as db:
cursor = db.cursor()
cursor.execute("ALTER TABLE data ADD p_hash TEXT;")
db.commit()
except Exception:
pass
with sqlite3.connect(self.sqlfile, timeout=30) as db:
cursor = db.cursor()
cursor.execute("ALTER TABLE data ADD broken BOOLEAN;")
cursor.execute("CREATE TABLE config (key TEXT PRIMARY KEY, value TEXT)")
config_version = "0.0.7"
cursor.execute("INSERT INTO config (key,value) VALUES (?,?)", ("version", config_version))
cursor.execute("UPDATE data SET broken = ?;", (False,))
db.commit()
if config_version == "0.0.7": # => 0.0.8
with sqlite3.connect(self.sqlfile, timeout=30) as db:
cursor = db.cursor()
config_version = "0.0.8"
cursor.execute("UPDATE config SET value = ? WHERE key = ?;", (config_version,"version"))
db.commit()
print(f"Migrated to {config_version}. Restart", file=sys.stderr)
sys.exit(1)
def connect(self):
conn = sqlite3.connect(self.sqlfile, timeout=30)
conn.text_factory = str
conn.create_function("SQRT", 1, sqlite_sqrt)
conn.create_function("RELATIVE", 1, self.file2relative)
conn.create_function("PDISTANCE", 2, calculate_phash_distance)
conn.create_function("COLORDIFF", 6, calculate_color_difference)
conn.create_function("SHAPEDIFF", 4, calculate_shape_difference)
conn.enable_load_extension(True)
sqlite_vec.load(conn)
conn.enable_load_extension(False)
self.conn = conn
return conn
def cursor(self):
return self.conn.cursor()
def get_folder_contents(self, path):
"""return the contents of the folder"""
files = []
res = self.cursor().execute("SELECT file FROM list where file LIKE ?", (f"{path}%",))
for row in res:
base = row[0].replace(path, "", 1)
if not "/" in base:
files.append(row[0])
return files
def is_time_mismatch(self, image):
count = (
self.cursor()
.execute(
"SELECT COUNT(1) FROM list WHERE file = ? AND date = ?",
(
image.filename,
image.get_time(),
),
)
.fetchall()[0][0]
)
return count == 0
def is_hash_mismatch(self, image):
count = (
self.cursor()
.execute(
"SELECT COUNT(1) FROM list WHERE file = ? AND hash = ?",
(
image.filename,
image.get_hash(),
),
)
.fetchall()[0][0]
)
return count == 0
def hash2file(self, hash):
return [
row[0]
for row in self.cursor()
.execute(
"SELECT file FROM LIST WHERE hash = ?",
(hash,),
)
.fetchall()
]
def file2hash(self, file):
try:
return [
row[0]
for row in self.cursor()
.execute(
"SELECT hash FROM LIST WHERE file = ?",
(file,),
)
.fetchall()
][0]
except Exception:
return None
def file2relative(self, file):
return os.path.relpath(file, self.root_path)
class DBCachedWriter:
def __init__(self, DB):
"""DB = instance of the DB object"""
self.db = DB
self.cache = []
self.cache_time = time.time()
self.writeout = 30
self.writemax = 499
self.max_retries = 5
self.try_count = 0
def __del__(self):
self.close()
def commit(self):
self.write_cache()
def close(self):
if len(self.cache) > 0:
self.write_cache()
def execute(self, query, values):
self.cache.append({"query": query, "values": values})
if time.time() > self.cache_time + self.writeout or len(self.cache) > self.writemax:
self.write_cache()
def write_cache(self):
if len(self.cache) > 0:
try:
# ~ print(f"Write cache: {len(self.cache)} rows...", file=sys.stderr)
cursor = self.db.cursor()
for row in self.cache:
# ~ print(row['query'], row['values'])
cursor.execute(row["query"], row["values"])
self.db.conn.commit()
except sqlite3.OperationalError as e:
print(e, file=sys.stderr)
print("Writing failed, waiting for next writeout...", file=sys.stderr)
self.cache_time = time.time()
self.try_count += 1
if self.try_count > self.max_retries:
print(f"Failed\nQuery: {row['query']}\nValues: {row['values']}", file=sys.stderr)
raise (e)
return
self.try_count = 0
self.cache = []
self.cache_time = time.time()
def sqlite_square(x):
return x * x