Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

maintenance: simplify code and add missing type annotations #144

Merged
merged 5 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions simplemma/strategies/defaultrules/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
)

PLUR_ORTH_DE = re.compile(r"(?:Innen|\*innen|\*Innen|-innen|_innen)$")
PP_DE = re.compile(r"^.{2,}ge.+?[^aes]t(?:e|em|er|es)$")

PP_DE = re.compile(r"^(.{2,}ge.+?[^aes]t)(?:e|em|er|es)$")

ENDING_CHARS_DE = {"e", "m", "n", "r", "s"}
ENDING_DE = re.compile(r"(?:e|em|er|es)$")


def apply_de(token: str) -> Optional[str]:
Expand All @@ -35,24 +35,20 @@ def apply_de(token: str) -> Optional[str]:
# nouns
if token[0].isupper():
# noun endings/suffixes: regex search
match = NOUN_ENDINGS_DE.search(token)
if match:
if match := NOUN_ENDINGS_DE.search(token):
# apply pattern
ending = next((g for g in match.groups() if g is not None), None)
if ending:
return token[: -len(ending)]
# lemma identified
return token
ending = next((g for g in match.groups() if g), None)
return token[: -len(ending)] if ending else token
# inclusive speech
# Binnen-I: ArbeitnehmerInnenschutzgesetz?
if PLUR_ORTH_DE.search(token):
return PLUR_ORTH_DE.sub(":innen", token)

# mostly adjectives and verbs
elif token[-1] in ENDING_CHARS_DE:
if ADJ_ENDINGS_DE.match(token):
return ADJ_ENDINGS_DE.sub(r"\1\2", token).lower()
if PP_DE.search(token):
return ENDING_DE.sub("", token).lower()
if adj_match := ADJ_ENDINGS_DE.match(token):
return (adj_match[1] + adj_match[2]).lower()
if pp_match := PP_DE.match(token):
return pp_match[1].lower()

return None
12 changes: 6 additions & 6 deletions simplemma/strategies/dictionaries/dictionary_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from functools import lru_cache
from os import listdir, path
from pathlib import Path
from typing import ByteString, Dict, Mapping, Protocol
from typing import ByteString, Dict, Iterator, Mapping, Protocol

