# Copyright 2024 Northern.tech AS
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
import base64
import json
import pytest
import requests

from contextlib import contextmanager

from common import (
    Device,
    DevAuthorizer,
    device_auth_req,
    explode_jwt,
    clean_migrated_db,
    clean_db,
    mongo,
    cli,
    management_api,
    internal_api,
    device_api,
)

import orchestrator
import mockserver


def request_token(device, dev_auth, url, tenant_addons=[]):
    rsp = device_auth_req(url, dev_auth, device)
    assert rsp.status_code == 200

    dev_auth.parse_rsp_payload(device, rsp.text)
    return device.token


@pytest.fixture(scope="function")
def accepted_device(device_api, management_api, clean_migrated_db):
    """Fixture that sets up an accepted device. Yields a tuple:
    (device ID, instance of Device, instance of DevAuthorizer)"""
    yield accept_device(device_api, management_api)


def accept_device(device_api, management_api, tenant_token=None):
    d = Device()
    da = DevAuthorizer(tenant_token)
    url = device_api.auth_requests_url
    kwargs = {}
    if tenant_token is not None:
        kwargs["Authorization"] = "Bearer " + tenant_token
    with orchestrator.run_fake_for_device_id(1) as server:
        # poke devauth so that device appears
        rsp = device_auth_req(url, da, d)
        assert rsp.status_code == 401

        # try to find our devices in all devices listing
        dev = management_api.find_device_by_identity(d.identity, **kwargs)
        assert dev is not None

        print("found matching device with ID", dev.id)
        devid = dev.id
        # extract authentication data set ID
        aid = dev.auth_sets[0].id

    with orchestrator.run_fake_for_device_id(devid) as server:
        try:
            management_api.accept_device(devid, aid, **kwargs)
        except management_api.ApiException as e:
            assert e.status == 204

    return devid, d, da


@pytest.fixture(scope="function")
def device_token(accepted_device, device_api):
    devid, d, da = accepted_device

    with orchestrator.run_fake_for_device_id(devid) as server:
        token = request_token(d, da, device_api.auth_requests_url)

    print("device token:", token)
    assert token
    yield token


@pytest.fixture(scope="session")
def token_verify_url(internal_api):
    verify_url = internal_api.make_api_url("/tokens/verify")
    print("verify URL:", verify_url)
    yield verify_url


class TestToken:
    def test_token_claims(self, accepted_device, management_api, device_api):
        devid, d, da = accepted_device

        with orchestrator.run_fake_for_device_id(devid) as server:
            token = request_token(d, da, device_api.auth_requests_url)

        assert len(token) > 0
        print("device token:", d.token)

        thdr, tclaims, tsign = explode_jwt(d.token)
        assert "typ" in thdr and thdr["typ"] == "JWT"

        assert "jti" in tclaims
        assert "exp" in tclaims
        assert "sub" in tclaims and tclaims["sub"] == devid
        assert "iss" in tclaims and tclaims["iss"] == "Mender"
        assert "mender.device" in tclaims and tclaims["mender.device"] == True

    def test_token_verify_ok(self, internal_api, device_token, token_verify_url):
        if not device_token.startswith("Bearer "):
            device_token = "Bearer " + device_token
        return internal_api.verify_jwt(authorization=device_token)

    def test_token_verify_none(self, token_verify_url):
        # no auth header should raise an error
        rsp = requests.post(token_verify_url, data="")
        assert rsp.status_code == 401

    def test_token_verify_bad(self, token_verify_url):
        # use a bogus token that is not a valid JWT
        rsp = requests.post(
            token_verify_url, data="", headers={"Authorization": "bogus"}
        )
        assert rsp.status_code == 401

    def test_token_verify_corrupted(self, device_token, token_verify_url):
        auth_hdr = "Bearer {}".format(device_token)

        rsp = requests.post(
            token_verify_url, data="", headers={"Authorization": auth_hdr + "==foo"}
        )
        assert rsp.status_code == 401

    def test_token_delete(self, device_token, token_verify_url, management_api):
        _, tclaims, _ = explode_jwt(device_token)

        management_api.delete_token(id=tclaims["jti"])

        auth_hdr = "Bearer {}".format(device_token)
        # unsuccessful verification
        rsp = requests.post(
            token_verify_url, data="", headers={"Authorization": auth_hdr}
        )
        assert rsp.status_code == 401
