#!/usr/bin/env python3
# Copyright 2023 Northern.tech AS
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

# Used to generate changelogs from the repository.

import argparse
import os
import os.path
import re
import subprocess
import sys

from abc import ABC, abstractmethod
from typing import List

# Holds each changelog entry indexed by SHA
ENTRIES = {}
# Links SHAs together, if we have a "X cherry-picked from Y" situation, those
# two commits will be linked, and this will be used in cases where we have
# reverted a commit.
LINKED_SHAS = {}
# A map of shas to a list of bugtracker numbers, as extracted from the commit
# messages.
SHA_TO_TRACKER = {}

JIRA_PROJECTS = "|".join(["ARCHIVE", "INF", "MEN", "QA", "MC", "ME", "ALV", "SEC"])
JIRA_REGEX = rf"\[?(?:Jira:? *)?(?:https?://northerntech.atlassian.net/browse/)?((?:{JIRA_PROJECTS})(?:-| )[0-9]+)\]?"
TRACKER_REGEX = r"\(?(?:(?:Ref|Ticket):? *)?%s\)?:? *" % (JIRA_REGEX)
CONVENTIONAL_COMMIT_TYPE_REGEX = (
    r"^(fix|feat|build|chore|ci|docs|perf|refac|revert|style|test)(?:\(\S+\))?!?: ?"
)
DEPENDABOT_REGEX = r"[Bb]umps? .* from .* to .*"
ORGANIZATION = "mendersoftware"


def aggregate_dependabot_changelogs(
    entry_list: List[str], possible_problems: List[str]
) -> List[str]:
    new_entry_list = []
    aggregated_dependabot_changelogs = ""
    for entry in entry_list:
        if re.search(DEPENDABOT_REGEX, entry):
            if ORGANIZATION not in entry:
                # Get rid of first line, which is a duplicate of the next one,
                # then strip start and end whitespace, but keep the rest of the
                # whitespace intact for indentation and line separation
                # purposes.
                try:
                    aggregated_dependabot_changelogs += "\n* " + "\n    ".join(
                        [
                            s
                            for n, s in enumerate(
                                entry.split("\n", 1)[1].strip().split("\n")
                            )
                        ]
                    )
                except:
                    possible_problems.append(
                        (
                            "*** A changelog entry which was assumed by the script to be "
                            + "derived from a dependabot commit could not be truncated "
                            + "properly. Either a commit was wrongly assumed to be committed "
                            + "by dependabot or a dependabot commit does not, for unknown "
                            + "reasons, follow its standard format. The changelog entry "
                            + " that produced this fault was:\n\n%s"
                        )
                        % entry
                    )
                # Compress lines with only whitespace. Not really necessary, but
                # looks better in editors.
                aggregated_dependabot_changelogs = re.sub(
                    "\n +\n", "\n\n", aggregated_dependabot_changelogs
                )
        else:
            new_entry_list.append(entry)
    if aggregated_dependabot_changelogs != "":
        new_entry_list.append(
            "Aggregated Dependabot Changelogs:" + aggregated_dependabot_changelogs
        )
    return new_entry_list


def add_entry(entries, sha, msg):
    if msg.lower().strip() == "none":
        return

    sha_list = entries.get(sha)
    if sha_list is None:
        sha_list = []
    sha_list.append(msg)
    entries[sha] = sha_list


