#!/usr/bin/python
import argparse
import datetime
import os
import sqlite3

import boto3
import botocore
import humanize
import psycopg2
import tqdm
import yaml

# Schema for our sqlite database cache
SCHEMA = """
    CREATE TABLE IF NOT EXISTS media (
        origin TEXT NOT NULL,  -- empty string if local media
        media_id TEXT NOT NULL,
        filesystem_id TEXT NOT NULL,
        -- Type is "local" or "remote"
        type TEXT NOT NULL,
        -- indicates whether the media and all its thumbnails have been deleted from the
        -- local cache
        known_deleted BOOLEAN NOT NULL
    );

    CREATE UNIQUE INDEX IF NOT EXISTS media_id_idx ON media(origin, media_id);
    CREATE INDEX IF NOT EXISTS deleted_idx ON media(known_deleted);
"""

progress = True


def parse_duration(string):
    """Parse a string into a duration supports suffix of s, h, d, m or y.
    """
    suffix = string[-1]
    number = string[:-1]

    try:
        number = int(number)
    except ValueError:
        raise argparse.ArgumentTypeError(
            "duration must be an integer followed by a 's', 'h', 'd', 'm' or 'y' suffix"
        )

    now = datetime.datetime.now()
    if suffix == "d":
        then = now - datetime.timedelta(days=number)
    elif suffix == "m":
        then = now - datetime.timedelta(days=30 * number)
    elif suffix == "y":
        then = now - datetime.timedelta(days=365 * number)
    elif suffix == "h":
        then = now - datetime.timedelta(hours=number)
    # We can't sensibly do minutes as months has taken 'm', seconds can be used to implement minutes
    # This would be the natural suffix as per https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html
    elif suffix == "s":
        then = now - datetime.timedelta(seconds=number)
    else:
        raise argparse.ArgumentTypeError("duration must end in 's', 'h', 'd', 'm' or 'y'")

    return then


def mark_as_deleted(sqlite_conn, origin, media_id):
    with sqlite_conn:
        sqlite_conn.execute(
            """
            UPDATE media SET known_deleted = ?
            WHERE origin = ? AND media_id = ?
            """,
            (True, origin, media_id),
        )


def get_not_deleted_count(sqlite_conn):
    """Get count of all rows in our cache that we don't think have been deleted
    """
    cur = sqlite_conn.cursor()

    cur.execute(
        """
        SELECT COALESCE(count(*), 0) FROM media
        WHERE NOT known_deleted
        """
    )
    (count,) = cur.fetchone()
    return count


def get_not_deleted(sqlite_conn):
    """Get all rows in our cache that we don't think have been deleted
    """
    cur = sqlite_conn.cursor()

    cur.execute(
        """
        SELECT origin, media_id, filesystem_id, type FROM media
        WHERE NOT known_deleted
        """
    )
    return cur


def to_path(origin, filesystem_id, m_type):
    """Get a relative path to the given media
    """
    if m_type == "local":
        file_path = os.path.join(
            "local_content", filesystem_id[:2], filesystem_id[2:4], filesystem_id[4:],
        )
    elif m_type == "remote":
        file_path = os.path.join(
            "remote_content",
            origin,
            filesystem_id[:2],
            filesystem_id[2:4],
            filesystem_id[4:],
        )
    else:
        raise Exception("Unexpected media type %r", m_type)

    return file_path


def to_thumbnail_dir(origin, filesystem_id, m_type):
    """Get a relative path to the given media's thumbnail directory
    """
    if m_type == "local":
        thumbnail_path = os.path.join(
            "local_thumbnails",
            filesystem_id[:2],
            filesystem_id[2:4],
            filesystem_id[4:],
        )
    elif m_type == "remote":
        thumbnail_path = os.path.join(
            "remote_thumbnail",
            origin,
            filesystem_id[:2],
            filesystem_id[2:4],
            filesystem_id[4:],
        )
    else:
        raise Exception("Unexpected media type %r", m_type)

    return thumbnail_path


