Skip to content

Commit

Permalink
updated check value and fixed customer certego.routing_history
Browse files Browse the repository at this point in the history
  • Loading branch information
giorgia-fusco8 committed Mar 12, 2024
1 parent 3f028dc commit eebb0eb
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 40 deletions.
2 changes: 1 addition & 1 deletion routing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_routing_history_no_customer(self):
self.routing.match(self.test_event_1)
self.routing.match(self.test_event_1, type_="customers")
self.assertTrue(self.test_event_1["certego"]["routing_history"]["Workshop"])
self.assertFalse(self.test_event_1["certego"]["routing_history"]["customer"]) #TODO FIX Customer should never be on history
self.assertNotIn("customer", self.test_event_1["certego"]["routing_history"])

def test_routing_history_stream_none(self):
self.routing.load_from_dicts([load_test_data("test_rule_1_equals")])
Expand Down
106 changes: 68 additions & 38 deletions routingfilter/filters/filters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import re
from abc import ABC, abstractmethod
from typing import Optional
from typing import NoReturn, Optional

import macaddress
from IPy import IP
Expand All @@ -19,21 +19,25 @@ def __init__(self, key, value, **kwargs):
def match(self, event: DictQuery) -> bool:
return NotImplemented

def _check_value(self) -> Optional[Exception]:
@abstractmethod
def _check_value(self) -> Exception | NoReturn:
"""
Check if values in self._value are correct and raise an exception if they are incorrect.
Check if values in self._value are correct and raise an exception if they are incorrect. If necessary, it converts value in lower case.
:return: no value or raise an exception
:rtype: Optional[Exception]
:rtype: NoReturn | Exception
"""
return None
return NotImplemented


class AllFilter(AbstractFilter):
def __init__(self):
key = value = []
super().__init__(key, value)

def _check_value(self) -> Exception | NoReturn:
return

def match(self, event: DictQuery) -> bool:
"""
Return always true.
Expand All @@ -51,6 +55,9 @@ def __init__(self, key):
value = []
super().__init__(key, value)

def _check_value(self) -> Exception | NoReturn:
return

def match(self, event: DictQuery) -> bool:
"""
Return True if one of the key exists in the event.
Expand Down Expand Up @@ -83,6 +90,13 @@ class EqualFilter(AbstractFilter):
def __init__(self, key, value):
super().__init__(key, value)

def _check_value(self) -> Exception | NoReturn:
tmp = []
for value in self._value:
value = value.lower() if isinstance(value, str) else str(value)
tmp.append(value)
self._value = tmp

def match(self, event: DictQuery):
"""
Check if at least a key matches at least one value.
Expand All @@ -92,16 +106,12 @@ def match(self, event: DictQuery):
:return: true if event matches, false otherwise
:rtype: bool
"""
filter_value = []
for value in self._value:
value = value.lower() if isinstance(value, str) else str(value)
filter_value.append(value)
for key in self._key:
event_value = event.get(key, [])
event_value = event_value if isinstance(event_value, list) else [event_value]
for value in event_value:
value = value.lower() if isinstance(value, str) else str(value)
if value in filter_value:
if value in self._value:
return True
return False

Expand All @@ -120,6 +130,13 @@ def match(self, event: DictQuery) -> bool:


class StartswithFilter(AbstractFilter):
def _check_value(self) -> Exception | NoReturn:
tmp = []
for prefix in self._value:
prefix = str(prefix).lower()
tmp.append(prefix)
self._value = tmp

def match(self, event: DictQuery) -> bool:
"""
Return True if at least one event value corresponding to a key starts with one of the value.
Expand Down Expand Up @@ -148,13 +165,19 @@ def _check_startswith(self, value: str) -> bool:
"""
value = value.lower()
for prefix in self._value:
prefix = str(prefix).lower()
if value.startswith(prefix):
return True
return False


class EndswithFilter(AbstractFilter):
def _check_value(self) -> Exception | NoReturn:
tmp = []
for suffix in self._value:
suffix = str(suffix).lower()
tmp.append(suffix)
self._value = tmp

def match(self, event: DictQuery) -> bool:
"""
Return True if at least one event value corresponding to a key ends with one of the value.
Expand All @@ -172,7 +195,7 @@ def match(self, event: DictQuery) -> bool:
return True
return False

def _check_endswith(self, value):
def _check_endswith(self, value: str) -> bool:
"""
Check if the value end with one of the suffix given.
Expand All @@ -183,13 +206,19 @@ def _check_endswith(self, value):
"""
value = value.lower()
for suffix in self._value:
suffix = str(suffix).lower()
if str(value).endswith(suffix):
if value.endswith(suffix):
return True
return False


class KeywordFilter(AbstractFilter):
def _check_value(self) -> Exception | NoReturn:
tmp = []
for keyword in self._value:
keyword = str(keyword).lower()
tmp.append(keyword)
self._value = tmp

def match(self, event: DictQuery) -> bool:
"""
Return True if at least one value is present in the event value of corresponding key.
Expand All @@ -216,16 +245,14 @@ def _check_keyword(self, value: str) -> bool:
:return: true or false
:rtype: bool
"""
value = value.lower()
for keyword in self._value:
keyword = str(keyword).lower()
if keyword in value:
if keyword in value.lower():
return True
return False


class RegexpFilter(AbstractFilter):
def _check_value(self) -> Optional[Exception]:
def _check_value(self) -> Exception | NoReturn:
"""
Check if values in self._value are valid regexes.
Expand All @@ -238,7 +265,6 @@ def _check_value(self) -> Optional[Exception]:
except re.error as e:
self.logger.error(f"Invalid regex {value}, during check of value list {self._value}. Error message: {e}")
raise ValueError(f"Regex check failed: error for value {value}. Error message: {e}")
return None