DATA_FOLDER = str(Path(__file__).parent / "data")
SUPPORTED_LANGUAGES = [
Expand Down Expand Up @@ -83,17 +83,17 @@ class MappingStrToByteString(Mapping[str, str]):

__slots__ = ["_dict"]

def __init__(self, dictionary: Dict[bytes, bytes]):
def __init__(self, dictionary: Dict[bytes, bytes]) -> None:
self._dict = dictionary

def __getitem__(self, item: str):
def __getitem__(self, item: str) -> str:
return self._dict[item.encode()].decode()

def __iter__(self):
def __iter__(self) -> Iterator[str]:
for key in self._dict:
yield key.decode()

def __len__(self):
def __len__(self) -> int:
return len(self._dict)


Expand All @@ -107,7 +107,7 @@ class DefaultDictionaryFactory(DictionaryFactory):

__slots__ = ["_load_dictionary_from_disk"]

def __init__(self, cache_max_size: int = 8):
def __init__(self, cache_max_size: int = 8) -> None:
"""
Initialize the DefaultDictionaryFactory.

Expand Down
24 changes: 12 additions & 12 deletions simplemma/strategies/dictionaries/trie_directory_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from collections.abc import MutableMapping
from functools import lru_cache
from pathlib import Path
from typing import List, Mapping, Optional
from typing import Any, Iterator, List, Mapping, Optional

try:
from marisa_trie import BytesTrie, HUGE_CACHE # type: ignore[import-not-found]
from platformdirs import user_cache_dir
except ImportError:

class BytesTrie: # type: ignore[no-redef]
def __init__(self):
def __init__(self) -> None:
raise ImportError("marisa_trie and platformdirs packages not installed")


Expand All @@ -24,26 +24,25 @@ def __init__(self):
logger = logging.getLogger(__name__)


class TrieWrapDict(MutableMapping):
class TrieWrapDict(MutableMapping): # Python > 3.8: [str, Any]
"""Wrapper around BytesTrie to make them behave like dicts."""

def __init__(self, trie: BytesTrie):
def __init__(self, trie: BytesTrie) -> None:
self._trie = trie

def __getitem__(self, item):
def __getitem__(self, item: str) -> Any:
return self._trie[item][0].decode()

def __setitem__(self, key, value):
def __setitem__(self, key: Any, value: Any) -> None:
raise NotImplementedError

def __delitem__(self, key):
def __delitem__(self, key: Any) -> None:
raise NotImplementedError

def __iter__(self):
for key in self._trie.iterkeys():
yield key
def __iter__(self) -> Iterator[str]:
yield from self._trie.iterkeys()

def __len__(self):
def __len__(self) -> int:
return len(self._trie)


Expand All @@ -56,7 +55,7 @@ class TrieDictionaryFactory(DictionaryFactory):
lookup performance isn't as good as with dicts.
"""

__slots__: List[str] = []
__slots__: List[str] = ["cache_max_size", "disk_cache_dir", "use_disk_cache"]

def __init__(
self,
Expand Down Expand Up @@ -127,4 +126,5 @@ def get_dictionary(
self,
lang: str,
) -> Mapping[str, str]:
"Retrieves a dictionary for the specified language."
return self._get_dictionary(lang)
3 changes: 1 addition & 2 deletions simplemma/strategies/dictionary_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def get_lemma(self, token: str, lang: str) -> Optional[str]:
"""
# Search the language data, reverse case to extend coverage.
dictionary = self._dictionary_factory.get_dictionary(lang)
result = dictionary.get(token)
if result:
if result := dictionary.get(token):
return result
# Try upper or lowercase.
token = token.lower() if token[0].isupper() else token.capitalize()
Expand Down
16 changes: 7 additions & 9 deletions simplemma/strategies/greedy_dictionary_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,17 @@ def get_lemma(self, token: str, lang: str) -> str:
return token

dictionary = self._dictionary_factory.get_dictionary(lang)
candidate = token
for _ in range(self._steps):
if candidate not in dictionary:
break

new_candidate = dictionary[candidate]
for _ in range(self._steps):
candidate = dictionary.get(token)

if (
len(new_candidate) > len(candidate)
or levenshtein_dist(new_candidate, candidate) > self._distance
not candidate
or len(candidate) > len(token)
or levenshtein_dist(candidate, token) > self._distance
):
break

candidate = new_candidate
token = candidate

return candidate
return token
10 changes: 4 additions & 6 deletions simplemma/strategies/hyphen_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from .lemmatization_strategy import LemmatizationStrategy

HYPHENS = {"-", "_"}
HYPHENS_FOR_REGEX = "".join(HYPHENS)
HYPHEN_REGEX = re.compile(rf"([{HYPHENS_FOR_REGEX}])")
HYPHEN_REGEX = re.compile(rf"([{''.join(HYPHENS)}])")


class HyphenRemovalStrategy(LemmatizationStrategy):
Expand Down Expand Up @@ -69,9 +68,8 @@ def get_lemma(self, token: str, lang: str) -> Optional[str]:
return lemma

# decompose
lemma = self._dictionary_lookup.get_lemma(token_parts[-1], lang)
if lemma is not None:
token_parts[-1] = lemma
return "".join(token_parts)
last_part_lemma = self._dictionary_lookup.get_lemma(token_parts[-1], lang)
if last_part_lemma is not None:
return "".join(token_parts[:-1] + [last_part_lemma])

return None
10 changes: 3 additions & 7 deletions simplemma/strategies/prefix_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,11 @@ def get_lemma(self, token: str, lang: str) -> Optional[str]:
return None

prefix_match = self._known_prefixes[lang].match(token)
if not prefix_match:
if not prefix_match or prefix_match[1] == token:
return None
prefix = prefix_match[1]

if prefix == token:
return None
prefix = prefix_match[1]

subword = self._dictionary_lookup.get_lemma(token[len(prefix) :], lang)
if subword is None:
return None

return prefix + subword.lower()
return prefix + subword.lower() if subword else None
Loading