# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the MIT License.  See the LICENSE file in the root of this
# repository for complete details.

from __future__ import absolute_import, division, print_function

import collections
import logging
import logging.config
import os

import pytest

from pretend import call_recorder

from structlog import ReturnLogger, configure, get_logger, reset_defaults
from structlog.dev import ConsoleRenderer
from structlog.exceptions import DropEvent
from structlog.processors import JSONRenderer
from structlog.stdlib import (
    CRITICAL, WARN, BoundLogger, LoggerFactory, PositionalArgumentsFormatter,
    ProcessorFormatter, _FixedFindCallerLogger, add_log_level, add_logger_name,
    filter_by_level, render_to_log_kwargs
)

from .additional_frame import additional_frame
from .utils import py3_only


def build_bl(logger=None, processors=None, context=None):
    """
    Convenience function to build BoundLogger with sane defaults.
    """
    return BoundLogger(
        logger or ReturnLogger(),
        processors,
        {}
    )


def return_method_name(_, method_name, __):
    """
    A final renderer that returns the name of the logging method.
    """
    return method_name


class TestLoggerFactory(object):
    def setup_method(self, method):
        """
        The stdlib logger factory modifies global state to fix caller
        identification.
        """
        self.original_logger = logging.getLoggerClass()

    def teardown_method(self, method):
        logging.setLoggerClass(self.original_logger)

    def test_deduces_correct_name(self):
        """
        The factory isn't called directly but from structlog._config so
        deducing has to be slightly smarter.
        """
        assert 'tests.additional_frame' == (
            additional_frame(LoggerFactory()).name
        )
        assert 'tests.test_stdlib' == LoggerFactory()().name

    def test_ignores_frames(self):
        """
        The name guesser walks up the frames until it reaches a frame whose
        name is not from structlog or one of the configurable other names.
        """
        assert '__main__' == additional_frame(LoggerFactory(
            ignore_frame_names=["tests.", "_pytest.", "pluggy"])
        ).name

    def test_deduces_correct_caller(self):
        logger = _FixedFindCallerLogger('test')
        file_name, line_number, func_name = logger.findCaller()[:3]
        assert file_name == os.path.realpath(__file__)
        assert func_name == 'test_deduces_correct_caller'

    @py3_only
    def test_stack_info(self):
        logger = _FixedFindCallerLogger('test')
        testing, is_, fun, stack_info = logger.findCaller(stack_info=True)
        assert 'testing, is_, fun' in stack_info

    @py3_only
    def test_no_stack_info_by_default(self):
        logger = _FixedFindCallerLogger('test')
        testing, is_, fun, stack_info = logger.findCaller()
        assert None is stack_info

    def test_find_caller(self, monkeypatch):
        logger = LoggerFactory()()
        log_handle = call_recorder(lambda x: None)
        monkeypatch.setattr(logger, 'handle', log_handle)
        logger.error('Test')
        log_record = log_handle.calls[0].args[0]
        assert log_record.funcName == 'test_find_caller'
        assert log_record.name == __name__
        assert log_record.filename == os.path.basename(__file__)

    def test_sets_correct_logger(self):
        assert logging.getLoggerClass() is logging.Logger
        LoggerFactory()
        assert logging.getLoggerClass() is _FixedFindCallerLogger

    def test_positional_argument_avoids_guessing(self):
        """
        If a positional argument is passed to the factory, it's used as the
        name instead of guessing.
        """
        lf = LoggerFactory()("foo")

        assert "foo" == lf.name


class TestFilterByLevel(object):
    def test_filters_lower_levels(self):
        logger = logging.Logger(__name__)
        logger.setLevel(CRITICAL)
        with pytest.raises(DropEvent):
            filter_by_level(logger, 'warn', {})

    def test_passes_higher_levels(self):
        logger = logging.Logger(__name__)
        logger.setLevel(WARN)
        event_dict = {'event': 'test'}
        assert event_dict is filter_by_level(logger, 'warn', event_dict)
        assert event_dict is filter_by_level(logger, 'error', event_dict)
        assert event_dict is filter_by_level(logger, 'exception', event_dict)


