<html><head><meta name="color-scheme" content="light dark"></head><body><pre style="word-wrap: break-word; white-space: pre-wrap;">from __future__ import annotations

from datetime import datetime
import json
import threading
from typing import TYPE_CHECKING

from optuna import logging
from optuna._imports import _LazyImport
from optuna.distributions import distribution_to_json
from optuna.distributions import json_to_distribution
from optuna.exceptions import DuplicatedStudyError
from optuna.exceptions import UpdateFinishedTrialError
from optuna.storages import BaseStorage
from optuna.study._study_direction import StudyDirection
from optuna.trial._frozen import FrozenTrial
from optuna.trial._state import TrialState


if TYPE_CHECKING:
    import grpc

    from optuna.storages._grpc.auto_generated import api_pb2
    from optuna.storages._grpc.auto_generated import api_pb2_grpc
else:
    api_pb2 = _LazyImport("optuna.storages._grpc.auto_generated.api_pb2")
    api_pb2_grpc = _LazyImport("optuna.storages._grpc.auto_generated.api_pb2_grpc")
    grpc = _LazyImport("grpc")


_logger = logging.get_logger(__name__)
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f"


class OptunaStorageProxyService(api_pb2_grpc.StorageServiceServicer):
    def __init__(self, storage: BaseStorage) -&gt; None:
        self._backend = storage
        self._lock = threading.Lock()

    def CreateNewStudy(
        self,
        request: api_pb2.CreateNewStudyRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.CreateNewStudyReply:
        directions = [
            StudyDirection.MINIMIZE if d == api_pb2.MINIMIZE else StudyDirection.MAXIMIZE
            for d in request.directions
        ]
        study_name = request.study_name

        try:
            study_id = self._backend.create_new_study(directions=directions, study_name=study_name)
        except DuplicatedStudyError as e:
            context.abort(code=grpc.StatusCode.ALREADY_EXISTS, details=str(e))
        return api_pb2.CreateNewStudyReply(study_id=study_id)

    def DeleteStudy(
        self,
        request: api_pb2.DeleteStudyRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.DeleteStudyReply:
        study_id = request.study_id
        try:
            self._backend.delete_study(study_id)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        return api_pb2.DeleteStudyReply()

    def SetStudyUserAttribute(
        self,
        request: api_pb2.SetStudyUserAttributeRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.SetStudyUserAttributeReply:
        try:
            self._backend.set_study_user_attr(
                request.study_id, request.key, json.loads(request.value)
            )
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        return api_pb2.SetStudyUserAttributeReply()

    def SetStudySystemAttribute(
        self,
        request: api_pb2.SetStudySystemAttributeRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.SetStudySystemAttributeReply:
        try:
            self._backend.set_study_system_attr(
                request.study_id, request.key, json.loads(request.value)
            )
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        return api_pb2.SetStudySystemAttributeReply()

    def GetStudyIdFromName(
        self,
        request: api_pb2.GetStudyIdFromNameRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.GetStudyIdFromNameReply:
        try:
            study_id = self._backend.get_study_id_from_name(request.study_name)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        return api_pb2.GetStudyIdFromNameReply(study_id=study_id)

    def GetStudyNameFromId(
        self,
        request: api_pb2.GetStudyNameFromIdRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.GetStudyNameFromIdReply:
        study_id = request.study_id

        try:
            name = self._backend.get_study_name_from_id(study_id)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        assert name is not None
        return api_pb2.GetStudyNameFromIdReply(study_name=name)

    def GetStudyDirections(
        self,
        request: api_pb2.GetStudyDirectionsRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.GetStudyDirectionsReply:
        study_id = request.study_id

        try:
            directions = self._backend.get_study_directions(study_id)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))

        assert directions is not None
        return api_pb2.GetStudyDirectionsReply(
            directions=[
                api_pb2.MINIMIZE if d == StudyDirection.MINIMIZE else api_pb2.MAXIMIZE
                for d in directions
            ]
        )

    def GetStudyUserAttributes(
        self,
        request: api_pb2.GetStudyUserAttributesRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.GetStudyUserAttributesReply:
        try:
            attributes = self._backend.get_study_user_attrs(request.study_id)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        return api_pb2.GetStudyUserAttributesReply(
            user_attributes={key: json.dumps(value) for key, value in attributes.items()}
        )

    def GetStudySystemAttributes(
        self,
        request: api_pb2.GetStudySystemAttributesRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.GetStudySystemAttributesReply:
        try:
            attributes = self._backend.get_study_system_attrs(request.study_id)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        return api_pb2.GetStudySystemAttributesReply(
            system_attributes={key: json.dumps(value) for key, value in attributes.items()}
        )

    def GetAllStudies(
        self,
        request: api_pb2.GetAllStudiesRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.GetAllStudiesReply:
        studies = self._backend.get_all_studies()
        return api_pb2.GetAllStudiesReply(
            studies=[
                api_pb2.Study(
                    study_id=study._study_id,
                    study_name=study.study_name,
                    directions=[
                        api_pb2.MINIMIZE if d == StudyDirection.MINIMIZE else api_pb2.MAXIMIZE
                        for d in study.directions
                    ],
                    user_attributes={
                        key: json.dumps(value) for key, value in study.user_attrs.items()
                    },
                    system_attributes={
                        key: json.dumps(value) for key, value in study.system_attrs.items()
                    },
                )
                for study in studies
            ]
        )

    def CreateNewTrial(
        self,
        request: api_pb2.CreateNewTrialRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.CreateNewTrialReply:
        study_id = request.study_id

        template_trial = None
        if not request.template_trial_is_none:
            template_trial = _from_proto_trial(request.template_trial)

        try:
            trial_id = self._backend.create_new_trial(study_id, template_trial)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))

        return api_pb2.CreateNewTrialReply(trial_id=trial_id)

    def SetTrialParameter(
        self,
        request: api_pb2.SetTrialParameterRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.SetTrialParameterReply:
        trial_id = request.trial_id
        param_name = request.param_name
        param_value_internal = request.param_value_internal
        distribution = json_to_distribution(request.distribution)
        try:
            self._backend.set_trial_param(trial_id, param_name, param_value_internal, distribution)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        except UpdateFinishedTrialError as e:
            context.abort(code=grpc.StatusCode.FAILED_PRECONDITION, details=str(e))
        except ValueError as e:
            context.abort(code=grpc.StatusCode.INVALID_ARGUMENT, details=str(e))
        return api_pb2.SetTrialParameterReply()

    def GetTrialIdFromStudyIdTrialNumber(
        self,
        request: api_pb2.GetTrialIdFromStudyIdTrialNumberRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.GetTrialIdFromStudyIdTrialNumberReply:
        study_id = request.study_id
        trial_number = request.trial_number

        try:
            trial_id = self._backend.get_trial_id_from_study_id_trial_number(
                study_id, trial_number
            )
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        return api_pb2.GetTrialIdFromStudyIdTrialNumberReply(trial_id=trial_id)

    def SetTrialStateValues(
        self,
        request: api_pb2.SetTrialStateValuesRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.SetTrialStateValuesReply:
        trial_id = request.trial_id
        state = request.state
        values = list(request.values)
        try:
            trial_updated = self._backend.set_trial_state_values(
                trial_id, _from_proto_trial_state(state), values
            )
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        except UpdateFinishedTrialError as e:
            context.abort(code=grpc.StatusCode.FAILED_PRECONDITION, details=str(e))
        return api_pb2.SetTrialStateValuesReply(trial_updated=trial_updated)

    def SetTrialIntermediateValue(
        self,
        request: api_pb2.SetTrialIntermediateValueRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.SetTrialIntermediateValueReply:
        trial_id = request.trial_id
        step = request.step
        intermediate_value = request.intermediate_value
        try:
            self._backend.set_trial_intermediate_value(trial_id, step, intermediate_value)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        except UpdateFinishedTrialError as e:
            context.abort(code=grpc.StatusCode.FAILED_PRECONDITION, details=str(e))
        return api_pb2.SetTrialIntermediateValueReply()

    def SetTrialUserAttribute(
        self,
        request: api_pb2.SetTrialUserAttributeRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.SetTrialUserAttributeReply:
        trial_id = request.trial_id
        key = request.key
        value = json.loads(request.value)
        try:
            self._backend.set_trial_user_attr(trial_id, key, value)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        except UpdateFinishedTrialError as e:
            context.abort(code=grpc.StatusCode.FAILED_PRECONDITION, details=str(e))
        return api_pb2.SetTrialUserAttributeReply()

    def SetTrialSystemAttribute(
        self,
        request: api_pb2.SetTrialSystemAttributeRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.SetTrialSystemAttributeReply:
        trial_id = request.trial_id
        key = request.key
        value = json.loads(request.value)
        try:
            self._backend.set_trial_system_attr(trial_id, key, value)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))
        except UpdateFinishedTrialError as e:
            context.abort(code=grpc.StatusCode.FAILED_PRECONDITION, details=str(e))
        return api_pb2.SetTrialSystemAttributeReply()

    def GetTrial(
        self,
        request: api_pb2.GetTrialRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.GetTrialReply:
        trial_id = request.trial_id
        try:
            trial = self._backend.get_trial(trial_id)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))

        return api_pb2.GetTrialReply(trial=_to_proto_trial(trial))

    def GetTrials(
        self,
        request: api_pb2.GetTrialsRequest,
        context: grpc.ServicerContext,
    ) -&gt; api_pb2.GetTrialsReply:
        study_id = request.study_id
        included_trial_ids = set(request.included_trial_ids)
        trial_id_greater_than = request.trial_id_greater_than
        try:
            trials = self._backend.get_all_trials(study_id, deepcopy=False)
        except KeyError as e:
            context.abort(code=grpc.StatusCode.NOT_FOUND, details=str(e))

        filtered_trials = [
            _to_proto_trial(t)
            for t in trials
            if t._trial_id &gt; trial_id_greater_than or t._trial_id in included_trial_ids
        ]
        return api_pb2.GetTrialsReply(trials=filtered_trials)


def _to_proto_trial_state(state: TrialState) -&gt; api_pb2.TrialState.ValueType:
    if state == TrialState.RUNNING:
        return api_pb2.RUNNING
    if state == TrialState.COMPLETE:
        return api_pb2.COMPLETE
    if state == TrialState.PRUNED:
        return api_pb2.PRUNED
    if state == TrialState.FAIL:
        return api_pb2.FAIL
    if state == TrialState.WAITING:
        return api_pb2.WAITING
    raise ValueError(f"Unknown TrialState: {state}")


def _from_proto_trial_state(state: api_pb2.TrialState.ValueType) -&gt; TrialState:
    if state == api_pb2.RUNNING:
        return TrialState.RUNNING
    if state == api_pb2.COMPLETE:
        return TrialState.COMPLETE
    if state == api_pb2.PRUNED:
        return TrialState.PRUNED
    if state == api_pb2.FAIL:
        return TrialState.FAIL
    if state == api_pb2.WAITING:
        return TrialState.WAITING
    raise ValueError(f"Unknown api_pb2.TrialState: {state}")


def _to_proto_trial(trial: FrozenTrial) -&gt; api_pb2.Trial:
    params = {}
    for key, value in trial.params.items():
        params[key] = trial.distributions[key].to_internal_repr(value)

    return api_pb2.Trial(
        trial_id=trial._trial_id,
        number=trial.number,
        state=_to_proto_trial_state(trial.state),
        values=trial.values,
        datetime_start=(
            trial.datetime_start.strftime(DATETIME_FORMAT) if trial.datetime_start else ""
        ),
        datetime_complete=(
            trial.datetime_complete.strftime(DATETIME_FORMAT) if trial.datetime_complete else ""
        ),
        distributions={
            key: distribution_to_json(distribution)
            for key, distribution in trial.distributions.items()
        },
        params=params,
        user_attributes={key: json.dumps(value) for key, value in trial.user_attrs.items()},
        system_attributes={key: json.dumps(value) for key, value in trial.system_attrs.items()},
        intermediate_values={step: value for step, value in trial.intermediate_values.items()},
    )


def _from_proto_trial(trial: api_pb2.Trial) -&gt; FrozenTrial:
    datetime_start = (
        datetime.strptime(trial.datetime_start, DATETIME_FORMAT) if trial.datetime_start else None
    )
    datetime_complete = (
        datetime.strptime(trial.datetime_complete, DATETIME_FORMAT)
        if trial.datetime_complete
        else None
    )
    distributions = {
        key: json_to_distribution(value) for key, value in trial.distributions.items()
    }
    params = {}
    for key, value in trial.params.items():
        params[key] = distributions[key].to_external_repr(value)

    return FrozenTrial(
        trial_id=trial.trial_id,
        number=trial.number,
        state=_from_proto_trial_state(trial.state),
        value=None,
        values=trial.values if trial.values else None,
        datetime_start=datetime_start,
        datetime_complete=datetime_complete,
        params=params,
        distributions=distributions,
        user_attrs={key: json.loads(value) for key, value in trial.user_attributes.items()},
        system_attrs={key: json.loads(value) for key, value in trial.system_attributes.items()},
        intermediate_values={step: value for step, value in trial.intermediate_values.items()},
    )
</pre></body></html>