diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index c73b552..30742a8 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -69,6 +69,10 @@ jobs:
if: matrix.python-version != '3.6' && matrix.python-version != '3.7'
run: pip install -r requirements-dev.txt
+ - name: Install training dependencies
+ if: matrix.python-version != '3.6' && matrix.python-version != '3.7'
+ run: pip install -r training/requirements-dev.txt
+
- name: Install dependencies (legacy versions)
if: matrix.python-version == '3.6' || matrix.python-version == '3.7'
run: |
diff --git a/.gitignore b/.gitignore
index ab5d3e5..f634b25 100644
--- a/.gitignore
+++ b/.gitignore
@@ -116,3 +116,4 @@ Makefile
# eval
UD/
+training/**/data
\ No newline at end of file
diff --git a/eval/README.rst b/eval/README.rst
deleted file mode 100644
index b341409..0000000
--- a/eval/README.rst
+++ /dev/null
@@ -1,13 +0,0 @@
-Instructions to run the evaluation
-----------------------------------
-
-The scores are calculated on `Universal Dependencies `_ treebanks on single word tokens (including some contractions but not merged prepositions). They can be reproduced by concatenating all available UD files and by using the script ``udscore.py``.
-
-1. Download the main archive containing the `Universal Dependencies `_ treebanks
-2. Extract relevant data (language and if applicable specific treebank, see notes in the results table)
-3. Concatenate the train, dev and test data into a single file (e.g. ``cat de_gsd*.conllu > de-gsd-all.conllu``)
-4. Store the files at the expected location (``tests/UD/``) and/or edit the corresponding paths in ``udscore.py``
-5. Install the evaluation dependencies (``pip install -r eval-requirements.txt``)
-6. Run the script, e.g. from the home directory ``python3 eval/udscore.py``
-7. Results are displayed in the terminal and errors are written in a CSV file
-
diff --git a/eval/eval-requirements.txt b/eval/eval-requirements.txt
deleted file mode 100644
index 634aded..0000000
--- a/eval/eval-requirements.txt
+++ /dev/null
@@ -1 +0,0 @@
-conllu>=4.5.2
diff --git a/eval/udscore.py b/eval/udscore.py
deleted file mode 100644
index 0238a12..0000000
--- a/eval/udscore.py
+++ /dev/null
@@ -1,134 +0,0 @@
-import csv
-import time
-
-from collections import Counter
-from os import path
-
-from conllu import parse_incr # type: ignore
-from simplemma import Lemmatizer
-from simplemma.strategies.dictionaries import DefaultDictionaryFactory
-from simplemma.strategies.default import DefaultStrategy
-
-
-data_files = [
- ("bg", "tests/UD/bg-btb-all.conllu"),
- ("cs", "tests/UD/cs-pdt-all.conllu"), # longer to process
- ("da", "tests/UD/da-ddt-all.conllu"),
- ("de", "tests/UD/de-gsd-all.conllu"),
- ("el", "tests/UD/el-gdt-all.conllu"),
- ("en", "tests/UD/en-gum-all.conllu"),
- ("es", "tests/UD/es-gsd-all.conllu"),
- ("et", "tests/UD/et-edt-all.conllu"),
- ("fi", "tests/UD/fi-tdt-all.conllu"),
- ("fr", "tests/UD/fr-gsd-all.conllu"),
- ("ga", "tests/UD/ga-idt-all.conllu"),
- ("hi", "tests/UD/hi-hdtb-all.conllu"),
- ("hu", "tests/UD/hu-szeged-all.conllu"),
- ("hy", "tests/UD/hy-armtdp-all.conllu"),
- ("id", "tests/UD/id-csui-all.conllu"),
- ("it", "tests/UD/it-isdt-all.conllu"),
- ("la", "tests/UD/la-proiel-all.conllu"),
- ("lt", "tests/UD/lt-alksnis-all.conllu"),
- ("lv", "tests/UD/lv-lvtb-all.conllu"),
- ("nb", "tests/UD/nb-bokmaal-all.conllu"),
- ("nl", "tests/UD/nl-alpino-all.conllu"),
- ("pl", "tests/UD/pl-pdb-all.conllu"),
- ("pt", "tests/UD/pt-gsd-all.conllu"),
- ("ru", "tests/UD/ru-gsd-all.conllu"),
- ("sk", "tests/UD/sk-snk-all.conllu"),
- ("tr", "tests/UD/tr-boun-all.conllu"),
-]
-
-# doesn't work: right-to-left?
-# data_files = [
-# ('he', 'tests/UD/he-htb-all.conllu'),
-# ('ur', 'tests/UD/ur-udtb-all.conllu'),
-# ]
-
-# data_files = [
-# ('de', 'tests/UD/de-gsd-all.conllu'),
-#]
-
-
-for filedata in data_files:
- total, focus_total, _greedy, nongreedy, zero, focus_zero, focus, focus_nongreedy = (
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- )
- errors, flag = [], False
- language, filename = filedata[0], filedata[1]
-
- start = time.time()
- _dictionary_factory = DefaultDictionaryFactory()
- strategies = DefaultStrategy(greedy=False)
- lemmatizer = Lemmatizer(
- lemmatization_strategy=DefaultStrategy(
- greedy=False, dictionary_factory=_dictionary_factory
- ),
- )
- greedy_lemmatizer = Lemmatizer(
- lemmatization_strategy=DefaultStrategy(
- greedy=True, dictionary_factory=_dictionary_factory
- ),
- )
- print("==", filedata, "==")
- with open(filename, "r", encoding="utf-8") as data_file:
- for tokenlist in parse_incr(data_file):
- for token in tokenlist:
- error_flag = False
- if token["lemma"] == "_": # or token['upos'] in ('PUNCT', 'SYM')
- # flag = True
- continue
-
- initial = bool(token["id"] == 1)
- token_form = token["form"].lower() if initial else token["form"]
-
- candidate = lemmatizer.lemmatize(token_form, lang=language)
- greedy_candidate = greedy_lemmatizer.lemmatize(token_form, lang=language)
-
- if token["upos"] in ("ADJ", "NOUN"):
- focus_total += 1
- if token["form"] == token["lemma"]:
- focus_zero += 1
- if greedy_candidate == token["lemma"]:
- focus += 1
- if candidate == token["lemma"]:
- focus_nongreedy += 1
- total += 1
- if token["form"] == token["lemma"]:
- zero += 1
- if greedy_candidate == token["lemma"]:
- _greedy += 1
- else:
- error_flag = True
- if candidate == token["lemma"]:
- nongreedy += 1
- else:
- error_flag = True
- if error_flag:
- errors.append(
- (token["form"], token["lemma"], candidate, greedy_candidate)
- )
- with open(
- f'{path.basename(filename).replace("conllu","csv")}', "w", encoding="utf-8"
- ) as csvfile:
- writer = csv.writer(csvfile)
- writer.writerow(("form", "lemma", "candidate", "greedy_candidate"))
- writer.writerows(errors)
-
- print("exec time:\t %.3f" % (time.time() - start))
- print("token count:\t", total)
- print("greedy:\t\t %.3f" % (_greedy / total))
- print("non-greedy:\t %.3f" % (nongreedy / total))
- print("baseline:\t %.3f" % (zero / total))
- print("ADJ+NOUN greedy:\t\t %.3f" % (focus / focus_total))
- print("ADJ+NOUN non-greedy:\t\t %.3f" % (focus_nongreedy / focus_total))
- print("ADJ+NOUN baseline:\t\t %.3f" % (focus_zero / focus_total))
- mycounter = Counter(errors)
- print(mycounter.most_common(20))
diff --git a/simplemma/language_detector.py b/simplemma/language_detector.py
index f6fab14..52369a6 100644
--- a/simplemma/language_detector.py
+++ b/simplemma/language_detector.py
@@ -198,18 +198,14 @@ def proportion_in_target_languages(
Returns:
float: The proportion of text in the target language(s).
"""
- tokens = self._token_sampler.sample_text(text)
- if len(tokens) == 0:
- return 0
-
- in_target = 0
- for token in tokens:
- for lang_code in self._lang:
- candidate = self._lemmatization_strategy.get_lemma(token, lang_code)
- if candidate is not None:
- in_target += 1
- break
- return in_target / len(tokens)
+ return sum(
+ percentage
+ for (
+ lang_code,
+ percentage,
+ ) in self.proportion_in_each_language(text).items()
+ if lang_code != "unk"
+ )
def main_language(
self,
diff --git a/tests/test_language_detector.py b/tests/test_language_detector.py
index ad68f0c..4affdd1 100644
--- a/tests/test_language_detector.py
+++ b/tests/test_language_detector.py
@@ -108,15 +108,6 @@ def test_in_target_language() -> None:
== 1.0
)
- langs = ("en", "de")
- text = "It was a true gift"
- assert (
- LanguageDetector(lang=langs).proportion_in_target_languages(text)
- == in_target_language(text, lang=langs)
- == 1.0
- )
- in_target_language("It was a true gift", lang=("en", "de"))
-
def test_main_language():
text = "Dieser Satz ist auf Deutsch."
diff --git a/training/README.rst b/training/README.rst
new file mode 100644
index 0000000..ba4c630
--- /dev/null
+++ b/training/README.rst
@@ -0,0 +1,15 @@
+Instructions to run the evaluation
+----------------------------------
+
+The scores are calculated on `Universal Dependencies `_ treebanks on single word tokens (including some contractions but not merged prepositions). They can be reproduced by the following steps:
+
+1. Install the evaluation dependencies (``pip install -r training/requirements.txt``)
+2. Update ``DATA_URL`` in ``training/download-eval-data.py`` to point to the latest treebanks archive from `Universal Dependencies ` (or the version that you which to use).
+3. Run ``python3 training/download-eval-data.py`` which will
+ 1. Download the archive
+ 2. Extract relevant data (language and if applicable specific treebank, see notes in the results table)
+ 3. Concatenate the train, dev and test data into a single file (e.g. ``cat de_gsd*.conllu > de-gsd-all.conllu``)
+ 4. Store the files at the expected location (``training/data/UD/``)
+4. Run the script, e.g. from the home directory ``python3 training/evaluate_simplema.py``
+5. Results are stored at ``training/data/results/results_summary.csv``. Also, errors are written in a CSV file for each dataset under the ``data/results``folder.
+
diff --git a/training/download-eval-data.py b/training/download-eval-data.py
new file mode 100644
index 0000000..5451392
--- /dev/null
+++ b/training/download-eval-data.py
@@ -0,0 +1,71 @@
+from typing import Iterable, List, Tuple
+from os import mkdir, path, scandir
+import re
+import logging
+import tarfile
+import requests
+from glob import glob
+
+from simplemma.strategies.dictionaries.dictionary_factory import SUPPORTED_LANGUAGES
+
+log = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+DATA_URL = "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-5150/ud-treebanks-v2.12.tgz?sequence=1&isAllowed=y"
+DATA_FOLDER = path.join(path.dirname(__file__), "data")
+DATA_FILE = path.join(DATA_FOLDER, "ud-treeebanks.tgz")
+CLEAN_DATA_FOLDER = path.join(DATA_FOLDER, "UD")
+
+
+def get_dirs(file_name: str) -> List[str]:
+ return [dir.name for dir in scandir(file_name) if dir.is_dir()]
+
+
+def get_files(file_name: str) -> List[str]:
+ return [dir.name for dir in scandir(file_name) if dir.is_file()]
+
+
+def get_relevant_language_data_folders(data_folder) -> Iterable[Tuple[str, str, str]]:
+ for lang_folder in get_dirs(data_folder):
+ lang_data_folder = path.join(uncompressed_data_folder, lang_folder)
+ conllu_file = glob(path.join(lang_data_folder, "*.conllu"))[0]
+ matches_files = re.search("^.*/(.*)-ud.*$", conllu_file)
+ if matches_files is not None:
+ dataset_name = matches_files.groups()[0]
+ lang = dataset_name.split("_")[0]
+
+ if lang in SUPPORTED_LANGUAGES:
+ yield (lang, dataset_name, lang_data_folder)
+
+
+if path.exists(DATA_FOLDER) or path.exists(CLEAN_DATA_FOLDER):
+ raise Exception(
+ "Data folder seems to be already present. Delete it before creating new data."
+ )
+
+mkdir(DATA_FOLDER)
+mkdir(CLEAN_DATA_FOLDER)
+
+log.info("Downloading evaluation data...")
+response = requests.get(DATA_URL)
+open(DATA_FILE, "wb").write(response.content)
+
+log.info("Uncompressing evaluation data...")
+with tarfile.open(DATA_FILE) as tar:
+ tar.extractall(DATA_FOLDER)
+uncompressed_data_folder = path.join(
+ DATA_FOLDER, glob(f"{DATA_FOLDER}/ud-treebanks-*")[0]
+)
+
+log.info("Filtering files...")
+for lang, dataset_name, dataset__folder in get_relevant_language_data_folders(
+ uncompressed_data_folder
+):
+ log.info(lang + " - " + dataset__folder)
+ # Concatenate the train, dev and test data into a single file (e.g. ``cat de_gsd*.conllu > de-gsd-all.conllu``)
+ lang_clean_data_file = path.join(CLEAN_DATA_FOLDER, f"{dataset_name}.conllu")
+ log.debug(f"Procressing data for {dataset_name}")
+ with open(lang_clean_data_file, "w") as outfile:
+ for file in glob(path.join(dataset__folder, "*.conllu")):
+ with open(file) as infile:
+ for line in infile:
+ outfile.write(line)
diff --git a/training/evaluate_simplema.py b/training/evaluate_simplema.py
new file mode 100644
index 0000000..4cd0dec
--- /dev/null
+++ b/training/evaluate_simplema.py
@@ -0,0 +1,152 @@
+import csv
+import time
+
+from collections import Counter
+from os import mkdir, path, rmdir, scandir, unlink
+import logging
+
+from conllu import parse_incr # type: ignore
+
+from simplemma import Lemmatizer
+from simplemma.strategies.dictionaries import DefaultDictionaryFactory
+from simplemma.strategies.default import DefaultStrategy
+
+log = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+DATA_FOLDER = path.join(path.dirname(__file__), "data")
+CLEAN_DATA_FOLDER = path.join(DATA_FOLDER, "UD")
+RESULTS_FOLDER = path.join(DATA_FOLDER, "results")
+
+data_files = [
+ (data_file.name.split("_")[0], data_file.name)
+ for data_file in scandir(CLEAN_DATA_FOLDER)
+]
+
+if not path.exists(CLEAN_DATA_FOLDER):
+ raise Exception(
+ "It doesn't seem like data was downloaded and precessed for evaluation."
+ )
+
+if path.exists(RESULTS_FOLDER):
+ for result_file in scandir(RESULTS_FOLDER):
+ unlink(result_file)
+ rmdir(RESULTS_FOLDER)
+mkdir(RESULTS_FOLDER)
+
+with open(path.join(RESULTS_FOLDER, "results_summary.csv"), "w") as csv_results_file:
+ csv_results_file_writer = csv.writer(csv_results_file)
+ csv_results_file_writer.writerow(
+ (
+ "dataset",
+ "exec time",
+ "token count",
+ "greedy",
+ "non-greedy",
+ "baseline",
+ "ADJ+NOUN greedy",
+ "ADJ+NOUN non-greedy",
+ "ADJ+NOUN baseline",
+ )
+ )
+
+ for language, filename in data_files:
+ total = 0
+ focus_total = 0
+ greedy = 0
+ nongreedy = 0
+ zero = 0
+ focus = 0
+ focus_nongreedy = 0
+ focus_zero = 0
+ errors = []
+
+ start = time.time()
+ _dictionary_factory = DefaultDictionaryFactory()
+ strategies = DefaultStrategy(greedy=False)
+ lemmatizer = Lemmatizer(
+ lemmatization_strategy=DefaultStrategy(
+ greedy=False, dictionary_factory=_dictionary_factory
+ ),
+ )
+ greedy_lemmatizer = Lemmatizer(
+ lemmatization_strategy=DefaultStrategy(
+ greedy=True, dictionary_factory=_dictionary_factory
+ ),
+ )
+ log.info(f"Evaluating dataset: {filename}")
+ with open(
+ path.join(CLEAN_DATA_FOLDER, filename), "r", encoding="utf-8"
+ ) as data_file:
+ for tokens in parse_incr(data_file):
+ for token in tokens:
+ error_flag = False
+ if token["lemma"] == "_": # or token['upos'] in ('PUNCT', 'SYM')
+ continue
+
+ initial = bool(token["id"] == 1)
+ token_form = token["form"].lower() if initial else token["form"]
+
+ candidate = lemmatizer.lemmatize(token_form, lang=language)
+ greedy_candidate = greedy_lemmatizer.lemmatize(
+ token_form, lang=language
+ )
+
+ if token["upos"] in ("ADJ", "NOUN"):
+ focus_total += 1
+ if token["form"] == token["lemma"]:
+ focus_zero += 1
+ if greedy_candidate == token["lemma"]:
+ focus += 1
+ if candidate == token["lemma"]:
+ focus_nongreedy += 1
+ total += 1
+ if token["form"] == token["lemma"]:
+ zero += 1
+ if greedy_candidate == token["lemma"]:
+ greedy += 1
+ else:
+ error_flag = True
+ if candidate == token["lemma"]:
+ nongreedy += 1
+ else:
+ error_flag = True
+ if error_flag:
+ errors.append(
+ (token["form"], token["lemma"], candidate, greedy_candidate)
+ )
+
+ if total > 0:
+ csv_results_file_writer.writerow(
+ (
+ filename.replace(".conllu", ""),
+ time.time() - start,
+ total,
+ (greedy / total) if total > 0 else 0,
+ (nongreedy / total) if total > 0 else 0,
+ (zero / total) if total > 0 else 0,
+ (focus / focus_total) if focus_total > 0 else 0,
+ (focus_nongreedy / focus_total) if focus_total > 0 else 0,
+ (focus_zero / focus_total) if focus_total > 0 else 0,
+ )
+ )
+
+ with open(
+ path.join(RESULTS_FOLDER, filename.replace("conllu", "csv")),
+ "w",
+ encoding="utf-8",
+ ) as csvfile:
+ writer = csv.writer(csvfile)
+ writer.writerow(("form", "lemma", "candidate", "greedy_candidate"))
+ writer.writerows(errors)
+
+ # print("exec time:\t %.3f" % (time.time() - start))
+ # print("token count:\t", total)
+ # print("greedy:\t\t %.3f" % (greedy / total))
+ # print("non-greedy:\t %.3f" % (nongreedy / total))
+ # print("baseline:\t %.3f" % (zero / total))
+ # print("ADJ+NOUN greedy:\t\t %.3f" % (focus / focus_total))
+ # print("ADJ+NOUN non-greedy:\t\t %.3f" % (focus_nongreedy / focus_total))
+ # print("ADJ+NOUN baseline:\t\t %.3f" % (focus_zero / focus_total))
+ # mycounter = Counter(errors)
+ # print(mycounter.most_common(20))
diff --git a/training/requirements.txt b/training/requirements.txt
new file mode 100644
index 0000000..7b806b7
--- /dev/null
+++ b/training/requirements.txt
@@ -0,0 +1,8 @@
+certifi==2023.7.22
+charset-normalizer==3.2.0
+conllu>=4.5.2
+idna==3.4
+requests==2.31.0
+types-requests==2.31.0.2
+types-urllib3==1.26.25.14
+urllib3==2.0.4