Repos / pytaku / 7167519d77
commit 7167519d773aa93a77c3e3e96149b85b82dd8b08
Author: Bùi Thành Nhân <hi@imnhan.com>
Date:   Thu Aug 6 20:56:11 2020 +0700

    return db rows as dicts instead of tuples
    
    also use run_sql whenever possible

diff --git a/src/pytaku/database/common.py b/src/pytaku/database/common.py
index 6701431..34d82bc 100644
--- a/src/pytaku/database/common.py
+++ b/src/pytaku/database/common.py
@@ -7,10 +7,20 @@
 
 def get_conn():
     global _conn
+
     if not _conn:
         _conn = apsw.Connection(DBNAME)
+
         # Apparently you need to enable this pragma _per connection_
         _conn.cursor().execute("PRAGMA foreign_keys = ON;")
+
+        # Return rows as dicts instead of tuples
+        _conn.setrowtrace(
+            lambda cursor, row: {
+                k[0]: row[i] for i, k in enumerate(cursor.getdescription())
+            }
+        )
+
     return _conn
 
 
diff --git a/src/pytaku/database/migrator.py b/src/pytaku/database/migrator.py
index 670ed7d..c63519c 100644
--- a/src/pytaku/database/migrator.py
+++ b/src/pytaku/database/migrator.py
@@ -3,7 +3,7 @@
 from pathlib import Path
 
 from . import migrations
