diff --git a/db/sqlite.py b/db/sqlite.py index 7bbba3a..586493e 100644 --- a/db/sqlite.py +++ b/db/sqlite.py @@ -6,13 +6,13 @@ class DB: pass def get_subforums(self): - return self._db().execute('select forum_id, name, description from subforums').fetchall() + return self._db().execute('select forum_id, name, description from subforums') def get_subforum(self, subforum): return self._db().execute('select name, description from subforums where forum_id = ?', (subforum,)).fetchone() def get_threads(self, subforum): - return self._db().execute('select thread_id, title from threads where forum_id = ?', (subforum,)).fetchall() + return self._db().execute('select thread_id, title from threads where forum_id = ?', (subforum,)) def get_thread(self, thread): db = self._db() @@ -29,7 +29,7 @@ class DB: where thread_id = ? and author_id = user_id ''', (thread,) - ).fetchall() + ) return title, text, author, author_id, comments def get_thread_title(self, thread_id): @@ -48,15 +48,30 @@ class DB: where thread_id = ? ''', (thread,) - ).fetchall() + ) - def get_comment_tree(self, comment): + def get_subcomments(self, comment_id): db = self._db() - parent = db.execute('select text from comments where comment_id = ?', (comment,)).fetchall() - children = db.execute('select text from comments where parent_id = ?', (comment,)).fetchall() - print(parent, children) - return str(parent) + str(children) - return parent + thread_id, parent_id, title = db.execute(''' + select threads.thread_id, parent_id, title + from threads, comments + where comment_id = ? and threads.thread_id = comments.thread_id + ''', + (comment_id,) + ).fetchone() + # Recursive CTE, see https://www.sqlite.org/lang_with.html + return thread_id, parent_id, title, db.execute(''' + with recursive + descendant_of(id) as ( + select comment_id from comments where comment_id = ? + union + select comment_id from descendant_of, comments where id = parent_id + ) + select id, parent_id, name, text from descendant_of, comments, users + where id = comment_id and user_id = author_id + ''', + (comment_id,) + ) def get_user_password(self, username): return self._db().execute(''' @@ -127,5 +142,37 @@ class DB: ) db.commit() + def add_comment_to_thread(self, thread_id, author_id, text, time): + db = self._db() + c = db.cursor() + c.execute(''' + insert into comments(thread_id, author_id, text, create_time, modify_time) + select ?, ?, ?, ?, ? + from threads + where thread_id = ? + ''', + (thread_id, author_id, text, time, time, thread_id) + ) + rowid = c.lastrowid + db.commit() + return rowid is not None + + def add_comment_to_comment(self, parent_id, author_id, text, time): + db = self._db() + c = db.cursor() + print(c.lastrowid, parent_id) + c.execute(''' + insert into comments(thread_id, parent_id, author_id, text, create_time, modify_time) + select thread_id, ?, ?, ?, ?, ? + from comments + where comment_id = ? + ''', + (parent_id, author_id, text, time, time, parent_id) + ) + print(c.lastrowid) + rowid = c.lastrowid + db.commit() + return rowid is not None + def _db(self): return sqlite3.connect(self.conn) diff --git a/main.py b/main.py index 477b640..9532faa 100644 --- a/main.py +++ b/main.py @@ -43,8 +43,16 @@ def thread(thread_id): @app.route('/comment//') def comment(comment_id): - #return str(db.get_comment_tree(comment_id)[0]) - return str(db.get_comment_tree(comment_id)) + user_id = session.get('user_id') + thread_id, parent_id, title, comments = db.get_subcomments(comment_id) + comments = create_comment_tree(comments) + return render_template( + 'comments.html', + title = title, + comments = comments, + parent_id = parent_id, + thread_id = thread_id, + ) @app.route('/login/', methods = ['GET', 'POST']) def login(): @@ -129,21 +137,50 @@ def delete_thread(thread_id): flash('Thread has been deleted', 'success') return redirect(url_for('index')) +@app.route('/thread//comment/', methods = ['POST']) +def add_comment(thread_id): + user_id = session.get('user_id') + if user_id is None: + return redirect(url_for('login')) + + if db.add_comment_to_thread(thread_id, user_id, request.form['text'], time.time_ns()): + flash('Added comment', 'success') + else: + flash('Failed to add comment', 'error') + return redirect(url_for('thread', thread_id = thread_id)) + +@app.route('/comment//comment/', methods = ['POST']) +def add_comment_parent(comment_id): + user_id = session.get('user_id') + if user_id is None: + return redirect(url_for('login')) + + if db.add_comment_to_comment(comment_id, user_id, request.form['text'], time.time_ns()): + flash('Added comment', 'success') + else: + flash('Failed to add comment', 'error') + return redirect(url_for('comment', comment_id = comment_id)) + + class Comment: - def __init__(self, author, text): + def __init__(self, id, author, text): + self.id = id self.author = author self.text = text self.children = [] def create_comment_tree(comments): + # Collect comments first, then build the tree in case we encounter a child before a parent + comment_map = { + comment_id: (Comment(comment_id, author, text), parent_id) + for comment_id, parent_id, author, text + in comments + } root = [] - comment_map = {} - for comment_id, parent_id, author, text in comments: - comment = Comment(author, text) + for comment, parent_id in comment_map.values(): parent = comment_map.get(parent_id) if parent is not None: - parent.children.append(comment) + parent[0].children.append(comment) else: root.append(comment) - comment_map[comment_id] = comment return root diff --git a/templates/comment.html b/templates/comment.html new file mode 100644 index 0000000..b988c96 --- /dev/null +++ b/templates/comment.html @@ -0,0 +1,19 @@ +{% macro render_comment(comment) %} +
+

{{ comment.author }}

+

{{ comment.text }}

+ reply + {% for c in comment.children %} + {{ render_comment(c) }} + {% endfor %} +
+{% endmacro %} + +{% macro reply() %} +{% if 'user_id' in session %} +
+ + +
+{% endif %} +{% endmacro %} diff --git a/templates/comments.html b/templates/comments.html new file mode 100644 index 0000000..731b2e6 --- /dev/null +++ b/templates/comments.html @@ -0,0 +1,15 @@ +{% extends 'base.html' %} +{% from 'comment.html' import render_comment, reply %} + +{% block content %} +thread +{% if parent_id %} +parent +{% endif %} + +{{ reply() }} + +{% for c in comments %} +{{ render_comment(c) }} +{% endfor %} +{% endblock %} diff --git a/templates/thread.html b/templates/thread.html index a3d2d5e..aa977e0 100644 --- a/templates/thread.html +++ b/templates/thread.html @@ -1,14 +1,5 @@ {% extends 'base.html' %} - -{% macro render_comment(comment) %} -
-

{{ comment.author }}

-

{{ comment.text }}

- {% for c in comment.children %} - {{ render_comment(c) }} - {% endfor %} -
-{% endmacro %} +{% from 'comment.html' import render_comment, reply %} {% block content %} {% if manage %} @@ -20,6 +11,9 @@ {% endif %}

{{ author }} - rjgoire

{{ text }}

+ +{{ reply() }} + {% for c in comments %} {{ render_comment(c) }} {% endfor %}