# -*- coding: utf-8 -*-
""" Common encoders for the clustering services. """
from typing import Any, Dict, Tuple
from onnx import ModelProto
from overrides import overrides
from sentence_transformers import SentenceTransformer
from orkgnlp.common.service.base import ORKGNLPBaseEncoder
from orkgnlp.common.service.runners import ORKGNLPONNXRunner
from orkgnlp.common.util import text
[docs]
class TfidfKmeansEncoder(ORKGNLPBaseEncoder):
"""
The TfidfKmeansEncoder encodes the given input to a TF-IDF vector
needed to execute a Kmeans onnx model.
"""
def __init__(self, vectorizer: ModelProto):
"""
:param vectorizer: The TF-IDF vectorizer needed for the encoding.
"""
super().__init__()
self._vectorizer: ORKGNLPONNXRunner = ORKGNLPONNXRunner(vectorizer)
[docs]
@overrides
def encode(self, raw_input: Any, **kwargs: Any) -> Tuple[Any, Dict[str, Any]]:
preprocessed_text = self._text_process(raw_input)
output, _ = self._vectorizer.run(inputs=([preprocessed_text],), output_names=["variable"])
return (output[0][0],), kwargs
@staticmethod
def _text_process(q: str) -> str:
q = text.remove_punctuation(q)
q = text.remove_stopwords(q)
q = text.lemmatize(q)
return q.lower()