Source code for orkgnlp.annotation.tdm.decoder
# -*- coding: utf-8 -*-
""" TDM-Extraction service decoder. """
from typing import Any, Dict, Generator, List
import numpy as np
import torch
from overrides import overrides
from pandas import DataFrame
from orkgnlp.common.service.base import ORKGNLPBaseDecoder
[docs]
class TdmExtractorDecoder(ORKGNLPBaseDecoder):
"""
The TdmExtractorDecoder decodes the TDM-Extraction service model's output
to a user-friendly one.
"""
def __init__(self, labels: DataFrame):
"""
:param labels: TDM gold labels given as one-columned-dataframe
"""
super().__init__()
self.labels: DataFrame = labels
[docs]
@overrides(check_signature=False)
def decode(self, model_output: Generator[Any, None, None], top_n: int, **kwargs: Any) -> Any:
self.labels["prob"] = np.nan
for batch_idx, batch in enumerate(model_output):
predictions = torch.sigmoid(batch.logits)
for predictions_idx, (true, false) in enumerate(predictions):
if true.item() > false.item():
label_index = batch_idx * predictions.shape[0] + predictions_idx
self.labels.at[label_index, "prob"] = true.item()
candidates = self.labels[self.labels["prob"].notnull()]
candidates = candidates.sort_values(by="prob", ascending=False)
return self._prepare_service_output(
candidates[0][:top_n].tolist(), candidates["prob"][:top_n].tolist()
)
@staticmethod
def _prepare_service_output(tdms: List[str], scores: List[float]) -> List[Dict[str, Any]]:
service_output = []
for i, tdm in enumerate(tdms):
t, d, m = tdm.split("#")
service_output.append({"task": t, "dataset": d, "metric": m, "score": scores[i]})
return service_output