/*
 * This file is part of LibEuFin.
 * Copyright (C) 2024, 2025, 2026 Taler Systems S.A.

 * LibEuFin is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation; either version 3, or
 * (at your option) any later version.

 * LibEuFin is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.

 * You should have received a copy of the GNU Affero General Public
 * License along with LibEuFin; see the file COPYING.  If not, see
 * <http://www.gnu.org/licenses/>
 */

package tech.libeufin.common.api

import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.server.application.*
import io.ktor.server.engine.*
import io.ktor.server.cio.*
import io.ktor.server.plugins.*
import io.ktor.server.plugins.calllogging.*
import io.ktor.server.plugins.contentnegotiation.*
import io.ktor.server.plugins.forwardedheaders.*
import io.ktor.server.plugins.statuspages.*
import io.ktor.server.plugins.callid.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.utils.io.*
import io.ktor.util.*
import io.ktor.util.pipeline.*
import io.ktor.http.content.*
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.json.Json
import org.postgresql.util.PSQLState
import org.slf4j.Logger
import org.slf4j.event.Level
import tech.libeufin.common.*
import tech.libeufin.common.db.SERIALIZATION_ERROR
import java.net.InetAddress
import java.sql.SQLException
import java.util.zip.DataFormatException
import java.util.zip.Inflater

/** Used to store the raw body */
private val RAW_BODY = AttributeKey<ByteArray>("RAW_BODY")

/** Used to set custom body limit */
val BODY_LIMIT = AttributeKey<Int>("BODY_LIMIT")

/** Get call raw body */
val ApplicationCall.rawBody: ByteArray get() = attributes.getOrNull(RAW_BODY) ?: ByteArray(0)

/**
 * This plugin apply Taler specific logic
 * It checks for body length limit and inflates the requests that have "Content-Encoding: deflate"
 * It logs incoming requests and their details
 */
fun talerPlugin(logger: Logger): ApplicationPlugin<Unit> {
    return createApplicationPlugin("TalerPlugin") {
        onCall { call ->
            // Handle CORS
            call.response.header(HttpHeaders.AccessControlAllowOrigin, "*")
            // Handle CORS preflight
            if (call.request.httpMethod == HttpMethod.Options) {
                call.response.header(HttpHeaders.AccessControlAllowHeaders, "*")
                call.response.header(HttpHeaders.AccessControlAllowMethods, "*")
                call.respond(HttpStatusCode.NoContent)
                return@onCall
            }

            // Log incoming transaction
            val requestCall = buildString {
                val path = call.request.path()
                append(call.request.httpMethod.value)
                append(' ')
                append(call.request.path())
                val query = call.request.queryString()
                if (query.isNotEmpty()) {
                    append('?')
                    append(query)
                }
            }
            logger.info(requestCall)
        }
        onCallReceive { call ->
            val bodyLimit = call.attributes.getOrNull(BODY_LIMIT) ?: MAX_BODY_LENGTH
            // Check content length if present and wellformed
            val contentLenght = call.request.headers[HttpHeaders.ContentLength]?.toIntOrNull()
            if (contentLenght != null && contentLenght > bodyLimit)
                throw bodyOverflow("Body is suspiciously big > ${bodyLimit}B")

            // Else check while reading and decompressing the body
            transformBody { body ->
                val bytes = ByteArray(bodyLimit + 1)
                var read = 0
                when (val encoding = call.request.headers[HttpHeaders.ContentEncoding])  {
                    "deflate" -> {
                        // Decompress and check decompressed length
                        val inflater = Inflater()
                        while (!body.isClosedForRead) {
                            body.read { buf ->
                                inflater.setInput(buf)
                                try {
                                    read += inflater.inflate(bytes, read, bytes.size - read)
                                } catch (e: DataFormatException) {
                                    logger.error("Deflated request failed to inflate: ${e.message}")
                                    throw badRequest(
                                        "Could not inflate request",
                                        TalerErrorCode.GENERIC_COMPRESSION_INVALID
                                    )
                                }
                            }
                            if (read > bodyLimit)
                                throw bodyOverflow("Decompressed body is suspiciously big > ${bodyLimit}B")
                        }
                    }
                    null -> {
                        // Check body length
                        while (true) {
                            val new = body.readAvailable(bytes, read, bytes.size - read)
                            if (new == -1) break // Channel is closed
                            read += new
                            if (read > bodyLimit)
                                throw bodyOverflow("Body is suspiciously big > ${bodyLimit}B")
                        }
                    } 
                    else -> throw unsupportedMediaType(
                        "Content encoding '$encoding' not supported, expected plain or deflate",
                        TalerErrorCode.GENERIC_COMPRESSION_INVALID
                    )
                }
                logger.trace {
                    "request ${bytes.sliceArray(0 until read).asUtf8()}"
                }
                call.attributes.put(RAW_BODY, bytes)
                ByteReadChannel(bytes, 0, read)
            }
        }
    }
}