-from .common import DBNAME, get_conn
+from .common import DBNAME, get_conn, run_sql
 
 
 """
@@ -15,11 +15,7 @@
 
 
 def _get_current_version():
-    conn = get_conn()
-    cur = conn.cursor()
-    cur.execute("PRAGMA user_version;")
-    version = int(cur.fetchone()[0])
-    return version
+    return run_sql("PRAGMA user_version;")[0]["user_version"]
 
 
 def _get_version(migration: Path):
@@ -58,8 +54,7 @@ def _write_db_schema_script(migrations_dir: Path):
 def migrate(overwrite_latest_schema=True):
     # If there's no existing db, create one with the correct pragmas
     if not Path(DBNAME).is_file():
-        conn = get_conn()
-        conn.cursor().execute("PRAGMA journal_mode = WAL;")
+        run_sql("PRAGMA journal_mode = WAL;")
 
     with resources.path(migrations, "") as migrations_dir:
         pending_migrations = _get_pending_migrations(migrations_dir)
diff --git a/src/pytaku/persistence.py b/src/pytaku/persistence.py
index ab3a2be..6792d4a 100644
--- a/src/pytaku/persistence.py
+++ b/src/pytaku/persistence.py
@@ -3,12 +3,11 @@
 import apsw
 import argon2
 
-from .database.common import get_conn, run_sql
+from .database.common import run_sql
 
 
 def save_title(title):
-    conn = get_conn()
-    conn.cursor().execute(
+    run_sql(
         """
     INSERT INTO title (
         id,
@@ -48,18 +47,15 @@ def save_title(title):
 
 
 def load_title(site, title_id, user_id=None):
-    conn = get_conn()
-    result = list(
-        conn.cursor().execute(
-            """
-    SELECT id, name, site, cover_ext, chapters, alt_names, descriptions
-    FROM title
-    WHERE id = ?
-      AND site = ?
-      AND datetime(updated_at) > datetime('now', '-6 hours');
-    """,
-            (title_id, site),
-        )
+    result = run_sql(
+        """
+        SELECT id, name, site, cover_ext, chapters, alt_names, descriptions
+        FROM title
+        WHERE id = ?
+          AND site = ?
+          AND datetime(updated_at) > datetime('now', '-6 hours');
+        """,
+        (title_id, site),
     )
     if not result:
         return None
@@ -68,29 +64,21 @@ def load_title(site, title_id, user_id=None):
     else:
         title = result[0]
 
-        return_val = {
-            "id": title[0],
-            "name": title[1],
-            "site": title[2],
-            "cover_ext": title[3],
-            "chapters": json.loads(title[4]),
-            "alt_names": json.loads(title[5]),
-            "descriptions": json.loads(title[6]),
-        }
+        for field in ["chapters", "alt_names", "descriptions"]:
+            title[field] = json.loads(title[field])
 
         if user_id is not None:
-            return_val["is_following"] = bool(
+            title["is_following"] = bool(
                 run_sql(
                     "SELECT 1 FROM follow WHERE user_id=? AND site=? AND title_id=?;",
-                    (user_id, site, return_val["id"]),
+                    (user_id, site, title["id"]),
                 )
             )
-        return return_val
+        return title
 
 
 def save_chapter(chapter):
-    conn = get_conn()
-    conn.cursor().execute(
+    run_sql(
         """
     INSERT INTO chapter (
         id,
@@ -129,33 +117,20 @@ def save_chapter(chapter):
 
 
 def load_chapter(site, chapter_id):
-    conn = get_conn()
-    result = list(
-        conn.cursor().execute(
-            """
-    SELECT id, title_id, num_major, num_minor, name, pages, groups, is_webtoon
-    FROM chapter
-    WHERE id = ? AND site=?;
-    """,
-            (chapter_id, site),
-        )
+    result = run_sql(
+        """
+        SELECT id, title_id, num_major, num_minor, name, pages, groups, is_webtoon
+        FROM chapter
+        WHERE id = ? AND site=?;
+        """,
+        (chapter_id, site),
     )
     if not result:
         return None
     elif len(result) > 1:
         raise Exception(f"Found multiple results for chapter_id {chapter_id}!")
     else:
-        chapter = result[0]
-        return {
-            "id": chapter[0],
-            "title_id": chapter[1],
-            "num_major": chapter[2],
-            "num_minor": chapter[3],
-            "name": chapter[4],
-            "pages": json.loads(chapter[5]),
-            "groups": json.loads(chapter[6]),
-            "is_webtoon": chapter[7],
-        }
+        return result[0]
 
 
 def get_prev_next_chapters(title, chapter):
@@ -181,7 +156,7 @@ def register_user(username, password):
     hasher = argon2.PasswordHasher()
     hashed_password = hasher.hash(password)
     try:
-        get_conn().cursor().execute(
+        run_sql(
             "INSERT INTO user (username, password) VALUES (?, ?);",
             (username, hashed_password),
         )
@@ -193,18 +168,13 @@ def register_user(username, password):
 
 
 def verify_username_password(username, password):
-    data = list(
-        get_conn()
-        .cursor()
-        .execute("SELECT id, password FROM user WHERE username = ?;", (username,))
-    )
+    data = run_sql("SELECT id, password FROM user WHERE username = ?;", (username,))
     if len(data) != 1:
         print(f"User {username} doesn't exist.")
         return None
 
-    user_id = data[0][0]
-    hashed_password = data[0][1]
-
+    user_id = data[0]["id"]
+    hashed_password = data[0]["password"]
     hasher = argon2.PasswordHasher()
     try:
         hasher.verify(hashed_password, password)
@@ -215,14 +185,14 @@ def verify_username_password(username, password):
 
 
 def follow(user_id, site, title_id):
-    get_conn().cursor().execute(
+    run_sql(
         "INSERT INTO follow (user_id, site, title_id) VALUES (?, ?, ?);",
         (user_id, site, title_id),
     )
 
 
 def unfollow(user_id, site, title_id):
-    get_conn().cursor().execute(
+    run_sql(
         "DELETE FROM follow WHERE user_id=? AND site=? AND title_id=?;",
         (user_id, site, title_id),
     )
@@ -239,11 +209,9 @@ def get_followed_titles(user_id):
         """,
         (user_id,),
     )
-    keys = ("id", "site", "name", "cover_ext", "chapters")
     title_dicts = []
     for t in titles:
-        title = {key: t[i] for i, key in enumerate(keys)}
-        title["chapters"] = json.loads(title["chapters"])
-        title_dicts.append(title)
+        t["chapters"] = json.loads(t["chapters"])
+        title_dicts.append(t)
 
     return title_dicts