Source code for orkgnlp.annotation.agriner.annotator

# -*- coding: utf-8 -*-
""" Agri-NER service. """

from typing import Any

from transformers import pipeline

from orkgnlp.annotation.agriner.decoder import AgriNerDecoder
from orkgnlp.common.config import orkgnlp_context
from orkgnlp.common.service.base import (
    ORKGNLPBaseDecoder,
    ORKGNLPBaseEncoder,
    ORKGNLPBaseRunner,
    ORKGNLPBaseService,
)
from orkgnlp.common.service.runners import ORKGNLPTorchRunner


[docs] class AgriNer(ORKGNLPBaseService): """ The AgriNer requires a classification model trained on titles and its configurations obtained during the training. The required files are downloaded while initiation, if it has not happened before. You can pass the parameter ``force_download=True`` to remove and re-download the previous downloaded service files. """ SERVICE_NAME = "agri-ner" def __init__(self, *args: Any, **kwargs: Any): super().__init__(self.SERVICE_NAME, *args, **kwargs) encoder: ORKGNLPBaseEncoder = ORKGNLPBaseEncoder() _model = pipeline( task="token-classification", model=self._config.service_dir, tokenizer="bert-base-cased", aggregation_strategy="simple", ) runner: ORKGNLPBaseRunner = ORKGNLPTorchRunner(_model) decoder: ORKGNLPBaseDecoder = AgriNerDecoder() self._register_pipeline("main", encoder, runner, decoder)
[docs] def __call__(self, title: str) -> Any: """ Applies Named Entity Recognition on the given paper's title. :param title: Paper's title. :return: A list of the annotated parts for the given text is returned. """ return self._run(raw_input=title)
orkgnlp_context.get("SERVICE_MAP")[AgriNer.SERVICE_NAME] = AgriNer