class TestBoundLogger(object):
    @pytest.mark.parametrize(('method_name'), [
        'debug', 'info', 'warning', 'error', 'critical',
    ])
    def test_proxies_to_correct_method(self, method_name):
        """
        The basic proxied methods are proxied to the correct counterparts.
        """
        bl = BoundLogger(ReturnLogger(), [return_method_name], {})
        assert method_name == getattr(bl, method_name)('event')

    def test_proxies_exception(self):
        """
        BoundLogger.exception is proxied to Logger.error.
        """
        bl = BoundLogger(ReturnLogger(), [return_method_name], {})
        assert "error" == bl.exception("event")

    def test_proxies_log(self):
        """
        BoundLogger.exception.log() is proxied to the apropriate method.
        """
        bl = BoundLogger(ReturnLogger(), [return_method_name], {})
        assert "critical" == bl.log(50, "event")
        assert "debug" == bl.log(10, "event")

    def test_positional_args_proxied(self):
        """
        Positional arguments supplied are proxied as kwarg.
        """
        bl = BoundLogger(ReturnLogger(), [], {})
        args, kwargs = bl.debug('event', 'foo', bar='baz')
        assert 'baz' == kwargs.get('bar')
        assert ('foo',) == kwargs.get('positional_args')

    @pytest.mark.parametrize('method_name,method_args', [
        ('addHandler', [None]),
        ('removeHandler', [None]),
        ('hasHandlers', None),
        ('callHandlers', [None]),
        ('handle', [None]),
        ('setLevel', [None]),
        ('getEffectiveLevel', None),
        ('isEnabledFor', [None]),
        ('findCaller', None),
        ('makeRecord', ['name', 'debug', 'test_func', '1',
                        'test msg', ['foo'], False]),
        ('getChild', [None]),
        ])
    def test_stdlib_passthrough_methods(self, method_name, method_args):
        """
        stdlib logger methods are also available in stdlib BoundLogger.
        """
        called_stdlib_method = [False]

        def validate(*args, **kw):
            called_stdlib_method[0] = True

        stdlib_logger = logging.getLogger('Test')
        stdlib_logger_method = getattr(stdlib_logger, method_name, None)
        if stdlib_logger_method:
            setattr(stdlib_logger, method_name, validate)
            bl = BoundLogger(stdlib_logger, [], {})
            bound_logger_method = getattr(bl, method_name)
            assert bound_logger_method is not None
            if method_args:
                bound_logger_method(*method_args)
            else:
                bound_logger_method()
            assert called_stdlib_method[0] is True

    def test_exception_exc_info(self):
        """
        BoundLogger.exception sets exc_info=True.
        """
        bl = BoundLogger(ReturnLogger(), [], {})

        assert (
            (),
            {"exc_info": True, "event": "event"}
        ) == bl.exception("event")

    def test_exception_exc_info_override(self):
        """
        If *exc_info* is password to exception, it's used.
        """
        bl = BoundLogger(ReturnLogger(), [], {})

        assert (
            (),
            {"exc_info": 42, "event": "event"}
        ) == bl.exception("event", exc_info=42)


class TestPositionalArgumentsFormatter(object):
    def test_formats_tuple(self):
        """
        Positional arguments as simple types are rendered.
        """
        formatter = PositionalArgumentsFormatter()
        event_dict = formatter(None, None, {'event': '%d %d %s',
                                            'positional_args': (1, 2, 'test')})
        assert '1 2 test' == event_dict['event']
        assert 'positional_args' not in event_dict

    def test_formats_dict(self):
        """
        Positional arguments as dict are rendered.
        """
        formatter = PositionalArgumentsFormatter()
        event_dict = formatter(None, None, {'event': '%(foo)s bar',
                                            'positional_args': (
                                                {'foo': 'bar'},)})
        assert 'bar bar' == event_dict['event']
        assert 'positional_args' not in event_dict

    def test_positional_args_retained(self):
        """
        Positional arguments are retained if remove_positional_args
        argument is set to False.
        """
        formatter = PositionalArgumentsFormatter(remove_positional_args=False)
        positional_args = (1, 2, 'test')
        event_dict = formatter(
            None, None,
            {'event': '%d %d %s', 'positional_args': positional_args})
        assert 'positional_args' in event_dict
        assert positional_args == event_dict['positional_args']

    def test_nop_no_args(self):
        """
        If no positional args are passed, nothing happens.
        """
        formatter = PositionalArgumentsFormatter()
        assert {} == formatter(None, None, {})

    def test_args_removed_if_empty(self):
        """
        If remove_positional_args is True and positional_args is (), still
        remove them.

        Regression test for https://github.com/hynek/structlog/issues/82.
        """
        formatter = PositionalArgumentsFormatter()

        assert {} == formatter(None, None, {"positional_args": ()})


