Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated data.py and data_cls.py to work with xlsx data files #314

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions fast_bert/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,12 @@ def get_test_examples(self, filename='test.txt', size=-1):
def get_labels(self, filename='labels.csv'):
"""See base class."""
if self.labels == None:
self.labels = list(pd.read_csv(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
if ".xlsx" in filename:
self.labels = list(pd.read_excel(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
else:
self.labels = list(pd.read_csv(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
return self.labels

def _create_examples(self, lines, set_type):
Expand All @@ -200,7 +204,7 @@ def read_col_file(self, filename):
'''
read file
return format :
[ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'],
[ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'],
['British', 'B-MISC'], ['lamb', 'O'], ['.', 'O'] ]
'''
f = open(filename)
Expand Down Expand Up @@ -235,25 +239,40 @@ def __init__(self, data_dir, label_dir):
def get_train_examples(self, filename='train.csv', text_col='text', label_col='label', size=-1):

if size == -1:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))

return self._create_examples(data_df, "train", text_col=text_col, label_col=label_col)
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
# data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml)
return self._create_examples(data_df.sample(size), "train", text_col=text_col, label_col=label_col)

def get_dev_examples(self, filename='val.csv', text_col='text', label_col='label', size=-1):

if size == -1:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
return self._create_examples(data_df, "dev", text_col=text_col, label_col=label_col)
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
return self._create_examples(data_df.sample(size), "dev", text_col=text_col, label_col=label_col)

def get_test_examples(self, filename='val.csv', text_col='text', label_col='label', size=-1):
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
# data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml)
if size == -1:
return self._create_examples(data_df, "test", text_col=text_col, label_col=None)
Expand All @@ -263,8 +282,12 @@ def get_test_examples(self, filename='val.csv', text_col='text', label_col='labe
def get_labels(self, filename='labels.csv'):
"""See base class."""
if self.labels == None:
self.labels = list(pd.read_csv(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
if ".xlsx" in filename:
self.labels = list(pd.read_excel(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
else:
self.labels = list(pd.read_csv(os.path.join(
self.label_dir, filename), header=None)[0].astype('str').values)
return self.labels

def _create_examples(self, df, set_type, text_col, label_col):
Expand Down Expand Up @@ -350,13 +373,13 @@ def load(data_dir, backend='nccl', filename="databunch.pkl"):
def __init__(self, data_dir, label_dir, tokenizer, train_file='train.csv', val_file='val.csv', test_data=None,
label_file='labels.csv', text_col='text', label_col='label', bs=32, maxlen=512,
multi_gpu=True, multi_label=False, backend="nccl", model_type='bert', custom_sampler=None):

if isinstance(tokenizer, str):
_,_,tokenizer_class = MODEL_CLASSES[model_type]
# instantiate the new tokeniser object using the tokeniser name
tokenizer = tokenizer_class.from_pretrained(tokenizer, do_lower_case=('uncased' in tokenizer))

self.tokenizer = tokenizer
self.tokenizer = tokenizer
self.data_dir = data_dir
self.maxlen = maxlen
self.bs = bs
Expand Down Expand Up @@ -451,7 +474,7 @@ def __init__(self, data_dir, label_dir, tokenizer, train_file='train.csv', val_f
torch.distributed.init_process_group(backend=backend,
init_method="tcp://localhost:23459",
rank=0, world_size=1)

except:
pass

Expand Down
47 changes: 36 additions & 11 deletions fast_bert/data_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,19 @@ def get_train_examples(
):

if size == -1:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))

return self._create_examples(
data_df, "train", text_col=text_col, label_col=label_col
)
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
# data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml)
return self._create_examples(
data_df.sample(size), "train", text_col=text_col, label_col=label_col
Expand All @@ -245,20 +251,29 @@ def get_dev_examples(
):

if size == -1:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
return self._create_examples(
data_df, "dev", text_col=text_col, label_col=label_col
)
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
return self._create_examples(
data_df.sample(size), "dev", text_col=text_col, label_col=label_col
)

def get_test_examples(
self, filename="val.csv", text_col="text", label_col="label", size=-1
):
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
if ".xlsx" in filename:
data_df = pd.read_excel(os.path.join(self.data_dir, filename))
else:
data_df = pd.read_csv(os.path.join(self.data_dir, filename))
# data_df['comment_text'] = data_df['comment_text'].apply(cleanHtml)
if size == -1:
return self._create_examples(
Expand All @@ -272,11 +287,18 @@ def get_test_examples(
def get_labels(self, filename="labels.csv"):
"""See base class."""
if self.labels is None:
self.labels = list(
pd.read_csv(os.path.join(self.label_dir, filename), header=None)[0]
.astype("str")
.values
)
if ".xlsx" in filename:
self.labels = list(
pd.read_excel(os.path.join(self.label_dir, filename), header=None)[0]
.astype("str")
.values
)
else:
self.labels = list(
pd.read_csv(os.path.join(self.label_dir, filename), header=None)[0]
.astype("str")
.values
)
return self.labels

def _create_examples(self, df, set_type, text_col, label_col):
Expand Down Expand Up @@ -345,7 +367,10 @@ def __init__(self, data_dir, filename, text_col, label_col):
self.text_col = text_col
self.label_col = label_col

self.data = pd.read_csv(os.path.join(data_dir, filename))
if ".xlsx" in filename:
self.data = pd.read_excel(os.path.join(data_dir, filename))
else:
self.data = pd.read_csv(os.path.join(data_dir, filename))

def __getitem__(self, idx):
return self.data.loc[idx, self.text_col], self.data.loc[idx, self.label_col]
Expand Down