253 lines
7.7 KiB
Python
253 lines
7.7 KiB
Python
import os
|
|
import sqlite3
|
|
import sys
|
|
import time
|
|
from math import sqrt as sqlite_sqrt
|
|
|
|
import sqlite_vec
|
|
|
|
from .image import (
|
|
calculate_color_difference,
|
|
calculate_phash_distance,
|
|
calculate_shape_difference,
|
|
)
|
|
|
|
|
|
class DB:
|
|
def __init__(self, sqlfile):
|
|
self.sqlfile = sqlfile
|
|
self.root_path = os.path.dirname(os.path.realpath(sqlfile))
|
|
self.migrated = False
|
|
self.create_db()
|
|
self.connect()
|
|
|
|
def create_db(self):
|
|
|
|
from imagelist2 import __version__
|
|
|
|
if os.path.exists(self.sqlfile):
|
|
self.migrate(__version__)
|
|
return
|
|
|
|
db = sqlite3.connect(self.sqlfile, timeout=30)
|
|
db.text_factory = str
|
|
cursor = db.cursor()
|
|
|
|
cursor.execute("CREATE TABLE list (file TEXT PRIMARY KEY,hash TEXT,date INTEGER,size INTEGER)")
|
|
cursor.execute(
|
|
"""CREATE TABLE data (
|
|
hash TEXT PRIMARY KEY,
|
|
description TEXT,
|
|
portrait BOOLEAN,
|
|
width INTEGER,
|
|
height INTEGER,
|
|
p_hash TEXT,
|
|
sharpness NUMERIC,
|
|
R REAL, G REAL, B REAL, BR REAL, BG REAL, BB REAL,
|
|
broken BOOLEAN
|
|
)"""
|
|
)
|
|
cursor.execute("CREATE TABLE tags (hash TEXT,tag TEXT)")
|
|
cursor.execute(
|
|
"""CREATE VIEW files AS
|
|
SELECT list.file, list.date, list.size, data.*
|
|
FROM list
|
|
LEFT JOIN data ON data.hash = list.hash"""
|
|
)
|
|
cursor.execute("CREATE TABLE config (key TEXT PRIMARY KEY, value TEXT)")
|
|
cursor.execute("CREATE UNIQUE INDEX data_hash ON data(hash)")
|
|
cursor.execute("CREATE UNIQUE INDEX list_file ON list(file)")
|
|
db.commit()
|
|
cursor = db.cursor()
|
|
|
|
cursor.execute("INSERT INTO config (key,value) VALUES (?,?)", ("version", __version__))
|
|
db.commit()
|
|
return
|
|
|
|
def migrate(self, running_version):
|
|
"""Before versions in config"""
|
|
try:
|
|
db = sqlite3.connect(self.sqlfile, timeout=30)
|
|
cursor = db.cursor()
|
|
config_version = cursor.execute("SELECT value FROM config WHERE key = 'version'").fetchall()[0][0]
|
|
if config_version == running_version:
|
|
# versions match
|
|
return
|
|
db.close()
|
|
except Exception:
|
|
# last version without config
|
|
config_version = "0.0.6"
|
|
self.migrated = True
|
|
if config_version == "0.0.6": # => 0.0.7
|
|
try:
|
|
with sqlite3.connect(self.sqlfile, timeout=30) as db:
|
|
cursor = db.cursor()
|
|
cursor.execute("ALTER TABLE data ADD p_hash TEXT;")
|
|
db.commit()
|
|
except Exception:
|
|
pass
|
|
|
|
with sqlite3.connect(self.sqlfile, timeout=30) as db:
|
|
cursor = db.cursor()
|
|
cursor.execute("ALTER TABLE data ADD broken BOOLEAN;")
|
|
cursor.execute("CREATE TABLE config (key TEXT PRIMARY KEY, value TEXT)")
|
|
config_version = "0.0.7"
|
|
cursor.execute("INSERT INTO config (key,value) VALUES (?,?)", ("version", config_version))
|
|
cursor.execute("UPDATE data SET broken = ?;", (False,))
|
|
db.commit()
|
|
|
|
if config_version == "0.0.7": # => 0.0.8
|
|
with sqlite3.connect(self.sqlfile, timeout=30) as db:
|
|
cursor = db.cursor()
|
|
config_version = "0.0.8"
|
|
cursor.execute("UPDATE config SET value = ? WHERE key = ?;", (config_version,"version"))
|
|
db.commit()
|
|
|
|
print(f"Migrated to {config_version}. Restart", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
def connect(self):
|
|
conn = sqlite3.connect(self.sqlfile, timeout=30)
|
|
conn.text_factory = str
|
|
conn.create_function("SQRT", 1, sqlite_sqrt)
|
|
conn.create_function("RELATIVE", 1, self.file2relative)
|
|
conn.create_function("PDISTANCE", 2, calculate_phash_distance)
|
|
conn.create_function("COLORDIFF", 6, calculate_color_difference)
|
|
conn.create_function("SHAPEDIFF", 4, calculate_shape_difference)
|
|
conn.enable_load_extension(True)
|
|
sqlite_vec.load(conn)
|
|
conn.enable_load_extension(False)
|
|
self.conn = conn
|
|
|
|
return conn
|
|
|
|
def cursor(self):
|
|
return self.conn.cursor()
|
|
|
|
def get_folder_contents(self, path):
|
|
"""return the contents of the folder"""
|
|
files = []
|
|
res = self.cursor().execute("SELECT file FROM list where file LIKE ?", (f"{path}%",))
|
|
for row in res:
|
|
base = row[0].replace(path, "", 1)
|
|
if not "/" in base:
|
|
files.append(row[0])
|
|
return files
|
|
|
|
def is_time_mismatch(self, image):
|
|
count = (
|
|
self.cursor()
|
|
.execute(
|
|
"SELECT COUNT(1) FROM list WHERE file = ? AND date = ?",
|
|
(
|
|
image.filename,
|
|
image.get_time(),
|
|
),
|
|
)
|
|
.fetchall()[0][0]
|
|
)
|
|
return count == 0
|
|
|
|
def is_hash_mismatch(self, image):
|
|
count = (
|
|
self.cursor()
|
|
.execute(
|
|
"SELECT COUNT(1) FROM list WHERE file = ? AND hash = ?",
|
|
(
|
|
image.filename,
|
|
image.get_hash(),
|
|
),
|
|
)
|
|
.fetchall()[0][0]
|
|
)
|
|
return count == 0
|
|
|
|
def hash2file(self, hash):
|
|
|
|
return [
|
|
row[0]
|
|
for row in self.cursor()
|
|
.execute(
|
|
"SELECT file FROM LIST WHERE hash = ?",
|
|
(hash,),
|
|
)
|
|
.fetchall()
|
|
]
|
|
|
|
def file2hash(self, file):
|
|
|
|
try:
|
|
return [
|
|
row[0]
|
|
for row in self.cursor()
|
|
.execute(
|
|
"SELECT hash FROM LIST WHERE file = ?",
|
|
(file,),
|
|
)
|
|
.fetchall()
|
|
][0]
|
|
except Exception:
|
|
return None
|
|
|
|
def file2relative(self, file):
|
|
|
|
return os.path.relpath(file, self.root_path)
|
|
|
|
|
|
class DBCachedWriter:
|
|
def __init__(self, DB):
|
|
"""DB = instance of the DB object"""
|
|
self.db = DB
|
|
self.cache = []
|
|
self.cache_time = time.time()
|
|
self.writeout = 30
|
|
self.writemax = 499
|
|
self.max_retries = 5
|
|
self.try_count = 0
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
def commit(self):
|
|
self.write_cache()
|
|
|
|
def close(self):
|
|
|
|
if len(self.cache) > 0:
|
|
self.write_cache()
|
|
|
|
def execute(self, query, values):
|
|
|
|
self.cache.append({"query": query, "values": values})
|
|
|
|
if time.time() > self.cache_time + self.writeout or len(self.cache) > self.writemax:
|
|
self.write_cache()
|
|
|
|
def write_cache(self):
|
|
|
|
if len(self.cache) > 0:
|
|
try:
|
|
# ~ print(f"Write cache: {len(self.cache)} rows...", file=sys.stderr)
|
|
cursor = self.db.cursor()
|
|
for row in self.cache:
|
|
# ~ print(row['query'], row['values'])
|
|
cursor.execute(row["query"], row["values"])
|
|
self.db.conn.commit()
|
|
except sqlite3.OperationalError as e:
|
|
print(e, file=sys.stderr)
|
|
print("Writing failed, waiting for next writeout...", file=sys.stderr)
|
|
self.cache_time = time.time()
|
|
self.try_count += 1
|
|
if self.try_count > self.max_retries:
|
|
print(f"Failed\nQuery: {row['query']}\nValues: {row['values']}", file=sys.stderr)
|
|
raise (e)
|
|
return
|
|
self.try_count = 0
|
|
self.cache = []
|
|
self.cache_time = time.time()
|
|
|
|
|
|
def sqlite_square(x):
|
|
return x * x
|