/** Set up web server handlers for a Taler API */
fun Application.talerApi(logger: Logger, routes: Routing.() -> Unit) {
    install(CallId) {
        generate(10, "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
        verify { true }
    }
    install(CallLogging) {
        callIdMdc("call-id")
        level = Level.INFO
        this.logger = logger
        format { call ->
            val status = call.response.status()
            val msg = call.logMsg()
            if (msg != null) {
                "${status?.value} ${call.processingTimeMillis()}ms: $msg"
            } else {
                "${status?.value} ${call.processingTimeMillis()}ms"
            }
        }
    }
    install(XForwardedHeaders)
    install(talerPlugin(logger))
    install(IgnoreTrailingSlash)
    install(ContentNegotiation) {
        json(Json {
            @OptIn(ExperimentalSerializationApi::class)
            explicitNulls = false
            encodeDefaults = true
            ignoreUnknownKeys = true
        })
    }
    install(StatusPages) {
        status(HttpStatusCode.NotFound) { call, status ->
            call.err(
                status,
                "There is no endpoint defined for the URL provided by the client. Check if you used the correct URL and/or file a report with the developers of the client software.",
                TalerErrorCode.GENERIC_ENDPOINT_UNKNOWN,
                null
            )
        }
        status(HttpStatusCode.MethodNotAllowed) { call, status ->
            call.err(
                status,
                "The HTTP method used is invalid for this endpoint. This is likely a bug in the client implementation. Check if you are using the latest available version and/or file a report with the developers.",
                TalerErrorCode.GENERIC_METHOD_INVALID,
                null
            )
        }
        exception<Exception> { call, cause ->
            logger.debug("", cause)
            when (cause) {
                is ApiException -> call.err(cause, null)
                is SQLException -> {
                    if (SERIALIZATION_ERROR.contains(cause.sqlState)) {
                        call.err(
                            HttpStatusCode.InternalServerError,
                            "Transaction serialization failure",
                            TalerErrorCode.BANK_SOFT_EXCEPTION,
                            cause
                        )
                    } else {
                        call.err(
                            HttpStatusCode.InternalServerError,
                            "Unexpected sql error with state ${cause.sqlState}",
                            TalerErrorCode.BANK_UNMANAGED_EXCEPTION,
                            cause
                        )
                    }
                }
                is BadRequestException -> {
                    /**
                     * NOTE: extracting the root cause helps with JSON error messages,
                     * because they mention the particular way they are invalid, but OTOH
                     * it loses (by getting null) other error messages, like for example
                     * the one from MissingRequestParameterException.  Therefore, in order
                     * to get the most detailed message, we must consider BOTH sides:
                     * the 'cause' AND its root cause!
                     */
                    var rootCause: Throwable? = cause.cause
                    while (rootCause?.cause != null)
                        rootCause = rootCause.cause
                    // Telling apart invalid JSON vs missing parameter vs invalid parameter.
                    val errorCode = when {
                        cause is MissingRequestParameterException ->
                            TalerErrorCode.GENERIC_PARAMETER_MISSING
                        cause is ParameterConversionException ->
                            TalerErrorCode.GENERIC_PARAMETER_MALFORMED
                        rootCause is CommonError -> when (rootCause) {
                            is CommonError.AmountFormat -> TalerErrorCode.BANK_BAD_FORMAT_AMOUNT
                            is CommonError.AmountNumberTooBig -> TalerErrorCode.BANK_NUMBER_TOO_BIG
                            is CommonError.Payto -> TalerErrorCode.GENERIC_JSON_INVALID
                        }
                        else -> TalerErrorCode.GENERIC_JSON_INVALID
                    }
                    call.err(
                        HttpStatusCode.BadRequest,
                        rootCause?.message,
                        errorCode,
                        null
                    )
                }
                is CommonError -> {
                    val errorCode = when (cause) {
                        is CommonError.AmountFormat -> TalerErrorCode.BANK_BAD_FORMAT_AMOUNT
                        is CommonError.AmountNumberTooBig -> TalerErrorCode.BANK_NUMBER_TOO_BIG
                        is CommonError.Payto -> TalerErrorCode.GENERIC_JSON_INVALID
                    }
                    call.err(
                        HttpStatusCode.BadRequest,
                        cause.message,
                        errorCode,
                        null
                    )
                }
                else -> {
                    call.err(
                        HttpStatusCode.InternalServerError,
                        cause.message,
                        TalerErrorCode.BANK_UNMANAGED_EXCEPTION,
                        cause
                    )
                }
            }
        }
    }
    val phase = PipelinePhase("phase")
    sendPipeline.insertPhaseBefore(ApplicationSendPipeline.Engine, phase)
    sendPipeline.intercept(phase) { response ->
        if (logger.isTraceEnabled) {
            if (response is OutgoingContent.ByteArrayContent) {
                logger.trace("response ${String(response.bytes())}")
            }
        }
        
    }
    routing { routes() }
}

// Dirty local variable to stop the server in test TODO remove this ugly hack
var engine: ApplicationEngine? = null

fun serve(cfg: tech.libeufin.common.ServerConfig, logger: Logger, api: Application.() -> Unit) {
    val server = embeddedServer(CIO,
        configure = {
            when (cfg) {
                is ServerConfig.Tcp -> {
                    for (addr in InetAddress.getAllByName(cfg.addr)) {
                        logger.info("Listening on ${addr.hostAddress}:${cfg.port}")
                        connector {
                            port = cfg.port
                            host = addr.hostAddress
                        }
                    }
                }
                is ServerConfig.Unix -> {
                    logger.info("Listening on ${cfg.path}")
                    unixConnector(cfg.path.toString())
                }
            }
        },
        module = api
    )
    engine = server.engine
    server.start(wait = true)
}