<html><head><meta name="color-scheme" content="light dark"></head><body><pre style="word-wrap: break-word; white-space: pre-wrap;"># -*- coding: utf-8 -*-

from pmdarima.arima.auto import StepwiseContext, auto_arima
from pmdarima.arima._context import ContextStore, ContextType
from pmdarima.arima import _context as context_lib
from pmdarima.datasets import load_lynx, load_wineind

from unittest import mock
import threading
import collections
import pytest
import warnings

lynx = load_lynx()
wineind = load_wineind()


# test StepwiseContext parameter validation
@pytest.mark.parametrize(
    'max_steps,max_dur', [
        pytest.param(-1, None),
        pytest.param(0, None),
        pytest.param(1001, None),
        pytest.param(1100, None),
        pytest.param(None, -1),
        pytest.param(None, 0),
    ])
def test_stepwise_context_args(max_steps, max_dur):
    with pytest.raises(ValueError):
        StepwiseContext(max_steps=max_steps, max_dur=max_dur)


# test auto_arima stepwise run with  StepwiseContext
def test_auto_arima_with_stepwise_context():
    samp = lynx[:8]
    with StepwiseContext(max_steps=3, max_dur=30):
        with pytest.warns(UserWarning) as uw:
            auto_arima(samp, suppress_warnings=False, stepwise=True,
                       error_action='ignore')

            # assert that max_steps were taken
            assert any(str(w.message)
                       .startswith('stepwise search has reached the '
                                   'maximum number of tries') for w in uw)


# test effective context info in nested context scenario
def test_nested_context():
    ctx1_data = {'max_dur': 30}
    ctx2_data = {'max_steps': 5}
    ctx1 = StepwiseContext(**ctx1_data)
    ctx2 = StepwiseContext(**ctx2_data)

    with ctx1, ctx2:
        effective_ctx_data = ContextStore.get_or_empty(
            ContextType.STEPWISE)
        expected_ctx_data = ctx1_data.copy()
        expected_ctx_data.update(ctx2_data)

        assert all(effective_ctx_data[key] == expected_ctx_data[key]
                   for key in expected_ctx_data.keys())

        assert all(effective_ctx_data[key] == expected_ctx_data[key]
                   for key in effective_ctx_data.keys())


# Test a context honors the max duration
def test_max_dur():
    # set arbitrarily low to guarantee will always pass after one iter
    with StepwiseContext(max_dur=.5), \
            pytest.warns(UserWarning) as uw:

        auto_arima(lynx, stepwise=True)
        # assert that max_dur was reached
        assert any(str(w.message)
                   .startswith('early termination') for w in uw)


# Test that a context after the first will not inherit the first's attrs
def test_subsequent_contexts():
    # Force a very fast fit
    with StepwiseContext(max_dur=.5), \
            pytest.warns(UserWarning):
        auto_arima(lynx, stepwise=True)

    # Out of scope, should be EMPTY
    ctx = ContextStore.get_or_empty(ContextType.STEPWISE)
    assert ctx.get_type() is ContextType.EMPTY

    # Now show that we DON'T hit early termination by time here
    with StepwiseContext(max_steps=100), \
            warnings.catch_warnings(record=True) as uw:

        ctx = ContextStore.get_or_empty(ContextType.STEPWISE)
        assert ctx.get_type() is ContextType.STEPWISE
        assert ctx.max_dur is None

        auto_arima(lynx, stepwise=True)
        # assert that max_dur was NOT reached
        if uw:
            assert not any(str(w.message)
                           .startswith('early termination') for w in uw)


# test param validation of ContextStore's add, get and remove members
def test_add_get_remove_context_args():
    with pytest.raises(ValueError):
        ContextStore._add_context(None)

    with pytest.raises(ValueError):
        ContextStore._remove_context(None)

    with pytest.raises(ValueError):
        ContextStore.get_context(None)


def test_context_store_accessible_across_threads():
    # Make sure it's completely empty by patching it
    d = {}
    with mock.patch('pmdarima.arima._context._ctx.store', d):

        # pushes onto the Context Store
        def push(n):
            # n is the number of times this has been executed before. If &gt; 0,
            # assert there is a context there
            if n &gt; 0:
                assert len(context_lib._ctx.store[ContextType.STEPWISE]) == n
            else:
                context_lib._ctx.store[ContextType.STEPWISE] = \
                    collections.deque()

            new_ctx = StepwiseContext()
            context_lib._ctx.store[ContextType.STEPWISE].append(new_ctx)
            assert len(context_lib._ctx.store[ContextType.STEPWISE]) == n + 1

        for i in range(5):
            t = threading.Thread(target=push, args=(i,))
            t.start()
            t.join(1)  # it shouldn't take even close to this time

    # Assert the mock has lifted
    assert context_lib._ctx.store is not d
</pre></body></html>