def get_local_files(base_path, origin, filesystem_id, m_type):
    """Get a list of relative paths to undeleted files for the given media
    """
    local_files = []

    original_path = to_path(origin, filesystem_id, m_type)
    if os.path.exists(os.path.join(base_path, original_path)):
        local_files.append(original_path)

    thumbnail_path = to_thumbnail_dir(origin, filesystem_id, m_type)
    try:
        with os.scandir(os.path.join(base_path, thumbnail_path)) as dir_entries:
            for dir_entry in dir_entries:
                if dir_entry.is_file():
                    local_files.append(os.path.join(thumbnail_path, dir_entry.name))
    except FileNotFoundError:
        # The thumbnail directory does not exist
        pass
    except NotADirectoryError:
        # The thumbnail directory is not a directory for some reason
        pass

    return local_files


def check_file_in_s3(s3, bucket, key, extra_args):
    """Check the file exists in S3 (though it could be different)
    """
    try:
        if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args:
            s3.head_object(
                Bucket=bucket,
                Key=key,
                SSECustomerKey=extra_args["SSECustomerKey"],
                SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
            )
        else:
            s3.head_object(Bucket=bucket, Key=key)
    except botocore.exceptions.ClientError as e:
        if int(e.response["Error"]["Code"]) == 404:
            return False
        raise

    return True


def run_write(sqlite_conn, output_file):
    """Entry point for write command
    """
    for origin, _, filesystem_id, m_type in get_not_deleted(sqlite_conn):
        file_path = to_path(origin, filesystem_id, m_type)
        print(file_path, file=output_file)

        # Print thumbnail directories with a trailing '/'
        thumbnail_path = to_thumbnail_dir(origin, filesystem_id, m_type)
        thumbnail_path = os.path.join(thumbnail_path, "")
        print(thumbnail_path, file=output_file)


def run_update_db(synapse_db_conn, sqlite_conn, before_date):
    """Entry point for update-db command
    """

    local_sql = """
        SELECT '', media_id, media_id
        FROM local_media_repository
        WHERE
            COALESCE(last_access_ts, created_ts) < %s
            AND url_cache IS NULL
    """

    remote_sql = """
        SELECT media_origin, media_id, filesystem_id
        FROM remote_media_cache
        WHERE
            COALESCE(last_access_ts, created_ts) < %s
    """

    last_access_ts = int(before_date.timestamp() * 1000)

    print(
        "Syncing files that haven't been accessed since:", before_date.isoformat(" "),
    )

    update_count = 0

    with sqlite_conn:
        sqlite_cur = sqlite_conn.cursor()

        if isinstance(synapse_db_conn, sqlite3.Connection):
            synapse_db_curs = synapse_db_conn.cursor()
            for sql, mtype in ((local_sql, "local"), (remote_sql, "remote")):
                synapse_db_curs.execute(sql.replace("%s", "?"), (last_access_ts,))
                update_count += update_db_process_rows(
                    mtype, sqlite_cur, synapse_db_curs
                )

        else:
            with synapse_db_conn.cursor() as synapse_db_curs:
                for sql, mtype in ((local_sql, "local"), (remote_sql, "remote")):
                    synapse_db_curs.execute(sql, (last_access_ts,))
                    update_count += update_db_process_rows(
                        mtype, sqlite_cur, synapse_db_curs
                    )

    print("Synced", update_count, "new rows")

    synapse_db_conn.close()


def update_db_process_rows(mtype, sqlite_cur, synapse_db_curs):
    """Process rows extracted from Synapse's database and insert them in cache
    """
    update_count = 0

    for (origin, media_id, filesystem_id) in synapse_db_curs:
        sqlite_cur.execute(
            """
            INSERT OR IGNORE INTO media
            (origin, media_id, filesystem_id, type, known_deleted)
            VALUES (?, ?, ?, ?, ?)
            """,
            (origin, media_id, filesystem_id, mtype, False),
        )
        update_count += sqlite_cur.rowcount

    return update_count


