diff --git a/octis/preprocessing/preprocessing.py b/octis/preprocessing/preprocessing.py index 64ede3ee..1c7c6caf 100644 --- a/octis/preprocessing/preprocessing.py +++ b/octis/preprocessing/preprocessing.py @@ -1,6 +1,6 @@ +import re import string from typing import List, Union - import spacy from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.model_selection import train_test_split @@ -9,7 +9,6 @@ from pathlib import Path from octis.dataset.dataset import Dataset from collections import Counter - """ Maps the language to its corresponding spacy model """ @@ -160,7 +159,7 @@ def preprocess_dataset(self, documents_path, labels_path=None, multilabel=False) # with Pool(self.num_processes) as p: # docs = p.map(self.simple_preprocessing_steps, docs) chunksize = max(1, len(docs) // (self.num_processes * 20)) - docs_list = process_map(self.simple_preprocessing_steps, docs, max_workers=self.num_processes, chunksize=chunksize) + docs = process_map(self.simple_preprocessing_steps, docs, max_workers=self.num_processes, chunksize=chunksize) else: docs = list(map(self.simple_preprocessing_steps, tqdm(docs))) if self.lowercase: @@ -174,6 +173,12 @@ def preprocess_dataset(self, documents_path, labels_path=None, multilabel=False) print("created vocab") print(len(vocabulary)) final_docs, final_labels, document_indexes = [], [], [] + + def valid_word_or_punc(word): + valid_word = len([rw for rw in re.findall(r"(?u)\b[\w|\-]{" + str(self.min_chars) + r",}\b", word) if rw in vocab]) > 0 + all_punc = len(word) == len(re.findall(r'[^\w]',word)) + return valid_word or all_punc + if labels_path is not None: if multilabel: labels = [ @@ -186,7 +191,8 @@ def preprocess_dataset(self, documents_path, labels_path=None, multilabel=False) vocab = set(vocabulary) for i, doc, label in zip(range(len(docs)), docs, labels): - new_doc = [w for w in doc.split() if w in vocab] + new_doc = [w for w in doc.split() if valid_word_or_punc(w)] + if len(new_doc) > self.min_doc_words: final_docs.append(new_doc) final_labels.append(label) @@ -206,7 +212,7 @@ def preprocess_dataset(self, documents_path, labels_path=None, multilabel=False) else: vocab = set(vocabulary) for i, doc in enumerate(docs): - new_doc = [w for w in doc.split() if w in vocab] + new_doc = [w for w in doc.split() if valid_word_or_punc(w)] if len(new_doc) > self.min_doc_words: final_docs.append(new_doc) document_indexes.append(i) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e451f8f6..569cfab0 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -54,7 +54,7 @@ def test_preprocessing_english_stops_split(data_dir): def test_preprocessing_multiprocess(data_dir): texts_path = data_dir+"/sample_texts/unprepr_docs.txt" p = Preprocessing(vocabulary=None, max_features=None, remove_punctuation=True, - lemmatize=False, num_processes=10, split=False, + lemmatize=False, num_processes=10, split=False, min_chars=2, min_words_docs=1) dataset = p.preprocess_dataset( documents_path=texts_path, @@ -64,6 +64,31 @@ def test_preprocessing_multiprocess(data_dir): dataset.load_custom_dataset_from_folder(data_dir + "/sample_texts") +def test_preprocessing_minimal(data_dir): + """ + This test is checking to make sure preprocessing does not remove tokens which the user does not + specify should be removed. + """ + texts_path = data_dir+"/sample_texts/unprepr_docs.txt" + p = Preprocessing(vocabulary=None, max_features=None, remove_punctuation=False, + remove_numbers = False, + lemmatize=False, split=False, + min_chars=1, min_words_docs=0) + + unprocessed = [d.strip() for d in open(texts_path, "r").readlines() if len(d.strip()) > 0] + raw_word_lens = [len(d.split()) for d in unprocessed] + + dataset = p.preprocess_dataset( + documents_path=texts_path, + ) + print(dataset.get_corpus()) + preprocessed_word_lens = [len(d) for d in dataset.get_corpus()] + print(list(zip(raw_word_lens,preprocessed_word_lens))) + assert len(raw_word_lens) == len(preprocessed_word_lens) + for i in range(len(preprocessed_word_lens)): + assert raw_word_lens[i] == preprocessed_word_lens[i] + + def test_load_20ng(): data_home = get_data_home(data_home=None) cache_path = _pkl_filepath(data_home, "20NewsGroup" + ".pkz")