diff --git a/db/sqlite.py b/db/sqlite.py index 8394006..3ba8d41 100644 --- a/db/sqlite.py +++ b/db/sqlite.py @@ -140,7 +140,7 @@ class DB: def get_user_private_info(self, user_id): return self._db().execute(''' - select name, about + select about from users where user_id = ? ''', @@ -158,6 +158,15 @@ class DB: ) db.commit() + def get_user_name_role(self, user_id): + return self._db().execute(''' + select name, role + from users + where user_id = ? + ''', + (user_id,) + ).fetchone() + def get_user_name(self, user_id): return self._db().execute(''' select name @@ -193,9 +202,13 @@ class DB: c.execute(''' delete from threads - where thread_id = ? and author_id = ? + -- 1 = moderator, 2 = admin + where thread_id = ? and ( + author_id = ? + or (select 1 from users where user_id = ? and (role = 1 or role = 2)) + ) ''', - (thread_id, user_id) + (thread_id, user_id, user_id) ) db.commit() return c.rowcount > 0 @@ -206,9 +219,16 @@ class DB: c.execute(''' delete from comments - where comment_id = ? and author_id = ? + where comment_id = ? + and ( + author_id = ? + -- 1 = moderator, 2 = admin + or (select 1 from users where user_id = ? and (role = 1 or role = 2)) + ) + -- Don't allow deleting comments with children + and (select 1 from comments where parent_id = ?) is null ''', - (comment_id, user_id) + (comment_id, user_id, user_id, comment_id) ) db.commit() return c.rowcount > 0 @@ -270,9 +290,13 @@ class DB: c.execute(''' update threads set title = ?, text = ?, modify_time = ? - where thread_id = ? and author_id = ? + where thread_id = ? and ( + author_id = ? + -- 1 = moderator, 2 = admin + or (select 1 from users where user_id = ? and (role = 1 or role = 2)) + ) ''', - (title, text, time, thread_id, user_id) + (title, text, time, thread_id, user_id, user_id) ) if c.rowcount > 0: db.commit() @@ -285,9 +309,13 @@ class DB: c.execute(''' update comments set text = ?, modify_time = ? - where comment_id = ? and author_id = ? + where comment_id = ? and ( + author_id = ? + -- 1 = moderator, 2 = admin + or (select 1 from users where user_id = ? and (role = 1 or role = 2)) + ) ''', - (text, time, comment_id, user_id) + (text, time, comment_id, user_id, user_id) ) if c.rowcount > 0: db.commit() diff --git a/main.py b/main.py index 371733d..e477f9d 100644 --- a/main.py +++ b/main.py @@ -16,9 +16,19 @@ captcha_key = 'piss off bots' app.jinja_env.trim_blocks = True app.jinja_env.lstrip_blocks = True +class Role: + USER = 0 + MODERATOR = 1 + ADMIN = 2 + @app.route('/') def index(): - return render_template('index.html', title = NAME, forums = db.get_forums()) + return render_template( + 'index.html', + title = NAME, + user = get_user(), + forums = db.get_forums() + ) @app.route('/forum//') def forum(forum_id): @@ -27,6 +37,7 @@ def forum(forum_id): return render_template( 'forum.html', title = title, + user = get_user(), forum_id = forum_id, description = description, threads = threads, @@ -40,6 +51,7 @@ def thread(thread_id): return render_template( 'thread.html', title = title, + user = get_user(), text = text, author = author, author_id = author_id, @@ -51,7 +63,6 @@ def thread(thread_id): @app.route('/comment//') def comment(comment_id): - user_id = session.get('user_id') thread_id, parent_id, title, comments = db.get_subcomments(comment_id) comments = create_comment_tree(comments) reply_comment, = comments @@ -60,6 +71,7 @@ def comment(comment_id): return render_template( 'comments.html', title = title, + user = get_user(), reply_comment = reply_comment, comments = comments, parent_id = parent_id, @@ -80,8 +92,11 @@ def login(): # Sleep to reduce effectiveness of bruteforce time.sleep(0.1) flash('Username or password is invalid', 'error') - return render_template('login.html', title = "Login") - return render_template('login.html', title = "Login") + return render_template( + 'login.html', + title = 'Login', + user = get_user() + ) @app.route('/logout/') def logout(): @@ -90,22 +105,21 @@ def logout(): @app.route('/user/', methods = ['GET', 'POST']) def user_edit(): - user_id = session.get('user_id') - if user_id is None: + user = get_user() + if user is None: return redirect(url_for('login')) if request.method == 'POST': about = request.form['about'].replace('\r', '') - db.set_user_private_info(user_id, about) - name, = db.get_user_name(user_id) + db.set_user_private_info(user.id, about) flash('Updated profile', 'success') else: - name, about = db.get_user_private_info(user_id) + about, = db.get_user_private_info(user.id) return render_template( 'user_edit.html', - name = name, title = 'Edit profile', + user = user, about = about ) @@ -115,6 +129,7 @@ def user_info(user_id): return render_template( 'user_info.html', title = 'Profile', + user = get_user(), name = name, about = about ) @@ -133,6 +148,7 @@ def new_thread(forum_id): return render_template( 'new_thread.html', title = 'Create new thread', + user = get_user(), ) @app.route('/thread//confirm_delete/') @@ -141,6 +157,7 @@ def confirm_delete_thread(thread_id): return render_template( 'confirm_delete_thread.html', title = 'Delete thread', + user = get_user(), thread_title = title, ) @@ -187,6 +204,7 @@ def confirm_delete_comment(comment_id): return render_template( 'confirm_delete_comment.html', title = 'Delete comment', + user = get_user(), thread_title = title, text = text, ) @@ -228,6 +246,7 @@ def edit_thread(thread_id): return render_template( 'edit_thread.html', title = 'Edit thread', + user = get_user(), thread_title = title, text = text, ) @@ -255,6 +274,7 @@ def edit_comment(comment_id): return render_template( 'edit_comment.html', title = 'Edit comment', + user = get_user(), thread_title = title, text = text, ) @@ -283,6 +303,7 @@ def register(): return render_template( 'register.html', title = 'Register', + user = get_user(), captcha = capt, answer = answer, ) @@ -300,6 +321,8 @@ class Comment: self.parent_id = parent_id def create_comment_tree(comments): + comments = [*comments] + print(comments) start = time.time(); # Collect comments first, then build the tree in case we encounter a child before a parent comment_map = { @@ -326,6 +349,26 @@ def create_comment_tree(comments): return root +class User: + def __init__(self, id, name, role): + self.id = id + self.name = name + self.role = role + + def is_moderator(self): + return self.role in (Role.ADMIN, Role.MODERATOR) + + def is_admin(self): + return self.role == Role.ADMIN + +def get_user(): + id = session.get('user_id') + if id is not None: + name, role = db.get_user_name_role(id) + return User(id, name, role) + return None + + @app.context_processor def utility_processor(): def format_since(t): diff --git a/templates/base.html b/templates/base.html index 331a161..662d81d 100644 --- a/templates/base.html +++ b/templates/base.html @@ -9,8 +9,8 @@