def run_check_delete(sqlite_conn, base_path):
    """Entry point for check-deleted command
    """
    deleted = []
    if progress:
        it = tqdm.tqdm(
            get_not_deleted(sqlite_conn),
            unit="files",
            total=get_not_deleted_count(sqlite_conn),
        )
    else:
        it = get_not_deleted(sqlite_conn)
        print("Checking on ", get_not_deleted_count(sqlite_conn), " undeleted files")

    for origin, media_id, filesystem_id, m_type in it:
        local_files = get_local_files(base_path, origin, filesystem_id, m_type)
        if not local_files:
            deleted.append((origin, media_id))

    with sqlite_conn:
        sqlite_conn.executemany(
            """
            UPDATE media SET known_deleted = ?
            WHERE origin = ? AND media_id = ?
            """,
            ((True, o, m) for o, m in deleted),
        )

    print("Updated", len(deleted), "as deleted")


def run_upload(s3, bucket, sqlite_conn, base_path, s3_prefix, extra_args, should_delete):
    """Entry point for upload command
    """
    total = get_not_deleted_count(sqlite_conn)

    uploaded_media = 0
    uploaded_files = 0
    uploaded_bytes = 0
    deleted_media = 0
    deleted_files = 0
    deleted_bytes = 0

    # This is a progress bar
    if progress:
        it = tqdm.tqdm(get_not_deleted(sqlite_conn), unit="files", total=total)
    else:
        print("Uploading ", total, " files")
        it = get_not_deleted(sqlite_conn)

    for origin, media_id, filesystem_id, m_type in it:
        local_files = get_local_files(base_path, origin, filesystem_id, m_type)

        if not local_files:
            mark_as_deleted(sqlite_conn, origin, media_id)
            continue

        # Counters of uploaded and deleted files for this media only
        media_uploaded_files = 0
        media_deleted_files = 0

        for rel_file_path in local_files:
            local_path = os.path.join(base_path, rel_file_path)

            key = s3_prefix + rel_file_path
            if not check_file_in_s3(s3, bucket, key, extra_args):
                try:
                    s3.upload_file(
                        local_path, bucket, key, ExtraArgs=extra_args,
                    )
                except Exception as e:
                    print("Failed to upload file %s: %s", local_path, e)
                    continue

                media_uploaded_files += 1
                uploaded_files += 1
                uploaded_bytes += os.path.getsize(local_path)

            if should_delete:
                size = os.path.getsize(local_path)
                os.remove(local_path)

                try:
                    # This may have lead to an empty directory, so lets remove all
                    # that are empty
                    os.removedirs(os.path.dirname(local_path))
                except Exception:
                    # The directory might not be empty, or maybe we don't have
                    # permission. Either way doesn't really matter.
                    pass

                media_deleted_files += 1
                deleted_files += 1
                deleted_bytes += size

        if media_uploaded_files:
            uploaded_media += 1

        if media_deleted_files:
            deleted_media += 1

        if media_deleted_files == len(local_files):
            # Mark as deleted only if *all* the local files have been deleted
            mark_as_deleted(sqlite_conn, origin, media_id)

    print("Uploaded", uploaded_media, "media out of", total)
    print("Uploaded", uploaded_files, "files")
    print("Uploaded", humanize.naturalsize(uploaded_bytes, gnu=True))
    print("Deleted", deleted_media, "media")
    print("Deleted", deleted_files, "files")
    print("Deleted", humanize.naturalsize(deleted_bytes, gnu=True))


def get_sqlite_conn(parser):
    """Attempt to get a sqlite connection to cache.db, or exit.
    """
    try:
        sqlite_conn = sqlite3.connect("cache.db")
        sqlite_conn.executescript(SCHEMA)
    except sqlite3.Error as e:
        parser.error("Could not open 'cache.db' as sqlite DB: %s" % (e,))

    return sqlite_conn

