diff --git a/fast_bert/data.py b/fast_bert/data.py index 4734a8a..054f2f5 100644 --- a/fast_bert/data.py +++ b/fast_bert/data.py @@ -181,8 +181,12 @@ def get_test_examples(self, filename='test.txt', size=-1): def get_labels(self, filename='labels.csv'): """See base class.""" if self.labels == None: - self.labels = list(pd.read_csv(os.path.join( - self.label_dir, filename), header=None)[0].astype('str').values) + if ".xlsx" in filename: + self.labels = list(pd.read_excel(os.path.join( + self.label_dir, filename), header=None)[0].astype('str').values) + else: + self.labels = list(pd.read_csv(os.path.join( + self.label_dir, filename), header=None)[0].astype('str').values) return self.labels def _create_examples(self, lines, set_type): @@ -200,7 +204,7 @@ def read_col_file(self, filename): ''' read file return format : - [ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'], + [ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'], ['British', 'B-MISC'], ['lamb', 'O'], ['.', 'O'] ] ''' f = open(filename) @@ -235,25 +239,40 @@ def __init__(self, data_dir, label_dir): def get_train_examples(self, filename='train.csv', text_col='text', label_col='label', size=-1): if size == -1: - data_df = pd.read_csv(os.path.join(self.data_dir, filename)) + if ".xlsx" in filename: + data_df = pd.read_excel(os.path.join(self.data_dir, filename)) + else: + data_df = pd.read_csv(os.path.join(self.data_dir, filename)) return self._create_examples(data_df, "train", text_col=text_col, label_col=label_col) else: - data_df = pd.read_csv(os.path.join(self.data_dir, filename)) + if ".xlsx" in filename: + data_df = pd.read_excel(os.path.join(self.data_dir, filename)) + else: + data_df = pd.read_csv(os.path.join(self.data_dir, filename)) # data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml) return self._create_examples(data_df.sample(size), "train", text_col=text_col, label_col=label_col) def get_dev_examples(self, filename='val.csv', text_col='text', label_col='label', size=-1): if size == -1: - data_df = pd.read_csv(os.path.join(self.data_dir, filename)) + if ".xlsx" in filename: + data_df = pd.read_excel(os.path.join(self.data_dir, filename)) + else: + data_df = pd.read_csv(os.path.join(self.data_dir, filename)) return self._create_examples(data_df, "dev", text_col=text_col, label_col=label_col) else: - data_df = pd.read_csv(os.path.join(self.data_dir, filename)) + if ".xlsx" in filename: + data_df = pd.read_excel(os.path.join(self.data_dir, filename)) + else: + data_df = pd.read_csv(os.path.join(self.data_dir, filename)) return self._create_examples(data_df.sample(size), "dev", text_col=text_col, label_col=label_col) def get_test_examples(self, filename='val.csv', text_col='text', label_col='label', size=-1): - data_df = pd.read_csv(os.path.join(self.data_dir, filename)) + if ".xlsx" in filename: + data_df = pd.read_excel(os.path.join(self.data_dir, filename)) + else: + data_df = pd.read_csv(os.path.join(self.data_dir, filename)) # data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml) if size == -1: return self._create_examples(data_df, "test", text_col=text_col, label_col=None) @@ -263,8 +282,12 @@ def get_test_examples(self, filename='val.csv', text_col='text', label_col='labe def get_labels(self, filename='labels.csv'): """See base class.""" if self.labels == None: - self.labels = list(pd.read_csv(os.path.join( - self.label_dir, filename), header=None)[0].astype('str').values) + if ".xlsx" in filename: + self.labels = list(pd.read_excel(os.path.join( + self.label_dir, filename), header=None)[0].astype('str').values) + else: + self.labels = list(pd.read_csv(os.path.join( + self.label_dir, filename), header=None)[0].astype('str').values) return self.labels def _create_examples(self, df, set_type, text_col, label_col): @@ -350,13 +373,13 @@ def load(data_dir, backend='nccl', filename="databunch.pkl"): def __init__(self, data_dir, label_dir, tokenizer, train_file='train.csv', val_file='val.csv', test_data=None, label_file='labels.csv', text_col='text', label_col='label', bs=32, maxlen=512, multi_gpu=True, multi_label=False, backend="nccl", model_type='bert', custom_sampler=None): - + if isinstance(tokenizer, str): _,_,tokenizer_class = MODEL_CLASSES[model_type] # instantiate the new tokeniser object using the tokeniser name tokenizer = tokenizer_class.from_pretrained(tokenizer, do_lower_case=('uncased' in tokenizer)) - self.tokenizer = tokenizer + self.tokenizer = tokenizer self.data_dir = data_dir self.maxlen = maxlen self.bs = bs @@ -451,7 +474,7 @@ def __init__(self, data_dir, label_dir, tokenizer, train_file='train.csv', val_f torch.distributed.init_process_group(backend=backend, init_method="tcp://localhost:23459", rank=0, world_size=1) - + except: pass diff --git a/fast_bert/data_cls.py b/fast_bert/data_cls.py index 6fb26b8..7cef2a5 100644 --- a/fast_bert/data_cls.py +++ b/fast_bert/data_cls.py @@ -228,13 +228,19 @@ def get_train_examples( ): if size == -1: - data_df = pd.read_csv(os.path.join(self.data_dir, filename)) + if ".xlsx" in filename: + data_df = pd.read_excel(os.path.join(self.data_dir, filename)) + else: + data_df = pd.read_csv(os.path.join(self.data_dir, filename)) return self._create_examples( data_df, "train", text_col=text_col, label_col=label_col ) else: - data_df = pd.read_csv(os.path.join(self.data_dir, filename)) + if ".xlsx" in filename: + data_df = pd.read_excel(os.path.join(self.data_dir, filename)) + else: + data_df = pd.read_csv(os.path.join(self.data_dir, filename)) # data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml) return self._create_examples( data_df.sample(size), "train", text_col=text_col, label_col=label_col @@ -245,12 +251,18 @@ def get_dev_examples( ): if size == -1: - data_df = pd.read_csv(os.path.join(self.data_dir, filename)) + if ".xlsx" in filename: + data_df = pd.read_excel(os.path.join(self.data_dir, filename)) + else: + data_df = pd.read_csv(os.path.join(self.data_dir, filename)) return self._create_examples( data_df, "dev", text_col=text_col, label_col=label_col ) else: - data_df = pd.read_csv(os.path.join(self.data_dir, filename)) + if ".xlsx" in filename: + data_df = pd.read_excel(os.path.join(self.data_dir, filename)) + else: + data_df = pd.read_csv(os.path.join(self.data_dir, filename)) return self._create_examples( data_df.sample(size), "dev", text_col=text_col, label_col=label_col ) @@ -258,7 +270,10 @@ def get_dev_examples( def get_test_examples( self, filename="val.csv", text_col="text", label_col="label", size=-1 ): - data_df = pd.read_csv(os.path.join(self.data_dir, filename)) + if ".xlsx" in filename: + data_df = pd.read_excel(os.path.join(self.data_dir, filename)) + else: + data_df = pd.read_csv(os.path.join(self.data_dir, filename)) # data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml) if size == -1: return self._create_examples( @@ -272,11 +287,18 @@ def get_test_examples( def get_labels(self, filename="labels.csv"): """See base class.""" if self.labels is None: - self.labels = list( - pd.read_csv(os.path.join(self.label_dir, filename), header=None)[0] - .astype("str") - .values - ) + if ".xlsx" in filename: + self.labels = list( + pd.read_excel(os.path.join(self.label_dir, filename), header=None)[0] + .astype("str") + .values + ) + else: + self.labels = list( + pd.read_csv(os.path.join(self.label_dir, filename), header=None)[0] + .astype("str") + .values + ) return self.labels def _create_examples(self, df, set_type, text_col, label_col): @@ -345,7 +367,10 @@ def __init__(self, data_dir, filename, text_col, label_col): self.text_col = text_col self.label_col = label_col - self.data = pd.read_csv(os.path.join(data_dir, filename)) + if ".xlsx" in filename: + self.data = pd.read_excel(os.path.join(data_dir, filename)) + else: + self.data = pd.read_csv(os.path.join(data_dir, filename)) def __getitem__(self, idx): return self.data.loc[idx, self.text_col], self.data.loc[idx, self.label_col]