parser = argparse.ArgumentParser(
    description="Generates a changelog for Mender repositories."
)
parser.add_argument(
    "--repo",
    dest="repo",
    action="store_true",
    default=False,
    help="Includes only the current repository. " + "Mutually exclusive with --all",
)
parser.add_argument(
    "--all",
    dest="all",
    action="store_true",
    help="(Default) Includes all versioned Mender repositories. "
    + "Commit range should be given relative to the integration "
    + "repository, and the generator will use release_tool.py "
    + "to figure out versions of other repositories. "
    + "Mutually exclusive with --repo",
)
parser.add_argument(
    "--base-dir",
    dest="base_dir",
    help="Base directory containing all the Mender repositories. "
    + "Ignored if using --repo",
)
parser.add_argument(
    "--query-github",
    dest="query_github",
    action="store_true",
    help="Query Github for commit messages instead of local repositories. "
    + "Set GITHUB_TOKEN if you want authorized access.",
)
parser.add_argument(
    "--github-repo",
    dest="github_repo",
    metavar="NAME",
    help="Name the Github repository to use, instead of deducing it from "
    + "the current directory. Requires --query-github.",
)
parser.add_argument(
    "--sort-changelog",
    dest="sort_changelog",
    action="store_true",
    # Not used except in the test.
    help=argparse.SUPPRESS,
)
parser.add_argument(
    "--append-commit-sha",
    dest="append_commit_sha",
    action="store_true",
    help="Append to every entry the SHA that it came from. "
    + "This can help when tidying up history with Cancel-changelog "
    + "after a long release cycle.",
)
parser.add_argument(
    "range", metavar="<commit-range> [--]", help="Range of commits to generate log for"
)
parser.add_argument(
    "gitargs",
    metavar="git-argument",
    nargs="*",
    help="Additional git arguments to tailor the commit range. "
    + "Note that for technical reasons these must come last.",
)


class GitQueryInterface(ABC):
    @abstractmethod
    def get_commits_for_range(self, repo, range):
        pass

    @abstractmethod
    def get_raw_commit_message(self, repo, sha):
        pass

    @abstractmethod
    def show_commit_without_diff(self, repo, sha):
        pass


class GitQuerier(GitQueryInterface):
    def __init__(self, gitargs, base_dir):
        self.gitargs = gitargs
        self.base_dir = base_dir

    def get_commits_for_range(self, repo, range):
        output = subprocess.check_output(
            ["git", "rev-list", "--reverse", range] + self.gitargs,
            cwd=os.path.join(self.base_dir, repo),
        )
        return output.decode().split()

    def get_raw_commit_message(self, repo, sha):
        output = subprocess.check_output(
            ["git", "cat-file", "-p", sha] + self.gitargs,
            cwd=os.path.join(self.base_dir, repo),
        )
        # Return only the part after headers (double newline).
        return output.decode().split("\n\n", 1)[1]

    def show_commit_without_diff(self, repo, sha):
        output = subprocess.check_output(
            ["git", "show", "--no-patch", sha] + self.gitargs,
            cwd=os.path.join(self.base_dir, repo),
        )
        return output.decode()


class GitHubQuerier(GitQueryInterface):
    def __init__(self):
        # Avoid importing this unless we have to, which is only when using this
        # class.
        import github

        auth = None
        if os.getenv("GITHUB_TOKEN"):
            auth = github.Auth.Token(os.getenv("GITHUB_TOKEN"))
        self.github = github.Github(auth=auth)

    def get_commits_for_range(self, repo, range):
        base, head = range.split("..", 1)
        # Cache the list, which has a lot of information, otherwise we need to
        # query every commit later.
        self.cached_result = self.github.get_repo(f"mendersoftware/{repo}").compare(
            base, head
        )
        return [commit.sha for commit in self.cached_result.commits]

    def get_raw_commit_message(self, repo, sha):
        for commit in self.cached_result.commits:
            if commit.sha == sha:
                return commit.commit.message
        raise Exception(f"Could not find {sha} in the cache")

    def show_commit_without_diff(self, repo, sha):
        # No way to return it exactly as Git does without manual piecing
        # together. This is only used for information, so return this instead.
        return self.get_raw_commit_message(repo, sha)


args = parser.parse_args()

if args.all and args.repo:
    print("--repo and --all are mutually exclusive!")
    sys.exit(1)
elif not any([args.all, args.repo]):
    args.all = True

if not any([args.base_dir, args.repo, args.query_github]):
    print("Need at least one of --base-dir, --repo or --query-github")
    sys.exit(1)

if args.github_repo and not args.query_github:
    print("--github-repo requires --query-github")
    sys.exit(1)

repos = []
if args.base_dir:
    base_dir = args.base_dir
else:
    base_dir = os.path.realpath("..")