def get_homeserver_db_conn(parser, homeserver_config_path):
    """Attempt to get a connection based on the provided YAML path to Synapse's
    database, or exit.
    """

    try:
        with open(homeserver_config_path) as f:
            homeserver_yaml = yaml.safe_load(f)
    except FileNotFoundError:
        parser.error("Could not find %s" % (homeserver_config_path,))
    except yaml.YAMLError as e:
        parser.error("%s is not valid yaml: %s" % (homeserver_config_path, e,))

    try:
        database_engine_name = homeserver_yaml["database"]["name"]
        database_args = homeserver_yaml["database"]["args"]
        if database_engine_name == "sqlite3":
            database_path = database_args["database"]
            synapse_db_conn = sqlite3.connect(database=database_path)
        else:
            # Determine the database name. "database" is a deprecated form of
            # the option name. See https://www.psycopg.org/docs/module.html
            database_name = database_args.get("dbname", database_args["database"])
            synapse_db_conn = psycopg2.connect(
                user=database_args["user"],
                password=database_args["password"],
                database=database_name,
                host=database_args["host"],
                port=database_args["port"],
            )
    except sqlite3.OperationalError as e:
        parser.error("Could not connect to sqlite3 database: %s" % (e,))
    except psycopg2.Error as e:
        parser.error("Could not connect to postgres database: %s" % (e,))

    return synapse_db_conn

def get_database_db_conn(parser):
    """Attempt to get a connection based on database.yaml to Synapse's
    database, or exit.
    """

    try:
        with open("database.yaml") as f:
            database_yaml = yaml.safe_load(f)
    except FileNotFoundError:
        parser.error("Could not find database.yaml")
    except yaml.YAMLError as e:
        parser.error("database.yaml is not valid yaml: %s" % (e,))

    try:
        if "sqlite" in database_yaml:
            synapse_db_conn = sqlite3.connect(**database_yaml["sqlite"])
        elif "postgres" in database_yaml:
            synapse_db_conn = psycopg2.connect(**database_yaml["postgres"])
        else:
            synapse_db_conn = psycopg2.connect(**database_yaml)
    except sqlite3.OperationalError as e:
        parser.error("Could not connect to sqlite3 database: %s" % (e,))
    except psycopg2.Error as e:
        parser.error("Could not connect to postgres database: %s" % (e,))

    return synapse_db_conn

def get_synapse_db_conn(parser, homeserver_config_path):
    """Attempt to get a connection based on database.yaml or homeserver.yaml 
    to Synapse's database, or exit.
    """

    if os.path.isfile("database.yaml"):
        conn = get_database_db_conn(parser)
    else:
        conn = get_homeserver_db_conn(parser, homeserver_config_path)

    return conn

