#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest import mock
from unittest.mock import AsyncMock

import pytest
from google.api_core.client_options import ClientOptions
from google.cloud.run_v2 import (
    CreateJobRequest,
    CreateServiceRequest,
    DeleteJobRequest,
    DeleteServiceRequest,
    GetJobRequest,
    GetServiceRequest,
    Job,
    JobsClient,
    ListJobsRequest,
    RunJobRequest,
    Service,
    ServicesAsyncClient,
    UpdateJobRequest,
)
from google.longrunning import operations_pb2

from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.google.cloud.hooks.cloud_run import (
    CloudRunAsyncHook,
    CloudRunHook,
    CloudRunServiceAsyncHook,
    CloudRunServiceHook,
)

from unit.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id

PROJECT_ID = "projectid"
REGION = "region1"
JOB_NAME = "job1"
SERVICE_NAME = "service1"
OPERATION_NAME = "operationname"
USE_REGIONAL_ENDPOINT = True
BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"


@pytest.mark.db_test
class TestCloudRunHook:
    def dummy_get_credentials(self):
        pass

    @pytest.fixture
    def cloud_run_hook(self):
        cloud_run_hook = CloudRunHook()
        cloud_run_hook.get_credentials = self.dummy_get_credentials
        return cloud_run_hook

    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_get_conn_regional_endpoint(self, mock_jobs_client_cls):
        hook = CloudRunHook(gcp_conn_id="google_cloud_default")
        hook.get_credentials = mock.MagicMock(return_value=mock.Mock())

        location = "us-central1"
        hook.get_conn(location=location, use_regional_endpoint=USE_REGIONAL_ENDPOINT)
        assert mock_jobs_client_cls.call_count == 1

        _, kwargs = mock_jobs_client_cls.call_args
        client_options = kwargs.get("client_options")
        assert isinstance(client_options, ClientOptions)
        assert client_options.api_endpoint == f"{location}-run.googleapis.com:443"

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_get_job(self, mock_batch_service_client, cloud_run_hook):
        get_job_request = GetJobRequest(name=f"projects/{PROJECT_ID}/locations/{REGION}/jobs/{JOB_NAME}")

        cloud_run_hook.get_job(
            job_name=JOB_NAME,
            region=REGION,
            project_id=PROJECT_ID,
            use_regional_endpoint=USE_REGIONAL_ENDPOINT,
        )
        cloud_run_hook._client.get_job.assert_called_once_with(get_job_request)

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_update_job(self, mock_batch_service_client, cloud_run_hook):
        job = Job()
        job.name = f"projects/{PROJECT_ID}/locations/{REGION}/jobs/{JOB_NAME}"

        update_request = UpdateJobRequest()
        update_request.job = job

        cloud_run_hook.update_job(
            job=Job.to_dict(job),
            job_name=JOB_NAME,
            region=REGION,
            project_id=PROJECT_ID,
            use_regional_endpoint=USE_REGIONAL_ENDPOINT,
        )

        cloud_run_hook._client.update_job.assert_called_once_with(update_request)

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_create_job(self, mock_batch_service_client, cloud_run_hook):
        job = Job()

        create_request = CreateJobRequest()
        create_request.job = job
        create_request.job_id = JOB_NAME
        create_request.parent = f"projects/{PROJECT_ID}/locations/{REGION}"

        cloud_run_hook.create_job(
            job=Job.to_dict(job),
            job_name=JOB_NAME,
            region=REGION,
            project_id=PROJECT_ID,
            use_regional_endpoint=USE_REGIONAL_ENDPOINT,
        )

        cloud_run_hook._client.create_job.assert_called_once_with(create_request)

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_execute_job(self, mock_batch_service_client, cloud_run_hook):
        overrides = {
            "container_overrides": [{"args": ["python", "main.py"]}],
            "task_count": 1,
            "timeout": "60s",
        }
        run_job_request = RunJobRequest(
            name=f"projects/{PROJECT_ID}/locations/{REGION}/jobs/{JOB_NAME}", overrides=overrides
        )

        cloud_run_hook.execute_job(
            job_name=JOB_NAME,
            region=REGION,
            project_id=PROJECT_ID,
            overrides=overrides,
            use_regional_endpoint=USE_REGIONAL_ENDPOINT,
        )
        cloud_run_hook._client.run_job.assert_called_once_with(request=run_job_request)

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_list_jobs(self, mock_batch_service_client, cloud_run_hook):
        number_of_jobs = 3
        region = "us-central1"
        project_id = "test_project_id"

        page = self._mock_pager(number_of_jobs)
        mock_batch_service_client.return_value.list_jobs.return_value = page

        jobs_list = cloud_run_hook.list_jobs(
            region=region, project_id=project_id, use_regional_endpoint=USE_REGIONAL_ENDPOINT
        )

        for i in range(number_of_jobs):
            assert jobs_list[i].name == f"name{i}"

        expected_list_jobs_request: ListJobsRequest = ListJobsRequest(
            parent=f"projects/{project_id}/locations/{region}"
        )
        mock_batch_service_client.return_value.list_jobs.assert_called_once_with(
            request=expected_list_jobs_request
        )

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_list_jobs_show_deleted(self, mock_batch_service_client, cloud_run_hook):
        number_of_jobs = 3
        region = "us-central1"
        project_id = "test_project_id"

        page = self._mock_pager(number_of_jobs)
        mock_batch_service_client.return_value.list_jobs.return_value = page

        jobs_list = cloud_run_hook.list_jobs(
            region=region,
            project_id=project_id,
            show_deleted=True,
            use_regional_endpoint=USE_REGIONAL_ENDPOINT,
        )

        for i in range(number_of_jobs):
            assert jobs_list[i].name == f"name{i}"

        expected_list_jobs_request: ListJobsRequest = ListJobsRequest(
            parent=f"projects/{project_id}/locations/{region}", show_deleted=True
        )
        mock_batch_service_client.return_value.list_jobs.assert_called_once_with(
            request=expected_list_jobs_request
        )

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_list_jobs_with_limit(self, mock_batch_service_client, cloud_run_hook):
        number_of_jobs = 3
        limit = 2
        region = "us-central1"
        project_id = "test_project_id"

        page = self._mock_pager(number_of_jobs)
        mock_batch_service_client.return_value.list_jobs.return_value = page

        jobs_list = cloud_run_hook.list_jobs(
            region=region, project_id=project_id, limit=limit, use_regional_endpoint=USE_REGIONAL_ENDPOINT
        )

        assert len(jobs_list) == limit
        for i in range(limit):
            assert jobs_list[i].name == f"name{i}"

    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_list_jobs_with_limit_zero(self, mock_batch_service_client, cloud_run_hook):
        number_of_jobs = 3
        limit = 0
        region = "us-central1"
        project_id = "test_project_id"

        page = self._mock_pager(number_of_jobs)
        mock_batch_service_client.return_value.list_jobs.return_value = page

        jobs_list = cloud_run_hook.list_jobs(
            region=region, project_id=project_id, limit=limit, use_regional_endpoint=USE_REGIONAL_ENDPOINT
        )

        assert len(jobs_list) == 0

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_list_jobs_with_limit_greater_then_range(self, mock_batch_service_client, cloud_run_hook):
        number_of_jobs = 3
        limit = 5
        region = "us-central1"
        project_id = "test_project_id"

        page = self._mock_pager(number_of_jobs)
        mock_batch_service_client.return_value.list_jobs.return_value = page

        jobs_list = cloud_run_hook.list_jobs(
            region=region, project_id=project_id, limit=limit, use_regional_endpoint=USE_REGIONAL_ENDPOINT
        )

        assert len(jobs_list) == number_of_jobs
        for i in range(number_of_jobs):
            assert jobs_list[i].name == f"name{i}"

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_list_jobs_with_limit_less_than_zero(self, mock_batch_service_client, cloud_run_hook):
        number_of_jobs = 3
        limit = -1
        region = "us-central1"
        project_id = "test_project_id"

        page = self._mock_pager(number_of_jobs)
        mock_batch_service_client.return_value.list_jobs.return_value = page

        with pytest.raises(expected_exception=AirflowException):
            cloud_run_hook.list_jobs(
                region=region, project_id=project_id, limit=limit, use_regional_endpoint=USE_REGIONAL_ENDPOINT
            )

    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_delete_job(self, mock_batch_service_client, cloud_run_hook):
        delete_request = DeleteJobRequest(name=f"projects/{PROJECT_ID}/locations/{REGION}/jobs/{JOB_NAME}")

        cloud_run_hook.delete_job(
            job_name=JOB_NAME,
            region=REGION,
            project_id=PROJECT_ID,
            use_regional_endpoint=USE_REGIONAL_ENDPOINT,
        )
        cloud_run_hook._client.delete_job.assert_called_once_with(delete_request)

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_get_conn_with_transport(self, mock_jobs_client):
        """Test that transport parameter is passed to JobsClient."""
        hook = CloudRunHook(transport="rest")
        hook.get_credentials = self.dummy_get_credentials
        hook.get_conn(location=REGION, use_regional_endpoint=USE_REGIONAL_ENDPOINT)

        mock_jobs_client.assert_called_once()
        call_kwargs = mock_jobs_client.call_args[1]
        assert call_kwargs["transport"] == "rest"

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    def test_get_conn_omits_transport_when_none(self, mock_jobs_client):
        """Test that transport is not passed to JobsClient when None."""
        hook = CloudRunHook(transport=None)
        hook.get_credentials = self.dummy_get_credentials
        hook.get_conn(location=REGION, use_regional_endpoint=USE_REGIONAL_ENDPOINT)

        mock_jobs_client.assert_called_once()
        call_kwargs = mock_jobs_client.call_args[1]
        assert "transport" not in call_kwargs

    def _mock_pager(self, number_of_jobs):
        mock_pager = []
        for i in range(number_of_jobs):
            mock_pager.append(Job(name=f"name{i}"))

        return mock_pager


