# -*- coding: utf-8 -*-
""" CS-NER service decoder. """
from typing import Any, Dict, Generator, List, Union
from overrides import overrides
from orkgnlp.common.service.base import ORKGNLPBaseDecoder
[docs]
class CSNerDecoder(ORKGNLPBaseDecoder):
"""
The CSNerDecoder decodes the CS-NER service model's output
to a user-friendly one.
"""
def __init__(self, alphabet):
"""
:param alphabet: Dict representing the word, char and label alphabets.
:type alphabet: Dict[str, Dict[str, int]]
"""
super().__init__()
self._alphabet = alphabet
self._UNKNOWN = "</unk>"
[docs]
@overrides(check_signature=False)
def decode(
self,
model_output: Union[Any, Generator[Any, None, None]],
raw_texts,
recover,
**kwargs: Any
) -> Any:
predicted_results = []
for i, batch in enumerate(model_output):
_, nbest_tag_seq = batch
tag_seq = nbest_tag_seq[:, :, 0]
pred_label, _ = self._recover_label(
tag_seq, recover[i][0], recover[i][0], recover[i][2]
)
predicted_results += pred_label
entities = self._get_entities(raw_texts, predicted_results)
return self._prepare_annotations(entities)
def _recover_label(self, pred_variable, gold_variable, mask_variable, word_recover):
"""
input:
pred_variable (batch_size, sent_len): pred tag result
gold_variable (batch_size, sent_len): gold result variable
mask_variable (batch_size, sent_len): mask variable
"""
pred_variable = pred_variable[word_recover]
gold_variable = gold_variable[word_recover]
mask_variable = mask_variable[word_recover]
batch_size = gold_variable.size(0)
seq_len = gold_variable.size(1)
mask = mask_variable.cpu().data.numpy()
pred_tag = pred_variable.cpu().data.numpy()
gold_tag = gold_variable.cpu().data.numpy()
batch_size = mask.shape[0]
pred_label = []
gold_label = []
for idx in range(batch_size):
labels = list(self._alphabet["label"].keys())
pred = [
self._get_instance(labels, pred_tag[idx][idy])
for idy in range(seq_len)
if mask[idx][idy] != 0
]
gold = [
self._get_instance(labels, gold_tag[idx][idy])
for idy in range(seq_len)
if mask[idx][idy] != 0
]
pred_label.append(pred)
gold_label.append(gold)
return pred_label, gold_label
@staticmethod
def _get_entities(raw_texts, predicted_results):
entities = {}
for idx in range(len(predicted_results)):
sent_length = len(predicted_results[idx])
key = ""
entity_txt = ""
for idy in range(sent_length):
if "S-" in predicted_results[idx][idy]:
key = predicted_results[idx][idy].split("-")[1]
if key in entities.keys():
entities[key].append(raw_texts[idx][0][idy])
else:
entities[key] = [raw_texts[idx][0][idy]]
elif "B-" in predicted_results[idx][idy]:
key = predicted_results[idx][idy].split("-")[1]
entity_txt = raw_texts[idx][0][idy].strip()
elif "I-" in predicted_results[idx][idy]:
entity_txt += " " + raw_texts[idx][0][idy].strip()
elif "E-" in predicted_results[idx][idy]:
entity_txt += " " + raw_texts[idx][0][idy].strip()
if key in entities.keys():
entities[key].append(entity_txt.strip())
else:
entities[key] = [entity_txt.strip()]
key = ""
entity_txt = ""
return entities
@staticmethod
def _get_instance(instances, index):
if index == 0:
return instances[0]
try:
return instances[index - 1]
except IndexError:
print("WARNING:Alphabet get_instance, unknown instance, return the first label.")
return instances[0]
@staticmethod
def _prepare_annotations(entities: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
annotations = []
for concept in entities:
annotations.append({"concept": concept, "entities": entities[concept]})
return annotations