# -*- coding: utf-8 -*-
""" Predicates recommendation service decoder. """
from typing import Any, Dict, Generator, List, Union
import numpy as np
from overrides import overrides
from pandas import DataFrame
from orkgnlp.common.service.base import ORKGNLPBaseDecoder
[docs]
class PredicatesRecommenderDecoder(ORKGNLPBaseDecoder):
"""
The PredicatesRecommenderDecoder decodes the Predicates' recommendation service model's output
to a user-friendly one.
"""
def __init__(self, train_df: DataFrame, predicates: Dict[str, List[Dict[str, str]]]):
"""
:param train_df: The training dataframe of the service.
:param predicates: Dict object representing the mapping from comparisons to predicates.
"""
super().__init__()
self._train_df = train_df
self._predicates = predicates
[docs]
@overrides
def decode(self, model_output: Union[Any, Generator[Any, None, None]], **kwargs: Any) -> Any:
cluster_label, model_labels_ = model_output[0], model_output[1]
cluster_instances_indices = np.argwhere(cluster_label == model_labels_).squeeze(1)
cluster_instances = self._train_df.iloc[cluster_instances_indices]
comparison_ids = cluster_instances["comparison_id"].unique()
return self._map_to_predicates(comparison_ids)
def _map_to_predicates(self, comparison_ids: List[str]) -> List[Dict[str, str]]:
predicates = []
for comparison_id in comparison_ids:
predicates.extend(self._predicates[comparison_id])
return predicates