diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c73b552..30742a8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -69,6 +69,10 @@ jobs: if: matrix.python-version != '3.6' && matrix.python-version != '3.7' run: pip install -r requirements-dev.txt + - name: Install training dependencies + if: matrix.python-version != '3.6' && matrix.python-version != '3.7' + run: pip install -r training/requirements-dev.txt + - name: Install dependencies (legacy versions) if: matrix.python-version == '3.6' || matrix.python-version == '3.7' run: | diff --git a/.gitignore b/.gitignore index ab5d3e5..f634b25 100644 --- a/.gitignore +++ b/.gitignore @@ -116,3 +116,4 @@ Makefile # eval UD/ +training/**/data \ No newline at end of file diff --git a/eval/README.rst b/eval/README.rst deleted file mode 100644 index b341409..0000000 --- a/eval/README.rst +++ /dev/null @@ -1,13 +0,0 @@ -Instructions to run the evaluation ----------------------------------- - -The scores are calculated on `Universal Dependencies `_ treebanks on single word tokens (including some contractions but not merged prepositions). They can be reproduced by concatenating all available UD files and by using the script ``udscore.py``. - -1. Download the main archive containing the `Universal Dependencies `_ treebanks -2. Extract relevant data (language and if applicable specific treebank, see notes in the results table) -3. Concatenate the train, dev and test data into a single file (e.g. ``cat de_gsd*.conllu > de-gsd-all.conllu``) -4. Store the files at the expected location (``tests/UD/``) and/or edit the corresponding paths in ``udscore.py`` -5. Install the evaluation dependencies (``pip install -r eval-requirements.txt``) -6. Run the script, e.g. from the home directory ``python3 eval/udscore.py`` -7. Results are displayed in the terminal and errors are written in a CSV file - diff --git a/eval/eval-requirements.txt b/eval/eval-requirements.txt deleted file mode 100644 index 634aded..0000000 --- a/eval/eval-requirements.txt +++ /dev/null @@ -1 +0,0 @@ -conllu>=4.5.2 diff --git a/eval/udscore.py b/eval/udscore.py deleted file mode 100644 index 0238a12..0000000 --- a/eval/udscore.py +++ /dev/null @@ -1,134 +0,0 @@ -import csv -import time - -from collections import Counter -from os import path - -from conllu import parse_incr # type: ignore -from simplemma import Lemmatizer -from simplemma.strategies.dictionaries import DefaultDictionaryFactory -from simplemma.strategies.default import DefaultStrategy - - -data_files = [ - ("bg", "tests/UD/bg-btb-all.conllu"), - ("cs", "tests/UD/cs-pdt-all.conllu"), # longer to process - ("da", "tests/UD/da-ddt-all.conllu"), - ("de", "tests/UD/de-gsd-all.conllu"), - ("el", "tests/UD/el-gdt-all.conllu"), - ("en", "tests/UD/en-gum-all.conllu"), - ("es", "tests/UD/es-gsd-all.conllu"), - ("et", "tests/UD/et-edt-all.conllu"), - ("fi", "tests/UD/fi-tdt-all.conllu"), - ("fr", "tests/UD/fr-gsd-all.conllu"), - ("ga", "tests/UD/ga-idt-all.conllu"), - ("hi", "tests/UD/hi-hdtb-all.conllu"), - ("hu", "tests/UD/hu-szeged-all.conllu"), - ("hy", "tests/UD/hy-armtdp-all.conllu"), - ("id", "tests/UD/id-csui-all.conllu"), - ("it", "tests/UD/it-isdt-all.conllu"), - ("la", "tests/UD/la-proiel-all.conllu"), - ("lt", "tests/UD/lt-alksnis-all.conllu"), - ("lv", "tests/UD/lv-lvtb-all.conllu"), - ("nb", "tests/UD/nb-bokmaal-all.conllu"), - ("nl", "tests/UD/nl-alpino-all.conllu"), - ("pl", "tests/UD/pl-pdb-all.conllu"), - ("pt", "tests/UD/pt-gsd-all.conllu"), - ("ru", "tests/UD/ru-gsd-all.conllu"), - ("sk", "tests/UD/sk-snk-all.conllu"), - ("tr", "tests/UD/tr-boun-all.conllu"), -] - -# doesn't work: right-to-left? -# data_files = [ -# ('he', 'tests/UD/he-htb-all.conllu'), -# ('ur', 'tests/UD/ur-udtb-all.conllu'), -# ] - -# data_files = [ -# ('de', 'tests/UD/de-gsd-all.conllu'), -#] - - -for filedata in data_files: - total, focus_total, _greedy, nongreedy, zero, focus_zero, focus, focus_nongreedy = ( - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ) - errors, flag = [], False - language, filename = filedata[0], filedata[1] - - start = time.time() - _dictionary_factory = DefaultDictionaryFactory() - strategies = DefaultStrategy(greedy=False) - lemmatizer = Lemmatizer( - lemmatization_strategy=DefaultStrategy( - greedy=False, dictionary_factory=_dictionary_factory - ), - ) - greedy_lemmatizer = Lemmatizer( - lemmatization_strategy=DefaultStrategy( - greedy=True, dictionary_factory=_dictionary_factory - ), - ) - print("==", filedata, "==") - with open(filename, "r", encoding="utf-8") as data_file: - for tokenlist in parse_incr(data_file): - for token in tokenlist: - error_flag = False - if token["lemma"] == "_": # or token['upos'] in ('PUNCT', 'SYM') - # flag = True - continue - - initial = bool(token["id"] == 1) - token_form = token["form"].lower() if initial else token["form"] - - candidate = lemmatizer.lemmatize(token_form, lang=language) - greedy_candidate = greedy_lemmatizer.lemmatize(token_form, lang=language) - - if token["upos"] in ("ADJ", "NOUN"): - focus_total += 1 - if token["form"] == token["lemma"]: - focus_zero += 1 - if greedy_candidate == token["lemma"]: - focus += 1 - if candidate == token["lemma"]: - focus_nongreedy += 1 - total += 1 - if token["form"] == token["lemma"]: - zero += 1 - if greedy_candidate == token["lemma"]: - _greedy += 1 - else: - error_flag = True - if candidate == token["lemma"]: - nongreedy += 1 - else: - error_flag = True - if error_flag: - errors.append( - (token["form"], token["lemma"], candidate, greedy_candidate) - ) - with open( - f'{path.basename(filename).replace("conllu","csv")}', "w", encoding="utf-8" - ) as csvfile: - writer = csv.writer(csvfile) - writer.writerow(("form", "lemma", "candidate", "greedy_candidate")) - writer.writerows(errors) - - print("exec time:\t %.3f" % (time.time() - start)) - print("token count:\t", total) - print("greedy:\t\t %.3f" % (_greedy / total)) - print("non-greedy:\t %.3f" % (nongreedy / total)) - print("baseline:\t %.3f" % (zero / total)) - print("ADJ+NOUN greedy:\t\t %.3f" % (focus / focus_total)) - print("ADJ+NOUN non-greedy:\t\t %.3f" % (focus_nongreedy / focus_total)) - print("ADJ+NOUN baseline:\t\t %.3f" % (focus_zero / focus_total)) - mycounter = Counter(errors) - print(mycounter.most_common(20)) diff --git a/simplemma/language_detector.py b/simplemma/language_detector.py index f6fab14..52369a6 100644 --- a/simplemma/language_detector.py +++ b/simplemma/language_detector.py @@ -198,18 +198,14 @@ def proportion_in_target_languages( Returns: float: The proportion of text in the target language(s). """ - tokens = self._token_sampler.sample_text(text) - if len(tokens) == 0: - return 0 - - in_target = 0 - for token in tokens: - for lang_code in self._lang: - candidate = self._lemmatization_strategy.get_lemma(token, lang_code) - if candidate is not None: - in_target += 1 - break - return in_target / len(tokens) + return sum( + percentage + for ( + lang_code, + percentage, + ) in self.proportion_in_each_language(text).items() + if lang_code != "unk" + ) def main_language( self, diff --git a/tests/test_language_detector.py b/tests/test_language_detector.py index ad68f0c..4affdd1 100644 --- a/tests/test_language_detector.py +++ b/tests/test_language_detector.py @@ -108,15 +108,6 @@ def test_in_target_language() -> None: == 1.0 ) - langs = ("en", "de") - text = "It was a true gift" - assert ( - LanguageDetector(lang=langs).proportion_in_target_languages(text) - == in_target_language(text, lang=langs) - == 1.0 - ) - in_target_language("It was a true gift", lang=("en", "de")) - def test_main_language(): text = "Dieser Satz ist auf Deutsch." diff --git a/training/README.rst b/training/README.rst new file mode 100644 index 0000000..ba4c630 --- /dev/null +++ b/training/README.rst @@ -0,0 +1,15 @@ +Instructions to run the evaluation +---------------------------------- + +The scores are calculated on `Universal Dependencies `_ treebanks on single word tokens (including some contractions but not merged prepositions). They can be reproduced by the following steps: + +1. Install the evaluation dependencies (``pip install -r training/requirements.txt``) +2. Update ``DATA_URL`` in ``training/download-eval-data.py`` to point to the latest treebanks archive from `Universal Dependencies ` (or the version that you which to use). +3. Run ``python3 training/download-eval-data.py`` which will + 1. Download the archive + 2. Extract relevant data (language and if applicable specific treebank, see notes in the results table) + 3. Concatenate the train, dev and test data into a single file (e.g. ``cat de_gsd*.conllu > de-gsd-all.conllu``) + 4. Store the files at the expected location (``training/data/UD/``) +4. Run the script, e.g. from the home directory ``python3 training/evaluate_simplema.py`` +5. Results are stored at ``training/data/results/results_summary.csv``. Also, errors are written in a CSV file for each dataset under the ``data/results``folder. + diff --git a/training/download-eval-data.py b/training/download-eval-data.py new file mode 100644 index 0000000..5451392 --- /dev/null +++ b/training/download-eval-data.py @@ -0,0 +1,71 @@ +from typing import Iterable, List, Tuple +from os import mkdir, path, scandir +import re +import logging +import tarfile +import requests +from glob import glob + +from simplemma.strategies.dictionaries.dictionary_factory import SUPPORTED_LANGUAGES + +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) +DATA_URL = "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-5150/ud-treebanks-v2.12.tgz?sequence=1&isAllowed=y" +DATA_FOLDER = path.join(path.dirname(__file__), "data") +DATA_FILE = path.join(DATA_FOLDER, "ud-treeebanks.tgz") +CLEAN_DATA_FOLDER = path.join(DATA_FOLDER, "UD") + + +def get_dirs(file_name: str) -> List[str]: + return [dir.name for dir in scandir(file_name) if dir.is_dir()] + + +def get_files(file_name: str) -> List[str]: + return [dir.name for dir in scandir(file_name) if dir.is_file()] + + +def get_relevant_language_data_folders(data_folder) -> Iterable[Tuple[str, str, str]]: + for lang_folder in get_dirs(data_folder): + lang_data_folder = path.join(uncompressed_data_folder, lang_folder) + conllu_file = glob(path.join(lang_data_folder, "*.conllu"))[0] + matches_files = re.search("^.*/(.*)-ud.*$", conllu_file) + if matches_files is not None: + dataset_name = matches_files.groups()[0] + lang = dataset_name.split("_")[0] + + if lang in SUPPORTED_LANGUAGES: + yield (lang, dataset_name, lang_data_folder) + + +if path.exists(DATA_FOLDER) or path.exists(CLEAN_DATA_FOLDER): + raise Exception( + "Data folder seems to be already present. Delete it before creating new data." + ) + +mkdir(DATA_FOLDER) +mkdir(CLEAN_DATA_FOLDER) + +log.info("Downloading evaluation data...") +response = requests.get(DATA_URL) +open(DATA_FILE, "wb").write(response.content) + +log.info("Uncompressing evaluation data...") +with tarfile.open(DATA_FILE) as tar: + tar.extractall(DATA_FOLDER) +uncompressed_data_folder = path.join( + DATA_FOLDER, glob(f"{DATA_FOLDER}/ud-treebanks-*")[0] +) + +log.info("Filtering files...") +for lang, dataset_name, dataset__folder in get_relevant_language_data_folders( + uncompressed_data_folder +): + log.info(lang + " - " + dataset__folder) + # Concatenate the train, dev and test data into a single file (e.g. ``cat de_gsd*.conllu > de-gsd-all.conllu``) + lang_clean_data_file = path.join(CLEAN_DATA_FOLDER, f"{dataset_name}.conllu") + log.debug(f"Procressing data for {dataset_name}") + with open(lang_clean_data_file, "w") as outfile: + for file in glob(path.join(dataset__folder, "*.conllu")): + with open(file) as infile: + for line in infile: + outfile.write(line) diff --git a/training/evaluate_simplema.py b/training/evaluate_simplema.py new file mode 100644 index 0000000..4cd0dec --- /dev/null +++ b/training/evaluate_simplema.py @@ -0,0 +1,152 @@ +import csv +import time + +from collections import Counter +from os import mkdir, path, rmdir, scandir, unlink +import logging + +from conllu import parse_incr # type: ignore + +from simplemma import Lemmatizer +from simplemma.strategies.dictionaries import DefaultDictionaryFactory +from simplemma.strategies.default import DefaultStrategy + +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +DATA_FOLDER = path.join(path.dirname(__file__), "data") +CLEAN_DATA_FOLDER = path.join(DATA_FOLDER, "UD") +RESULTS_FOLDER = path.join(DATA_FOLDER, "results") + +data_files = [ + (data_file.name.split("_")[0], data_file.name) + for data_file in scandir(CLEAN_DATA_FOLDER) +] + +if not path.exists(CLEAN_DATA_FOLDER): + raise Exception( + "It doesn't seem like data was downloaded and precessed for evaluation." + ) + +if path.exists(RESULTS_FOLDER): + for result_file in scandir(RESULTS_FOLDER): + unlink(result_file) + rmdir(RESULTS_FOLDER) +mkdir(RESULTS_FOLDER) + +with open(path.join(RESULTS_FOLDER, "results_summary.csv"), "w") as csv_results_file: + csv_results_file_writer = csv.writer(csv_results_file) + csv_results_file_writer.writerow( + ( + "dataset", + "exec time", + "token count", + "greedy", + "non-greedy", + "baseline", + "ADJ+NOUN greedy", + "ADJ+NOUN non-greedy", + "ADJ+NOUN baseline", + ) + ) + + for language, filename in data_files: + total = 0 + focus_total = 0 + greedy = 0 + nongreedy = 0 + zero = 0 + focus = 0 + focus_nongreedy = 0 + focus_zero = 0 + errors = [] + + start = time.time() + _dictionary_factory = DefaultDictionaryFactory() + strategies = DefaultStrategy(greedy=False) + lemmatizer = Lemmatizer( + lemmatization_strategy=DefaultStrategy( + greedy=False, dictionary_factory=_dictionary_factory + ), + ) + greedy_lemmatizer = Lemmatizer( + lemmatization_strategy=DefaultStrategy( + greedy=True, dictionary_factory=_dictionary_factory + ), + ) + log.info(f"Evaluating dataset: {filename}") + with open( + path.join(CLEAN_DATA_FOLDER, filename), "r", encoding="utf-8" + ) as data_file: + for tokens in parse_incr(data_file): + for token in tokens: + error_flag = False + if token["lemma"] == "_": # or token['upos'] in ('PUNCT', 'SYM') + continue + + initial = bool(token["id"] == 1) + token_form = token["form"].lower() if initial else token["form"] + + candidate = lemmatizer.lemmatize(token_form, lang=language) + greedy_candidate = greedy_lemmatizer.lemmatize( + token_form, lang=language + ) + + if token["upos"] in ("ADJ", "NOUN"): + focus_total += 1 + if token["form"] == token["lemma"]: + focus_zero += 1 + if greedy_candidate == token["lemma"]: + focus += 1 + if candidate == token["lemma"]: + focus_nongreedy += 1 + total += 1 + if token["form"] == token["lemma"]: + zero += 1 + if greedy_candidate == token["lemma"]: + greedy += 1 + else: + error_flag = True + if candidate == token["lemma"]: + nongreedy += 1 + else: + error_flag = True + if error_flag: + errors.append( + (token["form"], token["lemma"], candidate, greedy_candidate) + ) + + if total > 0: + csv_results_file_writer.writerow( + ( + filename.replace(".conllu", ""), + time.time() - start, + total, + (greedy / total) if total > 0 else 0, + (nongreedy / total) if total > 0 else 0, + (zero / total) if total > 0 else 0, + (focus / focus_total) if focus_total > 0 else 0, + (focus_nongreedy / focus_total) if focus_total > 0 else 0, + (focus_zero / focus_total) if focus_total > 0 else 0, + ) + ) + + with open( + path.join(RESULTS_FOLDER, filename.replace("conllu", "csv")), + "w", + encoding="utf-8", + ) as csvfile: + writer = csv.writer(csvfile) + writer.writerow(("form", "lemma", "candidate", "greedy_candidate")) + writer.writerows(errors) + + # print("exec time:\t %.3f" % (time.time() - start)) + # print("token count:\t", total) + # print("greedy:\t\t %.3f" % (greedy / total)) + # print("non-greedy:\t %.3f" % (nongreedy / total)) + # print("baseline:\t %.3f" % (zero / total)) + # print("ADJ+NOUN greedy:\t\t %.3f" % (focus / focus_total)) + # print("ADJ+NOUN non-greedy:\t\t %.3f" % (focus_nongreedy / focus_total)) + # print("ADJ+NOUN baseline:\t\t %.3f" % (focus_zero / focus_total)) + # mycounter = Counter(errors) + # print(mycounter.most_common(20)) diff --git a/training/requirements.txt b/training/requirements.txt new file mode 100644 index 0000000..7b806b7 --- /dev/null +++ b/training/requirements.txt @@ -0,0 +1,8 @@ +certifi==2023.7.22 +charset-normalizer==3.2.0 +conllu>=4.5.2 +idna==3.4 +requests==2.31.0 +types-requests==2.31.0.2 +types-urllib3==1.26.25.14 +urllib3==2.0.4