Source code for orkgnlp.annotation.tdm.extractor
# -*- coding: utf-8 -*-
""" TDM-Extraction service. """
from typing import Any
from pandas import DataFrame
from transformers import XLNetForSequenceClassification
from orkgnlp.annotation.tdm.decoder import TdmExtractorDecoder
from orkgnlp.annotation.tdm.encoder import TdmExtractorEncoder
from orkgnlp.common.config import orkgnlp_context
from orkgnlp.common.service.base import (
ORKGNLPBaseDecoder,
ORKGNLPBaseEncoder,
ORKGNLPBaseRunner,
ORKGNLPBaseService,
)
from orkgnlp.common.service.runners import ORKGNLPTorchRunner
from orkgnlp.common.util import io
[docs]
class TdmExtractor(ORKGNLPBaseService):
"""
The TdmExtractor requires a transformers.XLNetForSequenceClassification pretrained model and a TDM gold labels file.
The required files are downloaded while initiation, if it has not happened before.
You can pass the parameter ``force_download=True`` to remove and re-download the previous downloaded service files.
"""
SERVICE_NAME = "tdm-extraction"
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(self.SERVICE_NAME, *args, **kwargs)
requirements = self._config.requirements
labels: DataFrame = io.read_csv(requirements["labels"], sep="\t")
if self._unittest:
labels = labels.head()
encoder: ORKGNLPBaseEncoder = TdmExtractorEncoder(labels, self._batch_size)
runner: ORKGNLPBaseRunner = ORKGNLPTorchRunner(
io.load_transformers_pretrained(
self._config.service_dir, XLNetForSequenceClassification
)
)
decoder: ORKGNLPBaseDecoder = TdmExtractorDecoder(labels)
self._register_pipeline("main", encoder, runner, decoder)
[docs]
def __call__(self, text: str, top_n: int = 5) -> Any:
"""
Extracts Task-Dataset-Metric (TDM) entities from a given
DocTAET (Title, Abstract, ExperimentalSetup, TableInfo) ``text``
:param text: `DocTAET <https://doi.org/10.1007/978-3-030-91669-5_35>`_ represented text.
:param top_n: Top n results to be extracted. Defaults to 5.
:return: A list of TDMs.
"""
return self._run(
raw_input=text,
top_n=top_n,
multiple_batches=True,
)
orkgnlp_context.get("SERVICE_MAP")[TdmExtractor.SERVICE_NAME] = TdmExtractor