Source code for orkgnlp.annotation.rfclf.decoder

# -*- coding: utf-8 -*-
""" ResearchFieldClassifier decoder. """
from typing import Any, Generator, Union

import torch
from overrides import overrides

from orkgnlp.common.service.base import ORKGNLPBaseDecoder


[docs] class ResearchFieldClassifierDecoder(ORKGNLPBaseDecoder): """ The ResearchFieldClassifierDecoder decodes the ResearchFieldClassifier service model's output to a user-friendly one. """ def __init__(self, label_dict): """ :param label_dict: Classifier label mapping. :type label_dict: Dict[str, int] """ super().__init__() self.label_dict = {index: label for label, index in label_dict.items()}
[docs] @overrides(check_signature=False) def decode( self, model_output: Union[Any, Generator[Any, None, None]], top_n: int, **kwargs: Any ) -> Any: logits = model_output["logits"] top_n_scores, top_n_indices = torch.topk(logits, k=top_n) top_n_predicts = [ {"research_field": self.label_dict[indices.item()], "score": score.item()} for score, indices in zip(top_n_scores[0], top_n_indices[0]) ] return top_n_predicts