diff --git a/deeppavlov/dataset_readers/basic_classification_reader.py b/deeppavlov/dataset_readers/basic_classification_reader.py index 8ef767b368..4ccd3eebdc 100644 --- a/deeppavlov/dataset_readers/basic_classification_reader.py +++ b/deeppavlov/dataset_readers/basic_classification_reader.py @@ -78,11 +78,11 @@ def read(self, data_path: str, url: str = None, file = Path(data_path).joinpath(file_name) if file.exists(): if format == 'csv': - keys = ('sep', 'header', 'names') + keys = ('sep', 'header', 'names', 'dtype') options = {k: kwargs[k] for k in keys if k in kwargs} df = pd.read_csv(file, **options) elif format == 'json': - keys = ('orient', 'lines') + keys = ('orient', 'lines', 'dtype') options = {k: kwargs[k] for k in keys if k in kwargs} df = pd.read_json(file, **options) else: diff --git a/deeppavlov/dataset_readers/docred_reader.py b/deeppavlov/dataset_readers/docred_reader.py index 479854d041..d73efac937 100644 --- a/deeppavlov/dataset_readers/docred_reader.py +++ b/deeppavlov/dataset_readers/docred_reader.py @@ -409,6 +409,6 @@ def label_to_one_hot(self, labels: List[int]) -> List: def print_statistics(self, train_stat: Dict, valid_stat: Dict, test_stat: Dict) -> None: """ Print out the relation statistics as a markdown table """ df = pd.DataFrame([self.rel2relid, train_stat, valid_stat, test_stat]).T - df.columns = ['d{}'.format(i) for i, col in enumerate(df, 1)] + df.columns = ['rel_id', 'train', 'valid', 'test'] logger.info("\n") logger.info(df)