if args.all:
    repo_list = (
        subprocess.check_output(
            [
                os.path.join(base_dir, "integration/extra/release_tool.py"),
                "--list",
                "--in-integration-version",
                args.range,
                "--only-backend",
            ]
        )
        .decode()
        .strip()
    )
    repos += repo_list.split()
elif args.github_repo:
    repos.append(args.github_repo)
else:
    repos.append(os.path.basename(os.path.realpath(".")))

if args.query_github:
    git_query = GitHubQuerier()
else:
    git_query = GitQuerier(args.gitargs, base_dir)


def get_range_for_repo(repo, range):
    return (
        subprocess.check_output(
            [
                os.path.join(base_dir, "integration/extra/release_tool.py"),
                "--version-of",
                repo,
                "--in-integration-version",
                range,
            ]
        )
        .decode()
        .strip()
    )


def extract_message_and_bucket(message, fixes, feats, others, dep_bumps):
    t = ""
    bucket = others
    match = re.match(CONVENTIONAL_COMMIT_TYPE_REGEX, message, re.IGNORECASE)
    if match:
        t = match.group(1).lower()
    stripped_message = re.sub(
        CONVENTIONAL_COMMIT_TYPE_REGEX, "", message, count=1, flags=re.IGNORECASE
    )
    if t == "fix":
        bucket = fixes
    elif t == "feat":
        bucket = feats
    elif re.search(DEPENDABOT_REGEX, stripped_message):
        bucket = dep_bumps
    return stripped_message, bucket


def print_category_entries(
    entry_list, category, possible_problems, print_category=True
):
    if not entry_list:
        return
    entry_list = aggregate_dependabot_changelogs(entry_list, possible_problems)
    if print_category:
        print("##### " + category + "\n")
    for entry in entry_list:
        entry = "* " + entry
        # Indent all non-empty lines.
        entry = re.sub(r"\n([^\n])", r"\n  \1", entry)
        print(entry)
    print()