class TestAddLogLevel(object):
    def test_log_level_added(self):
        """
        The log level is added to the event dict.
        """
        event_dict = add_log_level(None, 'error', {})
        assert 'error' == event_dict['level']

    def test_log_level_alias_normalized(self):
        """
        The normalized name of the log level is added to the event dict.
        """
        event_dict = add_log_level(None, 'warn', {})
        assert 'warning' == event_dict['level']


@pytest.fixture
def log_record():
    """
    A LogRecord factory.
    """
    def create_log_record(**kwargs):
        defaults = {
            "name": "sample-name",
            "level": logging.INFO,
            "pathname": None,
            "lineno": None,
            "msg": "sample-message",
            "args": [],
            "exc_info": None,
        }
        defaults.update(kwargs)
        return logging.LogRecord(**defaults)
    return create_log_record


class TestAddLoggerName(object):
    def test_logger_name_added(self):
        """
        The logger name is added to the event dict.
        """
        name = "sample-name"
        logger = logging.getLogger(name)
        event_dict = add_logger_name(logger, None, {})
        assert name == event_dict["logger"]

    def test_logger_name_added_with_record(self, log_record):
        """
        The logger name is deduced from the LogRecord if provided.
        """
        name = "sample-name"
        record = log_record(name=name)
        event_dict = add_logger_name(None, None, {"_record": record})
        assert name == event_dict["logger"]


class TestRenderToLogKW(object):
    def test_default(self):
        """
        Translates `event` to `msg` and handles otherwise empty `event_dict`s.
        """
        d = render_to_log_kwargs(None, None, {"event": "message"})

        assert {"msg": "message", "extra": {}} == d

    def test_add_extra_event_dict(self, event_dict):
        """
        Adds all remaining data from `event_dict` into `extra`.
        """
        event_dict["event"] = "message"
        d = render_to_log_kwargs(None, None, event_dict)

        assert {"msg": "message", "extra": event_dict} == d


@pytest.fixture
def configure_for_pf():
    """
    Configure structlog to use ProcessorFormatter.

    Reset both structlog and logging setting after the test.
    """
    configure(
        processors=[
            add_log_level,
            ProcessorFormatter.wrap_for_formatter,
        ],
        logger_factory=LoggerFactory(),
        wrapper_class=BoundLogger,
    )

    yield

    logging.basicConfig()
    reset_defaults()


def configure_logging(pre_chain):
    """
    Configure logging to use ProcessorFormatter.
    """
    return logging.config.dictConfig({
        "version": 1,
        "disable_existing_loggers": False,
        "formatters": {
            "plain": {
                "()": ProcessorFormatter,
                "processor": ConsoleRenderer(colors=False),
                "foreign_pre_chain": pre_chain,
                "format": "%(message)s [in %(funcName)s]"
            }
        },
        "handlers": {
            "default": {
                "level": "DEBUG",
                "class": "logging.StreamHandler",
                "formatter": "plain",
            },
        },
        "loggers": {
            "": {
                "handlers": ["default"],
                "level": "DEBUG",
                "propagate": True,
            },
        }
    })


