Repos / pytaku / d5447c3fd3
commit d5447c3fd3f9358a8f1e9890df8d67f0515b26d7
Author: Bùi Thành Nhân <hi@imnhan.com>
Date:   Thu Jul 30 18:01:00 2020 +0700

    implement poor man's migrator

diff --git a/README.md b/README.md
new file mode 100644
index 0000000..4631d69
--- /dev/null
+++ b/README.md
@@ -0,0 +1,7 @@
+```sh
+poetry install
+pip install --upgrade pip
+pip install https://github.com/rogerbinns/apsw/releases/download/3.32.2-r1/apsw-3.32.2-r1.zip \
+      --global-option=fetch --global-option=--version --global-option=3.32.2 --global-option=--all \
+      --global-option=build --global-option=--enable-all-extensions
+```
diff --git a/pyproject.toml b/pyproject.toml
index 512dbea..8f64077 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,6 +5,9 @@ description = ""
 authors = ["Bùi Thành Nhân <hi@imnhan.com>"]
 license = "AGPL-3.0-only"
 
+[tool.poetry.scripts]
+pytaku-migrate = "pytaku:migrate"
+
 [tool.poetry.dependencies]
 python = "^3.7"
 
diff --git a/src/pytaku/__init__.py b/src/pytaku/__init__.py
index fd290f0..53b1563 100644
--- a/src/pytaku/__init__.py
+++ b/src/pytaku/__init__.py
@@ -1 +1,14 @@
-print("henlo")
+def migrate():
+    import argparse
+    from .database.migrator import migrate
+
+    argparser = argparse.ArgumentParser()
+    argparser.add_argument(
+        "-d",
+        "--dev",
+        action="store_true",
+        help="dev mode: overwrites latest_schema.sql on success",
+    )
+    args = argparser.parse_args()
+
+    migrate(overwrite_latest_schema=args.dev)
diff --git a/src/pytaku/database/common.py b/src/pytaku/database/common.py
new file mode 100644
index 0000000..3acae0c
--- /dev/null
+++ b/src/pytaku/database/common.py
@@ -0,0 +1,7 @@
+import apsw
+
+DBNAME = "db.sqlite3"
+
+
+def get_conn():
+    return apsw.Connection(DBNAME)
diff --git a/src/pytaku/database/migrations/__init__.py b/src/pytaku/database/migrations/__init__.py
new file mode 100644
index 0000000..361789d
--- /dev/null
+++ b/src/pytaku/database/migrations/__init__.py
@@ -0,0 +1,2 @@
+# importlib.resources.path() won't return anything
+# if this file doesn't exist
diff --git a/src/pytaku/database/migrations/latest_schema.sql b/src/pytaku/database/migrations/latest_schema.sql
new file mode 100644
index 0000000..3f12ee9
--- /dev/null
+++ b/src/pytaku/database/migrations/latest_schema.sql
@@ -0,0 +1,12 @@
+-- This file is auto-generated by the migration script
+-- for reference purposes only. DO NOT EDIT.
+
+CREATE TABLE user (
+    id integer primary key,
+    username text unique,
+    password text
+);
+CREATE TABLE token (
+    user_id integer,
+    content text
+);
diff --git a/src/pytaku/database/migrations/m0001.sql b/src/pytaku/database/migrations/m0001.sql
new file mode 100644
index 0000000..d80fe7d
--- /dev/null
+++ b/src/pytaku/database/migrations/m0001.sql
@@ -0,0 +1,12 @@
+PRAGMA journal_mode=WAL;
+
+create table user (
+    id integer primary key,
+    username text unique,
+    password text
+);
+
+create table token (
+    user_id integer,
+    content text
+);
diff --git a/src/pytaku/database/migrator.py b/src/pytaku/database/migrator.py
new file mode 100644
index 0000000..507972f
--- /dev/null
+++ b/src/pytaku/database/migrator.py
@@ -0,0 +1,80 @@
+import subprocess
+from importlib import resources
+from pathlib import Path
+
+from . import migrations
+from .common import DBNAME, get_conn
+
+
+"""
+Forward-only DB migration scheme held together by duct tape.
+
+- Uses `user_version` pragma to figure out what migrations are pending.
+- Migrations files are in the form `./migrations/mXXXX.sql`.
+"""
+
+
+def _get_current_version():
+    conn = get_conn()
+    cur = conn.cursor()
+    cur.execute("PRAGMA user_version;")
+    version = int(cur.fetchone()[0])
+    conn.close()
+    return version
+
+
+def _get_version(migration: Path):
+    return int(migration.name[len("m") : -len(".sql")])
+
+
+def _get_pending_migrations(migrations_dir: Path):
+    current_version = _get_current_version()
+    migrations = sorted(migrations_dir.glob("m*.sql"))
+    return [
+        migration
+        for migration in migrations
+        if _get_version(migration) > current_version
+    ]
+
+
+def _read_migrations(paths):
+    """Returns list of (version, sql_text) tuples"""
+    results = []
+    for path in paths:
+        with open(path, "r") as sql_file:
+            results.append((_get_version(path), sql_file.read()))
+    return results
+
+
+def _write_db_schema_script(migrations_dir: Path):
+    schema = subprocess.run(
+        ["sqlite3", DBNAME, ".schema"], capture_output=True, check=True
+    ).stdout
+    with open(migrations_dir / Path("latest_schema.sql"), "wb") as f:
+        f.write(b"-- This file is auto-generated by the migration script\n")
+        f.write(b"-- for reference purposes only. DO NOT EDIT.\n\n")
+        f.write(schema)
+
+
+def migrate(overwrite_latest_schema=True):
+    with resources.path(migrations, "") as migrations_dir:
+        pending_migrations = _get_pending_migrations(migrations_dir)
+        if not pending_migrations:
+            print("Nothing to migrate.")
+            exit()
+        print(f"There are {len(pending_migrations)} pending migrations.")
+        migration_contents = _read_migrations(pending_migrations)
+
+        conn = get_conn()
+        cursor = conn.cursor()
+
+        with conn:  # apsw provides automatic rollback for free here
+            for version, sql in migration_contents:
+                print("Migrating version", version, "...")
+                cursor.execute(sql)
+                cursor.execute(f"PRAGMA user_version = {version};")
+
+        if overwrite_latest_schema:
+            _write_db_schema_script(migrations_dir)
+
+        print("All done. Current version:", _get_current_version())