Source code for orkgnlp.annotation.rfclf.encoder

# -*- coding: utf-8 -*-
""" ResearchFieldClassifier encoder. """
import re
from typing import Any, Dict, Tuple

from overrides import overrides
from transformers import AutoTokenizer

from orkgnlp.common.service.base import ORKGNLPBaseEncoder


[docs] class ResearchFieldClassifierEncoder(ORKGNLPBaseEncoder): """ The ResearchFieldClassifierEncoder encodes the given input to the arguments needed to execute the classification model. """ def __init__(self): super().__init__() self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("malteos/scincl") self.max_input_sizes: int = 512 self.device: str = "cpu"
[docs] @overrides def encode(self, raw_input: Any, **kwargs: Any) -> Tuple[Any, Dict[str, Any]]: raw_input = " ".join(re.sub("<.*?>", " ", raw_input).split()).lower() input_encoding = self.tokenizer.encode( raw_input, padding="max_length", truncation=True, max_length=self.max_input_sizes, return_tensors="pt", ) return [input_encoding], kwargs