def match(self, event: DictQuery) -> bool:
"""
Expand Down Expand Up @@ -276,7 +302,7 @@ class NetworkFilter(AbstractFilter):
def __init__(self, key, value):
super().__init__(key, value)

def _check_value(self) -> Optional[Exception]:
def _check_value(self) -> Exception | NoReturn:
"""
Check if the values in self._value are valid IP addresses.
Expand Down Expand Up @@ -351,17 +377,20 @@ class DomainFilter(AbstractFilter):
def __init__(self, key, value):
super().__init__(key, value)

def _check_value(self) -> Optional[Exception]:
def _check_value(self) -> Exception | NoReturn:
"""
Check if values in self._value are string.
:return: none or error generated
:rtype: bool
"""
for value in self._value:
if not isinstance(value, str):
raise ValueError(f"Domain check failed: value {value} is not a string.")
return None
tmp = []
for domain in self._value:
if not isinstance(domain, str):
raise ValueError(f"Domain check failed: value {domain} is not a string.")
domain = str(domain).lower()
tmp.append(domain)
self._value = tmp

def match(self, event: DictQuery) -> bool:
"""
Expand Down Expand Up @@ -390,8 +419,7 @@ def _check_domain(self, value: str) -> bool:
"""
value = value.lower()
for domain in self._value:
domain = str(domain).lower()
if value == domain or str(value).endswith(f".{domain}"):
if value == domain or value.endswith(f".{domain}"):
return True
return False

Expand All @@ -402,32 +430,32 @@ def __init__(self, key, value, comparator_type):
self._check_comparator_type()
super().__init__(key, value)

def _check_value(self) -> Optional[Exception]:
def _check_value(self) -> Exception | NoReturn:
"""
Check if values in self._value are float.
:return: none or error generated
:rtype: Optional[Exception]
:rtype: Exception | NoReturn
"""
tmp = []
for value in self._value:
try:
float(value)
tmp.append(float(value))
except ValueError:
self.logger.error(f"Comparator check failed: value {value} of list {self._value} is not a float")
raise ValueError(f"Comparator check failed: value {value} is not a float")
return None
self._value = tmp

def _check_comparator_type(self) -> Optional[Exception]:
def _check_comparator_type(self) -> Exception | NoReturn:
"""
Check if comparator is valid.
:return: none or error generated
:rtype: Optional[Exception]
:rtype: Exception | NoReturn
"""
if self._comparator_type not in ["GREATER", "LESS", "GREATER_EQ", "LESS_EQ"]:
self.logger.error(f"Comparator check failed: value {self._comparator_type} is not valid.")
raise ValueError(f"Comparator type check failed. {self._comparator_type} is not a valid comparator.")
return None

def match(self, event: DictQuery) -> bool:
"""
Expand Down Expand Up @@ -457,7 +485,6 @@ def _compare(self, value: float) -> bool:
"""
for term in self._value:
try:
term = float(term)
value = float(value)
except ValueError as e:
self.logger.debug(f"Error in parsing value to float in comparator filter: {e}. ")
Expand All @@ -482,19 +509,22 @@ class TypeofFilter(AbstractFilter):
def __init__(self, key, value):
super().__init__(key, value)

def _check_value(self) -> Optional[Exception]:
def _check_value(self) -> Exception | NoReturn:
"""
Check if value is a correct type.
:return: no value or raised an exception
:rtype: Optional[Exception]
:rtype: NoReturn | Exception
"""
valid_type = ["str", "int", "float", "bool", "list", "dict", "ip", "mac"]
tmp = []
for value in self._value:
value = str(value).lower()
if value not in valid_type:
self.logger.error(f"Type check failed: value {value} of list {self._value} is invalid.")
raise ValueError(f"Type check failed: value {value} is invalid.")
return None
tmp.append(value)
self._value = tmp

def match(self, event: DictQuery) -> bool:
"""
Expand Down
3 changes: 2 additions & 1 deletion routingfilter/filters/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def match(self, event: DictQuery) -> Results | None:
if key in routing_history:
output_copy.pop(key)
else:
routing_history.update({key: now})
if key != "customer":
routing_history.update({key: now})
results = Results(rules=self.uid, output=output_copy)
return results

Expand Down

0 comments on commit eebb0eb

Please sign in to comment.