def gather_entries(repo, range, remove_cherry_picks=True):
    # If removing cherry-picks, then invert the range, and gather all entries
    # from there. Use this to remove entries from the results.
    exclude_entries = set()
    if remove_cherry_picks and range.find("..") >= 0:
        split = range.split("..", 1)
        inverted_range = split[1] + ".." + split[0]
        exclude_entries_map = gather_entries(
            repo, inverted_range, remove_cherry_picks=False
        )
        for _, values in exclude_entries_map.items():
            for msg in values:
                exclude_entries.add(msg)

    entries = {}

    sha_list = git_query.get_commits_for_range(repo, range)

    for sha in sha_list:
        blob = git_query.get_raw_commit_message(repo, sha)

        title_fetched = False
        title = ""
        commit_msg = ""
        log_entry_commit = False
        commit_prefix = ""
        log_entry_local = False
        log_entry = ""
        exclusive_tag_seen = False

        type_match = re.match(CONVENTIONAL_COMMIT_TYPE_REGEX, blob, re.IGNORECASE)
        cc_prefix = type_match.group(0) if type_match else ""

        for line in blob.split("\n"):
            line = line.rstrip("\r")
            if line == "":
                if log_entry:
                    add_entry(entries, sha, log_entry.strip())
                    log_entry = ""
                log_entry_local = False

            # Tracker reference, remove from string.
            for match in re.finditer(TRACKER_REGEX, line, re.IGNORECASE):
                if not SHA_TO_TRACKER.get(sha):
                    SHA_TO_TRACKER[sha] = set()
                SHA_TO_TRACKER[sha].add("".join(match.groups("")))
                tracker_removed = re.sub(TRACKER_REGEX, "", line, flags=re.IGNORECASE)
                line = tracker_removed.strip(" ")

            if re.match(r"^ticket: *none *$", line, re.IGNORECASE):
                continue

            if not title_fetched:
                title = line[len(cc_prefix) :]
                title_fetched = True

            match_changelog = re.match(
                r"^ *Changelog(?:\((fix|feat|build|chore|ci|docs|perf|refac|revert|style|test)\))?: *(.*)",
                line,
                re.IGNORECASE,
            )
            if match_changelog:
                if log_entry:
                    add_entry(entries, sha, log_entry.strip())
                    log_entry = ""
                log_entry_local = False

                type_keyword = match_changelog.group(1)
                if type_keyword:
                    log_entry = type_keyword.lower() + ": "
                elif cc_prefix:
                    log_entry = cc_prefix

                # The logic of this regex is:
                # - Shall start with Title|Commit|All|None
                # - From here, if [.:]+ then capture the rest in a new group
                # - Else just expect end line
                match_keyword = re.match(
                    r"^(Title|Commit|All|None) *(?:[.:]+ *(.*)$|$)",
                    match_changelog.group(2),
                    re.IGNORECASE,
                )

                if match_keyword:
                    if exclusive_tag_seen:
                        # This doesn't really mean that the tags are exclusive, but rather that it
                        # is quite uncommon to see any of them together, and might indicate a
                        # squashed commit.
                        POSSIBLE_PROBLEMS.append(
                            "*** Commit %s had conflicting changelog tags. "
                            "This might be a squashed commit which will not work correctly with changelogs. "
                            "Should be manually checked."
                            "\n---\n%s---"
                            % (sha, git_query.show_commit_without_diff(repo, sha),)
                        )
                    exclusive_tag_seen = True

                    if match_keyword.group(1).lower() == "title":
                        log_entry += title
                    elif match_keyword.group(1).lower() == "none":
                        log_entry_commit = False
                        log_entry = ""
                    elif match_keyword.group(1).lower() in ["commit", "all"]:
                        log_entry_commit = True
                        commit_prefix = log_entry
                        log_entry = ""
                        if match_keyword.group(2):
                            # Log the rest of the line
                            if commit_msg:
                                commit_msg += "\n"
                            commit_msg += match_keyword.group(2)
                else:
                    log_entry_local = True
                    log_entry += match_changelog.group(2)
                continue

            for cancel_expr in [
                r"^ *Cancel-Changelog: *([0-9a-f]+).*",
                r"^This reverts commit ([0-9a-f]+).*",
            ]:
                match = re.match(cancel_expr, line, re.IGNORECASE)
                if match:
                    if log_entry:
                        add_entry(entries, sha, log_entry.strip())
                        log_entry = ""
                    log_entry_local = False

                    linked_shas = [match.group(1)]
                    if LINKED_SHAS.get(match.group(1)):
                        for linked_sha in LINKED_SHAS.get(match.group(1)):
                            linked_shas.append(linked_sha)
                    for linked_sha in linked_shas:
                        if LINKED_SHAS.get(linked_sha):
                            del LINKED_SHAS[linked_sha]
                        if entries.get(linked_sha):
                            del entries[linked_sha]
                    continue

            match = re.match(
                r"^\(cherry picked from commit ([0-9a-f]+)\)", line, re.IGNORECASE
            )
            if match:
                if log_entry:
                    add_entry(entries, sha, log_entry.strip())
                    log_entry = ""
                log_entry_local = False

                if not LINKED_SHAS.get(sha):
                    LINKED_SHAS[sha] = []
                LINKED_SHAS[sha].append(match.group(1))
                if not LINKED_SHAS.get(match.group(1)):
                    LINKED_SHAS[match.group(1)] = []
                LINKED_SHAS[match.group(1)].append(sha)
                continue

            # Use a slightly stricter filter for other "<something>-by:"
            # messages than for "Signed-off-by:", to avoid false positives.
            match = re.match(r"^(Signed-off-by:|\S+-by:.*>\s*$)", line, re.IGNORECASE)
            if match:
                # Ignore such lines.
                continue

            # Replace lines consisting entirely of `---` and `...` by
            # "```". These are typically added by dependabot, and render really
            # badly in Grav Markdown.
            if line == "---" or line == "...":
                line = "```"

            if log_entry_local:
                log_entry += "\n" + line
            else:
                if commit_msg:
                    commit_msg += "\n"
                commit_msg += line

        if log_entry_commit:
            commit_msg = commit_prefix + commit_msg[len(cc_prefix) :]
            add_entry(entries, sha, commit_msg.strip())
        if log_entry:
            add_entry(entries, sha, log_entry.strip())

    # Remove entries we don't want.
    remove_shas = []
    for sha, values in entries.items():
        index = 0
        while index < len(values):
            if values[index] in exclude_entries:
                del values[index]
                index -= 1
            index += 1
        if len(values) == 0:
            remove_shas.append(sha)
    for sha in remove_shas:
        del entries[sha]

    return entries