class TestCloudRunAsyncHook:
    @pytest.mark.asyncio
    async def test_get_operation(self):
        region = "us-central1"
        hook = CloudRunAsyncHook()
        hook.get_conn = mock.AsyncMock()
        await hook.get_operation(
            operation_name=OPERATION_NAME, location=region, use_regional_endpoint=USE_REGIONAL_ENDPOINT
        )
        hook.get_conn.return_value.get_operation.assert_called_once_with(
            operations_pb2.GetOperationRequest(name=OPERATION_NAME), timeout=120
        )

    @pytest.mark.asyncio
    @pytest.mark.parametrize("transport", [None, "grpc"])
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsAsyncClient")
    async def test_get_conn_uses_async_client_by_default(self, mock_async_client, transport):
        """Test that get_conn uses JobsAsyncClient (grpc_asyncio) when transport is None or grpc."""
        hook = CloudRunAsyncHook(transport=transport)
        mock_sync_hook = mock.MagicMock(spec=CloudRunHook)
        mock_sync_hook.get_credentials.return_value = "credentials"
        hook.get_sync_hook = mock.AsyncMock(return_value=mock_sync_hook)

        await hook.get_conn(location=REGION, use_regional_endpoint=USE_REGIONAL_ENDPOINT)

        mock_async_client.assert_called_once()
        call_kwargs = mock_async_client.call_args[1]
        assert "transport" not in call_kwargs

    @pytest.mark.asyncio
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
    async def test_get_conn_uses_sync_client_for_rest(self, mock_sync_client):
        """Test that get_conn uses sync JobsClient with REST transport."""
        hook = CloudRunAsyncHook(transport="rest")
        mock_sync_hook = mock.MagicMock(spec=CloudRunHook)
        mock_sync_hook.get_credentials.return_value = "credentials"
        hook.get_sync_hook = mock.AsyncMock(return_value=mock_sync_hook)

        await hook.get_conn(location=REGION, use_regional_endpoint=USE_REGIONAL_ENDPOINT)

        mock_sync_client.assert_called_once()
        call_kwargs = mock_sync_client.call_args[1]
        assert call_kwargs["transport"] == "rest"

    @pytest.mark.asyncio
    @mock.patch("asyncio.to_thread")
    async def test_get_operation_rest_uses_to_thread(self, mock_to_thread):
        """Test that get_operation uses asyncio.to_thread for REST transport."""
        expected_operation = operations_pb2.Operation(name=OPERATION_NAME)
        mock_to_thread.return_value = expected_operation

        hook = CloudRunAsyncHook(transport="rest")
        mock_conn = mock.MagicMock(spec=JobsClient)  # sync client
        hook.get_conn = mock.AsyncMock(return_value=mock_conn)

        result = await hook.get_operation(
            operation_name=OPERATION_NAME, location=REGION, use_regional_endpoint=USE_REGIONAL_ENDPOINT
        )

        mock_to_thread.assert_called_once_with(
            mock_conn.get_operation,
            operations_pb2.GetOperationRequest(name=OPERATION_NAME),
            timeout=120,
        )
        assert result == expected_operation