class TestProcessorFormatter(object):
    """
    These are all integration tests because they're all about integration.
    """
    def test_foreign_delegate(self, configure_for_pf, capsys):
        """
        If foreign_pre_chain is None, non-structlog log entries are delegated
        to logging.
        """
        configure_logging(None)
        configure(
            processors=[
                ProcessorFormatter.wrap_for_formatter,
            ],
            logger_factory=LoggerFactory(),
            wrapper_class=BoundLogger,
        )

        logging.getLogger().warning("foo")

        assert (
            "",
            "foo [in test_foreign_delegate]\n",
        ) == capsys.readouterr()

    def test_clears_args(self, capsys, configure_for_pf):
        """
        We render our log records before sending it back to logging.  Therefore
        we must clear `LogRecord.args` otherwise the user gets an
        `TypeError: not all arguments converted during string formatting.` if
        they use positional formatting in stdlib logging.
        """
        configure_logging(None)

        logging.getLogger().warning("hello %s.", "world")

        assert (
            "",
            "hello world. [in test_clears_args]\n",
        ) == capsys.readouterr()

    def test_foreign_pre_chain(self, configure_for_pf, capsys):
        """
        If foreign_pre_chain is an iterable, it's used to pre-process
        non-structlog log entries.
        """
        configure_logging((add_log_level,))
        configure(
            processors=[
                ProcessorFormatter.wrap_for_formatter,
            ],
            logger_factory=LoggerFactory(),
            wrapper_class=BoundLogger,
        )

        logging.getLogger().warning("foo")

        assert (
            "",
            "[warning  ] foo [in test_foreign_pre_chain]\n",
        ) == capsys.readouterr()

    def test_foreign_pre_chain_add_logger_name(self, configure_for_pf, capsys):
        """
        foreign_pre_chain works with add_logger_name processor.
        """
        configure_logging((add_logger_name,))
        configure(
            processors=[
                ProcessorFormatter.wrap_for_formatter,
            ],
            logger_factory=LoggerFactory(),
            wrapper_class=BoundLogger,
        )

        logging.getLogger("sample-name").warning("foo")

        assert (
            "",
            "foo                            [sample-name]  [in test_foreign_pr"
            "e_chain_add_logger_name]\n",
        ) == capsys.readouterr()

    def test_foreign_pre_chain_gets_exc_info(self, configure_for_pf, capsys):
        """
        If non-structlog record contains exc_info, foreign_pre_chain functions
        have access to it.
        """
        test_processor = call_recorder(lambda l, m, event_dict: event_dict)
        configure_logging((test_processor,))
        configure(
            processors=[
                ProcessorFormatter.wrap_for_formatter,
            ],
            logger_factory=LoggerFactory(),
            wrapper_class=BoundLogger,
        )

        try:
            raise RuntimeError("oh noo")
        except Exception:
            logging.getLogger().exception("okay")

        event_dict = test_processor.calls[0].args[2]
        assert "exc_info" in event_dict
        assert isinstance(event_dict["exc_info"], tuple)

    def test_other_handlers_get_original_record(self, configure_for_pf,
                                                capsys):
        """
        Logging handlers that come after the handler with ProcessorFormatter
        should receive original, unmodified record.
        """
        configure_logging(None)

        handler1 = logging.StreamHandler()
        handler1.setFormatter(ProcessorFormatter(JSONRenderer()))
        handler2 = type("", (), {})()
        handler2.handle = call_recorder(lambda record: None)
        handler2.level = logging.INFO
        logger = logging.getLogger()
        logger.addHandler(handler1)
        logger.addHandler(handler2)

        logger.info("meh")

        assert 1 == len(handler2.handle.calls)
        handler2_record = handler2.handle.calls[0].args[0]
        assert "meh" == handler2_record.msg

    @pytest.mark.parametrize("keep", [True, False])
    def test_formatter_unsets_exc_info(self, configure_for_pf, capsys, keep):
        """
        Stack traces doesn't get printed outside of the json document when
        keep_exc_info are set to False but preserved if set to True.
        """
        configure_logging(None)
        logger = logging.getLogger()

        def format_exc_info_fake(logger, name, event_dict):
            event_dict = collections.OrderedDict(event_dict)
            del event_dict["exc_info"]
            event_dict["exception"] = "Exception!"
            return event_dict

        formatter = ProcessorFormatter(
            processor=JSONRenderer(),
            keep_stack_info=keep,
            keep_exc_info=keep,
            foreign_pre_chain=[format_exc_info_fake],
        )
        logger.handlers[0].setFormatter(formatter)

        try:
            raise RuntimeError("oh noo")
        except Exception:
            logging.getLogger().exception("seen worse")

        out, err = capsys.readouterr()
        assert "" == out
        if keep is False:
            assert (
                '{"event": "seen worse", "exception": "Exception!"}\n'
            ) == err
        else:
            assert "Traceback (most recent call last):" in err

    @pytest.mark.parametrize("keep", [True, False])
    @py3_only
    def test_formatter_unsets_stack_info(self, configure_for_pf, capsys, keep):
        """
        Stack traces doesn't get printed outside of the json document when
        keep_stack_info are set to False but preserved if set to True.
        """
        configure_logging(None)
        logger = logging.getLogger()

        formatter = ProcessorFormatter(
            processor=JSONRenderer(),
            keep_stack_info=keep,
            keep_exc_info=keep,
            foreign_pre_chain=[],
        )
        logger.handlers[0].setFormatter(formatter)

        logging.getLogger().warning("have a stack trace", stack_info=True)

        out, err = capsys.readouterr()
        assert "" == out
        if keep is False:
            assert 1 == err.count("Stack (most recent call last):")
        else:
            assert 2 == err.count("Stack (most recent call last):")

    def test_native(self, configure_for_pf, capsys):
        """
        If the log entry comes from structlog, it's unpackaged and processed.
        """
        configure_logging(None)

        get_logger().warning("foo")

        assert (
            "",
            "[warning  ] foo [in test_native]\n",
        ) == capsys.readouterr()