print("### Changelogs\n")

for repo in repos:
    ENTRIES = {}
    LINKED_SHAS = {}
    SHA_TO_TRACKER = {}
    POSSIBLE_PROBLEMS = []

    if args.all and not repo.endswith("integration"):
        range = get_range_for_repo(os.path.basename(repo), args.range)
        if args.range.find("..") >= 0 and range.find("..") < 0:
            POSSIBLE_PROBLEMS.append(
                (
                    "*** The changelog for the %s repository contains the entire history. "
                    + "This can happen for repositories that are new in this release. "
                    + "Please double check that this is not a mistake, and consider a "
                    + "simpler changelog if appropriate."
                )
                % repo
            )
    else:
        range = args.range

    ENTRIES = gather_entries(repo, range)

    entry_list = []
    for sha_entry in ENTRIES:
        tracker = ""
        if SHA_TO_TRACKER.get(sha_entry):
            jiras = [
                "[%s](https://northerntech.atlassian.net/browse/%s)"
                % (ticket.upper(), ticket.upper())
                for ticket in SHA_TO_TRACKER[sha_entry]
            ]
            tracker = ""
            if len(jiras) > 0:
                tracker += "(" + ", ".join(sorted(jiras)) + ")"
        for entry in ENTRIES[sha_entry]:
            # Safety check. See if there are still numbers at least four digits
            # long not preceded by '#' in the output and if so, warn about it.
            # This may be ticket references that we missed.
            match = re.search("(?<!#)[0-9]{4,}", entry)
            if match:
                POSSIBLE_PROBLEMS.append(
                    "*** Commit %s had a number %s which may be a ticket reference we missed. Should be manually checked."
                    "\n---\n%s---"
                    % (
                        sha_entry,
                        match.group(0),
                        git_query.show_commit_without_diff(repo, sha_entry),
                    )
                )
            entry = entry.strip("\n")
            if tracker:
                if (
                    entry.endswith("```")
                    or len(entry) - entry.rfind("\n") + len(tracker) >= 70
                ):
                    entry += "\n"
                else:
                    entry += " "
                entry += tracker

            if args.append_commit_sha:
                entry += " (Git SHA: %s)" % sha_entry

            entry_list.append(entry)

    if args.sort_changelog:
        entry_list.sort()
    if len(entry_list) > 0:
        service_header = "#### %s" % os.path.basename(repo)
        if range.find("..") >= 0:
            v_from, v_to = range.split("..", 1)
            service_header += " (%s)\n\nNew changes in %s since %s:\n" % (
                v_to,
                os.path.basename(os.path.realpath(repo)),
                v_from,
            )
        else:
            service_header += " (%s)\n" % range
        print(service_header)
    fixes = []
    feats = []
    others = []
    dep_bumps = []
    for entry in entry_list:
        (entry, bucket) = extract_message_and_bucket(
            entry, fixes, feats, others, dep_bumps
        )
        bucket.append(entry)

    print_category_entries(fixes, "Bug Fixes", POSSIBLE_PROBLEMS)
    print_category_entries(feats, "Features", POSSIBLE_PROBLEMS)
    print_category_entries(others, "Other", POSSIBLE_PROBLEMS, fixes or feats)
    print_category_entries(dep_bumps, "Dependabot bumps", POSSIBLE_PROBLEMS)

    for problem in POSSIBLE_PROBLEMS:
        if sys.stderr.isatty():
            # Use red color.
            sys.stderr.write("\033[31;1m%s\033[0m\n\n" % (problem))
        else:
            sys.stderr.write("%s\n\n" % (problem))

sys.exit(0)
