Repos / pytaku / 7167519d77
commit 7167519d773aa93a77c3e3e96149b85b82dd8b08
Author: Bùi Thành Nhân <hi@imnhan.com>
Date: Thu Aug 6 20:56:11 2020 +0700
return db rows as dicts instead of tuples
also use run_sql whenever possible
diff --git a/src/pytaku/database/common.py b/src/pytaku/database/common.py
index 6701431..34d82bc 100644
--- a/src/pytaku/database/common.py
+++ b/src/pytaku/database/common.py
@@ -7,10 +7,20 @@
def get_conn():
global _conn
+
if not _conn:
_conn = apsw.Connection(DBNAME)
+
# Apparently you need to enable this pragma _per connection_
_conn.cursor().execute("PRAGMA foreign_keys = ON;")
+
+ # Return rows as dicts instead of tuples
+ _conn.setrowtrace(
+ lambda cursor, row: {
+ k[0]: row[i] for i, k in enumerate(cursor.getdescription())
+ }
+ )
+
return _conn
diff --git a/src/pytaku/database/migrator.py b/src/pytaku/database/migrator.py
index 670ed7d..c63519c 100644
--- a/src/pytaku/database/migrator.py
+++ b/src/pytaku/database/migrator.py
@@ -3,7 +3,7 @@
from pathlib import Path
from . import migrations
-from .common import DBNAME, get_conn
+from .common import DBNAME, get_conn, run_sql
"""
@@ -15,11 +15,7 @@
def _get_current_version():
- conn = get_conn()
- cur = conn.cursor()
- cur.execute("PRAGMA user_version;")
- version = int(cur.fetchone()[0])
- return version
+ return run_sql("PRAGMA user_version;")[0]["user_version"]
def _get_version(migration: Path):
@@ -58,8 +54,7 @@ def _write_db_schema_script(migrations_dir: Path):
def migrate(overwrite_latest_schema=True):
# If there's no existing db, create one with the correct pragmas
if not Path(DBNAME).is_file():
- conn = get_conn()
- conn.cursor().execute("PRAGMA journal_mode = WAL;")
+ run_sql("PRAGMA journal_mode = WAL;")
with resources.path(migrations, "") as migrations_dir:
pending_migrations = _get_pending_migrations(migrations_dir)
diff --git a/src/pytaku/persistence.py b/src/pytaku/persistence.py
index ab3a2be..6792d4a 100644
--- a/src/pytaku/persistence.py
+++ b/src/pytaku/persistence.py
@@ -3,12 +3,11 @@
import apsw
import argon2
-from .database.common import get_conn, run_sql
+from .database.common import run_sql
def save_title(title):
- conn = get_conn()
- conn.cursor().execute(
+ run_sql(
"""
INSERT INTO title (
id,
@@ -48,18 +47,15 @@ def save_title(title):
def load_title(site, title_id, user_id=None):
- conn = get_conn()
- result = list(
- conn.cursor().execute(
- """
- SELECT id, name, site, cover_ext, chapters, alt_names, descriptions
- FROM title
- WHERE id = ?
- AND site = ?
- AND datetime(updated_at) > datetime('now', '-6 hours');
- """,
- (title_id, site),
- )
+ result = run_sql(
+ """
+ SELECT id, name, site, cover_ext, chapters, alt_names, descriptions
+ FROM title
+ WHERE id = ?
+ AND site = ?
+ AND datetime(updated_at) > datetime('now', '-6 hours');
+ """,
+ (title_id, site),
)
if not result:
return None
@@ -68,29 +64,21 @@ def load_title(site, title_id, user_id=None):
else:
title = result[0]
- return_val = {
- "id": title[0],
- "name": title[1],
- "site": title[2],
- "cover_ext": title[3],
- "chapters": json.loads(title[4]),
- "alt_names": json.loads(title[5]),
- "descriptions": json.loads(title[6]),
- }
+ for field in ["chapters", "alt_names", "descriptions"]:
+ title[field] = json.loads(title[field])
if user_id is not None:
- return_val["is_following"] = bool(
+ title["is_following"] = bool(
run_sql(
"SELECT 1 FROM follow WHERE user_id=? AND site=? AND title_id=?;",
- (user_id, site, return_val["id"]),
+ (user_id, site, title["id"]),
)
)
- return return_val
+ return title
def save_chapter(chapter):
- conn = get_conn()
- conn.cursor().execute(
+ run_sql(
"""
INSERT INTO chapter (
id,
@@ -129,33 +117,20 @@ def save_chapter(chapter):
def load_chapter(site, chapter_id):
- conn = get_conn()
- result = list(
- conn.cursor().execute(
- """
- SELECT id, title_id, num_major, num_minor, name, pages, groups, is_webtoon
- FROM chapter
- WHERE id = ? AND site=?;
- """,
- (chapter_id, site),
- )
+ result = run_sql(
+ """
+ SELECT id, title_id, num_major, num_minor, name, pages, groups, is_webtoon
+ FROM chapter
+ WHERE id = ? AND site=?;
+ """,
+ (chapter_id, site),
)
if not result:
return None
elif len(result) > 1:
raise Exception(f"Found multiple results for chapter_id {chapter_id}!")
else:
- chapter = result[0]
- return {
- "id": chapter[0],
- "title_id": chapter[1],
- "num_major": chapter[2],
- "num_minor": chapter[3],
- "name": chapter[4],
- "pages": json.loads(chapter[5]),
- "groups": json.loads(chapter[6]),
- "is_webtoon": chapter[7],
- }
+ return result[0]
def get_prev_next_chapters(title, chapter):
@@ -181,7 +156,7 @@ def register_user(username, password):
hasher = argon2.PasswordHasher()
hashed_password = hasher.hash(password)
try:
- get_conn().cursor().execute(
+ run_sql(
"INSERT INTO user (username, password) VALUES (?, ?);",
(username, hashed_password),
)
@@ -193,18 +168,13 @@ def register_user(username, password):
def verify_username_password(username, password):
- data = list(
- get_conn()
- .cursor()
- .execute("SELECT id, password FROM user WHERE username = ?;", (username,))
- )
+ data = run_sql("SELECT id, password FROM user WHERE username = ?;", (username,))
if len(data) != 1:
print(f"User {username} doesn't exist.")
return None
- user_id = data[0][0]
- hashed_password = data[0][1]
-
+ user_id = data[0]["id"]
+ hashed_password = data[0]["password"]
hasher = argon2.PasswordHasher()
try:
hasher.verify(hashed_password, password)
@@ -215,14 +185,14 @@ def verify_username_password(username, password):
def follow(user_id, site, title_id):
- get_conn().cursor().execute(
+ run_sql(
"INSERT INTO follow (user_id, site, title_id) VALUES (?, ?, ?);",
(user_id, site, title_id),
)
def unfollow(user_id, site, title_id):
- get_conn().cursor().execute(
+ run_sql(
"DELETE FROM follow WHERE user_id=? AND site=? AND title_id=?;",
(user_id, site, title_id),
)
@@ -239,11 +209,9 @@ def get_followed_titles(user_id):
""",
(user_id,),
)
- keys = ("id", "site", "name", "cover_ext", "chapters")
title_dicts = []
for t in titles:
- title = {key: t[i] for i, key in enumerate(keys)}
- title["chapters"] = json.loads(title["chapters"])
- title_dicts.append(title)
+ t["chapters"] = json.loads(t["chapters"])
+ title_dicts.append(t)
return title_dicts