# Copyright 2024 Northern.tech AS
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
import os

import pytest

from common import (
    Device,
    DevAuthorizer,
    device_auth_req,
    make_devices,
    devices,
    clean_migrated_db,
    clean_db,
    mongo,
    cli,
    management_api,
    internal_api,
    device_api,
)
import management_api as ma

from cryptutil import compare_keys

import orchestrator


class TestDevice:
    def test_device_new(self, device_api, management_api, clean_migrated_db):
        d = Device()
        da = DevAuthorizer()

        try:
            with orchestrator.run_fake_for_device_id(1) as server:
                try:
                    device_auth_req(device_api.auth_requests_url, da, d)
                except ma.ApiException as e:
                    assert e.status == 401
        except ma.ApiException as e:
            assert e.status == 204

        devs = management_api.list_devices()

        assert len(devs) == 1
        dev = devs[0]

        assert len(dev.auth_sets) == 1
        aset = dev.auth_sets[0]

        assert compare_keys(aset.pubkey, d.public_key)

    def test_auth_req_bad_key(self, device_api, management_api, clean_migrated_db):
        d = Device()
        da = DevAuthorizer()

        # corrupt the autogenerated public key
        d.public_key = "invalid"

        rsp = device_auth_req(device_api.auth_requests_url, da, d)
        assert rsp.status_code == 400
        assert rsp.json()["error"] == "invalid auth request: cannot decode public key"

    def test_device_accept_nonexistent(self, management_api):
        try:
            management_api.accept_device("funnyid", "funnyid")
        except ma.ApiException as e:
            assert e.status == 404

    def test_device_reject_nonexistent(self, management_api):
        try:
            management_api.reject_device("funnyid", "funnyid")
        except ma.ApiException as e:
            assert e.status == 404

    def test_device_accept_reject_cycle(self, devices, device_api, management_api):
        d, da = devices[0]
        url = device_api.auth_requests_url

        dev = management_api.find_device_by_identity(d.identity)

        assert dev
        devid = dev.id

        print("found matching device with ID:", dev.id)
        aid = dev.auth_sets[0].id

        try:
            with orchestrator.run_fake_for_device_id(devid) as server:
                management_api.accept_device(devid, aid)
        except ma.ApiException as e:
            assert e.status == 204

        # device is accepted, we should get a token now
        try:
            with orchestrator.run_fake_for_device_id(devid) as server:
                rsp = device_auth_req(url, da, d)
                assert rsp.status_code == 200

                da.parse_rsp_payload(d, rsp.text)

                assert len(d.token) > 0

                # reject it now
                try:
                    management_api.reject_device(devid, aid)
                except ma.ApiException as e:
                    assert e.status == 204

                # device is rejected, should get unauthorized
                rsp = device_auth_req(url, da, d)
                assert rsp.status_code == 401
        except ma.ApiException as e:
            assert e.status == 204

    @pytest.mark.parametrize("devices", ["50"], indirect=True)
    def test_get_devices(self, management_api, devices):
        devcount = 50
        devs = management_api.list_devices()

        # try to get a maximum number of devices
        devs = management_api.list_devices(page=1, per_page=500)
        print("got", len(devs), "devices")
        assert 500 >= len(devs) >= devcount

        # we have added at least `devcount` devices, so listing some lower
        # number of device should return exactly that number of entries
        plimit = devcount // 2
        devs = management_api.list_devices(page=1, per_page=plimit)
        assert len(devs) == plimit

    def test_get_device_limit(self, management_api):
        limit = management_api.get_device_limit()
        print("limit:", limit)
        assert limit.limit == 0

    def test_get_single_device_none(self, management_api):
        try:
            management_api.get_device(id="some-devid-foo")
        except ma.ApiException as e:
            assert e.status == 404

    def test_get_device_single(self, management_api, devices):
        dev, _ = devices[0]

        # try to find our devices in all devices listing
        ourdev = management_api.find_device_by_identity(dev.identity)

        authdev = management_api.get_device(id=ourdev.id)
        assert authdev == ourdev

    def test_delete_device_nonexistent(self, management_api):
        # try delete a nonexistent device
        try:
            management_api.decommission_device("some-devid-foo")
        except ma.ApiException as e:
            assert e.status == 404

    def test_delete_device(self, management_api, internal_api, devices):
        # try delete an existing device, verify decommissioning workflow was started
        # setup single device and poke devauth
        dev, _ = devices[0]
        ourdev = management_api.find_device_by_identity(dev.identity)
        assert ourdev

        try:
            with orchestrator.run_fake_for_device_id(ourdev.id):
                management_api.decommission_device(
                    ourdev.id,
                    x_men_request_id="delete_device",
                )
        except ma.ApiException as e:
            assert e.status == 204

        try:
            with orchestrator.run_fake_for_device_id(ourdev.id):
                internal_api.delete_device(
                    ourdev.id,
                    headers={
                        "X-MEN-RequestID": "delete_device",
                        "Authorization": "Bearer foobar",
                    },
                )
        except ma.ApiException as e:
            assert e.status == 204

        found = management_api.find_device_by_identity(dev.identity)
        assert not found

    @pytest.mark.parametrize("devices", ["15"], indirect=True)
    def test_device_count_simple(self, devices, management_api):
        """We have 15 devices, each with a single auth set, verify that
        accepting/rejecting affects the count"""
        count = management_api.count_devices()

        assert count == 15

        pending_count = management_api.count_devices(status="pending")
        assert pending_count == 15

        # accept device[0] and reject device[1]
        for idx, (d, da) in enumerate(devices[0:2]):
            dev = management_api.find_device_by_identity(d.identity)

            assert dev
            devid = dev.id

            print("found matching device with ID:", dev.id)
            aid = dev.auth_sets[0].id

            try:
                with orchestrator.run_fake_for_device_id(devid) as server:
                    if idx == 0:
                        management_api.accept_device(devid, aid)
                    elif idx == 1:
                        management_api.reject_device(devid, aid)
            except ma.ApiException as e:
                assert e.status == 204

        TestDevice.verify_device_count(management_api, "pending", 13)
        TestDevice.verify_device_count(management_api, "accepted", 1)
        TestDevice.verify_device_count(management_api, "rejected", 1)
        TestDevice.verify_device_count(management_api, "noauth", 0)

    @staticmethod
    def verify_device_count(management_api, status, expected_count):
        count = management_api.count_devices(status=status)
        assert count == expected_count

    @pytest.mark.parametrize("devices", ["5"], indirect=True)
    def test_device_count_multiple_auth_sets(self, devices, management_api, device_api):
        """"Verify that auth sets are properly counted. Take a device, make sure it has
        2 auth sets, switch each auth sets between accepted/rejected/pending/noauth
        states
        """

        dev, dauth = devices[0]
        # pretend device rotates its keys
        dev.rotate_key()

        device_auth_req(device_api.auth_requests_url, dauth, dev)

        # should have 2 auth sets now
        found_dev = management_api.find_device_by_identity(dev.identity)
        assert len(found_dev.auth_sets) == 2

        first_aid, second_aid = found_dev.auth_sets[0].id, found_dev.auth_sets[1].id

        # device [0] has 2 auth sets, but still counts as 1 device
        TestDevice.verify_device_count(management_api, "pending", 5)

        devid = found_dev.id
        with orchestrator.run_fake_for_device_id(orchestrator.ANY_DEVICE) as server:
            # accept first auth set
            management_api.accept_device(devid, first_aid)

            TestDevice.verify_device_count(management_api, "pending", 4)
            TestDevice.verify_device_count(management_api, "accepted", 1)
            TestDevice.verify_device_count(management_api, "rejected", 0)
            TestDevice.verify_device_count(management_api, "noauth", 0)

            # reject the other
            management_api.reject_device(devid, second_aid)
            TestDevice.verify_device_count(management_api, "pending", 4)
            TestDevice.verify_device_count(management_api, "accepted", 1)
            TestDevice.verify_device_count(management_api, "rejected", 0)
            TestDevice.verify_device_count(management_api, "noauth", 0)

            # reject both
            management_api.reject_device(devid, first_aid)
            TestDevice.verify_device_count(management_api, "pending", 4)
            TestDevice.verify_device_count(management_api, "accepted", 0)
            TestDevice.verify_device_count(management_api, "rejected", 1)
            TestDevice.verify_device_count(management_api, "noauth", 0)

            # switch the first back to pending, 2nd remains rejected
            management_api.put_device_status(devid, first_aid, "pending")
            TestDevice.verify_device_count(management_api, "pending", 5)
            TestDevice.verify_device_count(management_api, "accepted", 0)
            TestDevice.verify_device_count(management_api, "rejected", 0)
            TestDevice.verify_device_count(management_api, "noauth", 0)

            # delete device authsets, becomes 'noauth'
            for a in found_dev.auth_sets:
                management_api.delete_authset(devid, a.id)

            TestDevice.verify_device_count(management_api, "pending", 4)
            TestDevice.verify_device_count(management_api, "accepted", 0)
            TestDevice.verify_device_count(management_api, "rejected", 0)
            TestDevice.verify_device_count(management_api, "noauth", 1)

            # device can come back from noauth
            device_auth_req(device_api.auth_requests_url, dauth, dev)
            TestDevice.verify_device_count(management_api, "pending", 5)
            TestDevice.verify_device_count(management_api, "accepted", 0)
            TestDevice.verify_device_count(management_api, "rejected", 0)
            TestDevice.verify_device_count(management_api, "noauth", 0)


