# -*- coding: utf-8 -*-
""" Templates Recommendation service encoder. """
from typing import Any, Dict, List, Tuple
import torch
from overrides import overrides
from transformers import BertTokenizer
from orkgnlp.common.service.base import ORKGNLPBaseEncoder
from orkgnlp.common.util import text
[docs]
class TemplatesRecommenderEncoder(ORKGNLPBaseEncoder):
"""
The TemplatesRecommenderEncoder encodes the given input to the arguments
needed to execute a BertForSequenceClassification model.
"""
def __init__(self, templates: List[Dict[str, str]]):
"""
:param templates: templates used for training the service models as premises.
"""
super().__init__()
self.templates: List[Dict[str, str]] = templates
self.tokenizer: BertTokenizer = BertTokenizer.from_pretrained(
"allenai/scibert_scivocab_uncased"
)
self.max_input_sizes: int = self.tokenizer.model_max_length or 512
self.device: str = "cpu"
[docs]
@overrides
def encode(self, raw_input: Any, **kwargs: Any) -> Tuple[Any, Dict[str, Any]]:
def batch_generator():
for template in self.templates:
sequence = "[CLS] {} [SEP] {} [SEP]".format(
self._post_process(template["premise"]),
self._post_process(raw_input),
)
sequence_tokens = self.tokenizer.tokenize(sequence)
max_sizes = self.max_input_sizes
attention_mask = self._get_attention_mask(sequence_tokens)[:max_sizes]
token_type = self._get_token_type(sequence_tokens)[:max_sizes:]
sequence_tokens = self.tokenizer.convert_tokens_to_ids(sequence_tokens)[
:max_sizes:
]
attention_mask = torch.tensor(attention_mask).unsqueeze(0).to(self.device)
token_type = torch.tensor(token_type).unsqueeze(0).to(self.device)
sequence_tokens = torch.tensor(sequence_tokens).unsqueeze(0).to(self.device)
yield {
"input_ids": sequence_tokens,
"token_type_ids": token_type,
"attention_mask": attention_mask,
}
return batch_generator(), kwargs
@staticmethod
def _get_attention_mask(tokens):
return [1] * len(tokens)
@staticmethod
def _get_token_type(tokens):
sep_index = tokens.index("[SEP]") + 1
return [0] * sep_index + [1] * (len(tokens) - sep_index)
@staticmethod
def _post_process(string):
string = text.replace(string, [r"\s+-\s+", "-", "_", r"\."], " ")
return text.trim(string).lower()