Source code for orkgnlp.deepresearch.textsplitter

# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import List, Optional


[docs] class TextSplitter(ABC): def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200): self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap if self.chunk_overlap >= self.chunk_size: raise ValueError("Cannot have chunk_overlap >= chunk_size")
[docs] @abstractmethod def split_text(self, text: str) -> List[str]: """Split a single piece of text into chunks""" pass
def create_documents(self, texts: List[str]) -> List[str]: documents = [] for text in texts: chunks = self.split_text(text) documents.extend(chunks) return documents def split_documents(self, documents: List[str]) -> List[str]: return self.create_documents(documents) def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: text = separator.join(docs).strip() return text if text else None def merge_splits(self, splits: List[str], separator: str) -> List[str]: docs = [] current_doc = [] total = 0 for d in splits: length = len(d) if total + length >= self.chunk_size: if total > self.chunk_size: print( f"Warning: Created a chunk of size {total}, which is longer than {self.chunk_size}" ) if current_doc: doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) # Reduce overlap while total > self.chunk_overlap or ( total + length > self.chunk_size and total > 0 ): total -= len(current_doc[0]) current_doc.pop(0) current_doc.append(d) total += length final_doc = self._join_docs(current_doc, separator) if final_doc is not None: docs.append(final_doc) return docs
[docs] class RecursiveCharacterTextSplitter(TextSplitter): def __init__( self, chunk_size: int = 1000, chunk_overlap: int = 200, separators: Optional[List[str]] = None, ): super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self.separators = separators or ["\n\n", "\n", ".", ",", ">", "<", " ", ""]
[docs] def split_text(self, text: str) -> List[str]: final_chunks = [] # Choose best separator separator = self.separators[-1] for s in self.separators: if s == "" or s in text: separator = s break splits = text.split(separator) if separator else list(text) good_splits = [] for part in splits: if len(part) < self.chunk_size: good_splits.append(part) else: if good_splits: merged = self.merge_splits(good_splits, separator) final_chunks.extend(merged) good_splits = [] deeper_chunks = self.split_text(part) final_chunks.extend(deeper_chunks) if good_splits: merged = self.merge_splits(good_splits, separator) final_chunks.extend(merged) return final_chunks