@pytest.mark.db_test
class TestCloudRunServiceHook:
    def dummy_get_credentials(self):
        pass

    @pytest.fixture
    def cloud_run_service_hook(self):
        region = "us-central1"
        cloud_run_service_hook = CloudRunServiceHook()
        cloud_run_service_hook.get_credentials = self.dummy_get_credentials
        cloud_run_service_hook.client_options = ClientOptions(api_endpoint=f"{region}-run.googleapis.com:443")
        return cloud_run_service_hook

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.ServicesClient")
    def test_get_service(self, mock_batch_service_client, cloud_run_service_hook):
        get_service_request = GetServiceRequest(
            name=f"projects/{PROJECT_ID}/locations/{REGION}/services/{SERVICE_NAME}"
        )

        cloud_run_service_hook.get_service(
            service_name=SERVICE_NAME,
            region=REGION,
            project_id=PROJECT_ID,
            use_regional_endpoint=USE_REGIONAL_ENDPOINT,
        )
        cloud_run_service_hook._client.get_service.assert_called_once_with(get_service_request)

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.ServicesClient")
    def test_create_service(self, mock_batch_service_client, cloud_run_service_hook):
        service = Service()

        create_request = CreateServiceRequest()
        create_request.service = service
        create_request.service_id = SERVICE_NAME
        create_request.parent = f"projects/{PROJECT_ID}/locations/{REGION}"

        cloud_run_service_hook.create_service(
            service=service,
            service_name=SERVICE_NAME,
            region=REGION,
            project_id=PROJECT_ID,
            use_regional_endpoint=USE_REGIONAL_ENDPOINT,
        )
        cloud_run_service_hook._client.create_service.assert_called_once_with(create_request)

    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.ServicesClient")
    def test_delete_service(self, mock_batch_service_client, cloud_run_service_hook):
        delete_request = DeleteServiceRequest(
            name=f"projects/{PROJECT_ID}/locations/{REGION}/services/{SERVICE_NAME}"
        )

        cloud_run_service_hook.delete_service(
            service_name=SERVICE_NAME,
            region=REGION,
            project_id=PROJECT_ID,
            use_regional_endpoint=USE_REGIONAL_ENDPOINT,
        )
        cloud_run_service_hook._client.delete_service.assert_called_once_with(delete_request)