def main():
    parser = argparse.ArgumentParser(prog="s3_media_upload")
    parser.add_argument(
        "--no-progress",
        help="do not show progress bars",
        action="store_true",
        dest="no_progress",
    )
    subparsers = parser.add_subparsers(help="command to run", dest="cmd")

    update_db_parser = subparsers.add_parser(
        "update-db", help="Syncs rows from database to local cache"
    )
    update_db_parser.add_argument(
        "duration",
        type=parse_duration,
        help="Fetch rows that haven't been accessed in the duration given,"
        " accepts duration of the form of e.g. 1m (one month). Valid suffixes"
        " are s, h, d, m or y. NOTE: Currently does not remove entries from the cache",
    )
    update_db_parser.add_argument(
        "--homeserver-config-path",
        help="Path to the yaml file containing Synapse's DB settings",
        default="homeserver.yaml"
    )

    deleted_parser = subparsers.add_parser(
        "check-deleted",
        help="Check whether files in the local cache still exist under given path",
    )
    deleted_parser.add_argument(
        "base_path", help="Base path of the media store directory"
    )

    update_parser = subparsers.add_parser(
        "update",
        help="Updates local cache. Equivalent to running update-db and"
        " check-deleted",
    )
    update_parser.add_argument(
        "base_path", help="Base path of the media store directory"
    )
    update_parser.add_argument(
        "duration",
        type=parse_duration,
        help="Fetch rows that haven't been accessed in the duration given,"
        " accepts duration of the form of e.g. 1m (one month). Valid suffixes"
        " are s, h, d, m or y. NOTE: Currently does not remove entries from the cache",
    )
    update_parser.add_argument(
        "--homeserver-config-path",
        help="Path to the yaml file containing Synapse's DB settings",
        default="homeserver.yaml"
    )

    write_parser = subparsers.add_parser(
        "write",
        help="Outputs all file and directory paths in the local cache that we may not"
        " have deleted. check-deleted should be run first to update cache.",
    )
    write_parser.add_argument(
        "out",
        type=argparse.FileType("w", encoding="UTF-8"),
        default="-",
        nargs="?",
        help="File to output list to, or '-' for stdout",
    )

    upload_parser = subparsers.add_parser(
        "upload", help="Uploads media to s3 based on local cache"
    )
    upload_parser.add_argument(
        "base_path", help="Base path of the media store directory"
    )
    upload_parser.add_argument("bucket", help="S3 bucket to upload to")

    upload_parser.add_argument("--prefix", help="Prefix of files in the bucket", default="")

    upload_parser.add_argument(
        "--storage-class",
        help="S3 storage class to use",
        nargs="?",
        choices=[
            "STANDARD",
            "REDUCED_REDUNDANCY",
            "STANDARD_IA",
            "ONEZONE_IA",
            "INTELLIGENT_TIERING",
        ],
        default="STANDARD",
    )

    upload_parser.add_argument(
        "--sse-customer-key", help="SSE-C key to use",
    )

    upload_parser.add_argument(
        "--sse-customer-algo",
        help="Algorithm for SSE-C, only used if sse-customer-key is also specified",
        default="AES256",
    )

    upload_parser.add_argument(
        "--delete",
        action="store_const",
        const=True,
        help="Deletes local copy from media store on succesful upload",
    )

    upload_parser.add_argument(
        "--endpoint-url", help="S3 endpoint url to use", default=None
    )

    args = parser.parse_args()
    if args.no_progress:
        global progress
        progress = False

    if args.cmd == "write":
        sqlite_conn = get_sqlite_conn(parser)
        run_write(sqlite_conn, args.out)
        return

    if args.cmd == "update-db":
        sqlite_conn = get_sqlite_conn(parser)
        synapse_db_conn = get_synapse_db_conn(parser, args.homeserver_config_path)
        run_update_db(synapse_db_conn, sqlite_conn, args.duration)
        return

    if args.cmd == "check-deleted":
        sqlite_conn = get_sqlite_conn(parser)
        run_check_delete(sqlite_conn, args.base_path)
        return

    if args.cmd == "update":
        sqlite_conn = get_sqlite_conn(parser)
        synapse_db_conn = get_synapse_db_conn(parser, args.homeserver_config_path)
        run_update_db(synapse_db_conn, sqlite_conn, args.duration)
        run_check_delete(sqlite_conn, args.base_path)
        return

    if args.cmd == "upload":
        sqlite_conn = get_sqlite_conn(parser)
        s3 = boto3.client("s3", endpoint_url=args.endpoint_url)

        extra_args = {"StorageClass": args.storage_class}
        if args.sse_customer_key:
            extra_args["SSECustomerKey"] = args.sse_customer_key
            if args.sse_customer_algo:
                extra_args["SSECustomerAlgorithm"] = args.sse_customer_algo
            else:
                extra_args["SSECustomerAlgorithm"] = "AES256"

        run_upload(
            s3,
            args.bucket,
            sqlite_conn,
            args.base_path,
            args.prefix,
            extra_args,
            should_delete=args.delete,
        )
        return

    parser.error("Valid subcommand must be specified")


if __name__ == "__main__":
    main()