class TestDeleteAuthsetBase:
    def _test_delete_authset_OK(self, management_api, devices, **kwargs):
        d, da = devices[0]

        dev = management_api.find_device_by_identity(d.identity, **kwargs)
        assert dev

        print("found matching device with ID:", dev.id)
        aid = dev.auth_sets[0].id

        with orchestrator.run_fake_for_device_id(dev.id) as server:
            try:
                management_api.delete_authset(dev.id, aid, **kwargs)
            except ma.ApiException as e:
                assert e.status == 204

        found = management_api.get_device(id=dev.id, **kwargs)
        assert found

        assert len(found.auth_sets) == 0

    def _test_delete_authset_error_device_not_found(
        self, management_api, devices, **kwargs
    ):
        try:
            management_api.delete_authset("foo", "bar")
        except ma.ApiException as e:
            assert e.status == 404

    def _test_delete_authset_error_authset_not_found(
        self, management_api, devices, **kwargs
    ):
        d, da = devices[0]

        dev = management_api.find_device_by_identity(d.identity, **kwargs)

        assert dev
        devid = dev.id

        print("found matching device with ID:", dev.id)

        try:
            management_api.delete_authset(devid, "foobar")
        except ma.ApiException as e:
            assert e.status == 404


class TestDeleteAuthset(TestDeleteAuthsetBase):
    def test_delete_authset_OK(self, management_api, devices):
        self._test_delete_authset_OK(management_api, devices)

    def test_delete_authset_error_device_not_found(self, management_api, devices):
        self._test_delete_authset_error_device_not_found(management_api, devices)

    def test_delete_authset_error_authset_not_found(self, management_api, devices):
        self._test_delete_authset_error_authset_not_found(management_api, devices)