class TestCloudRunServiceAsyncHook:
    def dummy_get_credentials(self):
        pass

    def mock_service(self):
        return mock.AsyncMock()

    @pytest.fixture
    def cloud_run_service_hook(self):
        region = "us-central1"
        cloud_run_service_hook = CloudRunServiceAsyncHook()
        cloud_run_service_hook.get_credentials = self.dummy_get_credentials
        cloud_run_service_hook.client_options = ClientOptions(api_endpoint=f"{region}-run.googleapis.com:443")
        return cloud_run_service_hook

    @pytest.mark.asyncio
    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.CloudRunServiceAsyncHook.get_conn")
    async def test_create_service(self, mock_client, cloud_run_service_hook):
        mock_env_client = AsyncMock(ServicesAsyncClient)
        mock_client.return_value = mock_env_client

        await cloud_run_service_hook.create_service(
            service_name=SERVICE_NAME,
            service=Service(),
            region=REGION,
            project_id=PROJECT_ID,
            use_regional_endpoint=USE_REGIONAL_ENDPOINT,
        )

        expected_request = CreateServiceRequest(
            service=Service(),
            service_id=SERVICE_NAME,
            parent=f"projects/{PROJECT_ID}/locations/{REGION}",
        )

        mock_client.return_value.create_service.assert_called_once_with(expected_request)

    @pytest.mark.asyncio
    @mock.patch(
        "airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.__init__",
        new=mock_base_gcp_hook_default_project_id,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.CloudRunServiceAsyncHook.get_conn")
    async def test_delete_service(self, mock_client, cloud_run_service_hook):
        mock_env_client = AsyncMock(ServicesAsyncClient)
        mock_client.return_value = mock_env_client

        await cloud_run_service_hook.delete_service(
            service_name=SERVICE_NAME,
            region=REGION,
            project_id=PROJECT_ID,
            use_regional_endpoint=USE_REGIONAL_ENDPOINT,
        )

        expected_request = DeleteServiceRequest(
            name=f"projects/{PROJECT_ID}/locations/{REGION}/services/{SERVICE_NAME}",
        )

        mock_client.return_value.delete_service.assert_called_once_with(expected_request)
