22
33import datetime as dt
44from abc import ABC , abstractmethod
5- from typing import Union , Optional , TYPE_CHECKING , Type , Iterator , Literal
5+ from typing import Union , Optional , TYPE_CHECKING , Type , Iterator , Literal , TypeAlias
66
77if TYPE_CHECKING :
88 from O365 .connection import Protocol
99
10- FilterWord = Union [str , bool , None , dt .date , int , float ]
10+ FilterWord : TypeAlias = Union [str , bool , None , dt .date , int , float ]
1111
1212
1313class QueryBase (ABC ):
@@ -35,6 +35,34 @@ def __and__(self, other):
3535 def __or__ (self , other ):
3636 pass
3737
38+ def get_filter_by_attribute (self , attribute : str ) -> Optional [str ]:
39+ """
40+ Returns a filter value by attribute name. It will match the attribute to the start of each filter attribute
41+ and return the first found.
42+ :param attribute: the attribute you want to search
43+ :return: The value applied to that attribute or None
44+ """
45+ search_object : Optional [QueryFilter ] = getattr (self , "_filter_instance" ) or getattr (self , "filters" )
46+ if search_object is not None :
47+ # CompositeFilter, IterableFilter, ModifierQueryFilter (negate, group)
48+ return search_object .get_filter_by_attribute (attribute )
49+
50+ search_object : Optional [list [QueryFilter ]] = getattr (self , "_filter_instances" )
51+ if search_object is not None :
52+ # ChainFilter
53+ for filter_obj in search_object :
54+ result = filter_obj .get_filter_by_attribute (attribute )
55+ if result is not None :
56+ return result
57+ return None
58+
59+ search_object : Optional [str ] = getattr (self , "_attribute" )
60+ if search_object is not None :
61+ # LogicalFilter or FunctionFilter
62+ if search_object .lower ().startswith (attribute .lower ()):
63+ return getattr (self , "_word" )
64+ return None
65+
3866
3967class QueryFilter (QueryBase , ABC ):
4068 __slots__ = ()
@@ -428,7 +456,7 @@ def _parse_filter_word(self, word: FilterWord) -> str:
428456 # bools are treated as lower case bools
429457 parsed_word = str (word ).lower ()
430458 elif word is None :
431- parsed_word = ' null'
459+ parsed_word = " null"
432460 elif isinstance (word , dt .date ):
433461 if isinstance (word , dt .datetime ):
434462 if word .tzinfo is None :
@@ -451,9 +479,9 @@ def _get_attribute_from_mapping(self, attribute: str) -> str:
451479 """
452480 mapping = self ._attribute_mapping .get (attribute )
453481 if mapping :
454- attribute = '/' .join (
482+ attribute = "/" .join (
455483 [self .protocol .convert_case (step ) for step in
456- mapping .split ('/' )])
484+ mapping .split ("/" )])
457485 else :
458486 attribute = self .protocol .convert_case (attribute )
459487 return attribute
@@ -475,7 +503,7 @@ def equals(self, attribute: str, word: FilterWord) -> LogicalFilter:
475503 :param word: word to compare with
476504 :return: a QueryFilter instance that can render the OData this logical operation
477505 """
478- return self .logical_operation ('eq' , attribute , word )
506+ return self .logical_operation ("eq" , attribute , word )
479507
480508 def unequal (self , attribute : str , word : FilterWord ) -> LogicalFilter :
481509 """ Return an unequal check
@@ -484,7 +512,7 @@ def unequal(self, attribute: str, word: FilterWord) -> LogicalFilter:
484512 :param word: word to compare with
485513 :return: a QueryFilter instance that can render the OData this logical operation
486514 """
487- return self .logical_operation ('ne' , attribute , word )
515+ return self .logical_operation ("ne" , attribute , word )
488516
489517 def greater (self , attribute : str , word : FilterWord ) -> LogicalFilter :
490518 """ Return a 'greater than' check
@@ -493,7 +521,7 @@ def greater(self, attribute: str, word: FilterWord) -> LogicalFilter:
493521 :param word: word to compare with
494522 :return: a QueryFilter instance that can render the OData this logical operation
495523 """
496- return self .logical_operation ('gt' , attribute , word )
524+ return self .logical_operation ("gt" , attribute , word )
497525
498526 def greater_equal (self , attribute : str , word : FilterWord ) -> LogicalFilter :
499527 """ Return a 'greater than or equal to' check
@@ -502,7 +530,7 @@ def greater_equal(self, attribute: str, word: FilterWord) -> LogicalFilter:
502530 :param word: word to compare with
503531 :return: a QueryFilter instance that can render the OData this logical operation
504532 """
505- return self .logical_operation ('ge' , attribute , word )
533+ return self .logical_operation ("ge" , attribute , word )
506534
507535 def less (self , attribute : str , word : FilterWord ) -> LogicalFilter :
508536 """ Return a 'less than' check
@@ -511,7 +539,7 @@ def less(self, attribute: str, word: FilterWord) -> LogicalFilter:
511539 :param word: word to compare with
512540 :return: a QueryFilter instance that can render the OData this logical operation
513541 """
514- return self .logical_operation ('lt' , attribute , word )
542+ return self .logical_operation ("lt" , attribute , word )
515543
516544 def less_equal (self , attribute : str , word : FilterWord ) -> LogicalFilter :
517545 """ Return a 'less than or equal to' check
@@ -520,7 +548,7 @@ def less_equal(self, attribute: str, word: FilterWord) -> LogicalFilter:
520548 :param word: word to compare with
521549 :return: a QueryFilter instance that can render the OData this logical operation
522550 """
523- return self .logical_operation ('le' , attribute , word )
551+ return self .logical_operation ("le" , attribute , word )
524552
525553 def function_operation (self , operation : str , attribute : str , word : FilterWord ) -> FunctionFilter :
526554 """ Apply a function operation
@@ -539,7 +567,7 @@ def contains(self, attribute: str, word: FilterWord) -> FunctionFilter:
539567 :param word: value to feed the function
540568 :return: a QueryFilter instance that can render the OData function operation
541569 """
542- return self .function_operation (' contains' , attribute , word )
570+ return self .function_operation (" contains" , attribute , word )
543571
544572 def startswith (self , attribute : str , word : FilterWord ) -> FunctionFilter :
545573 """ Adds a startswith word check
@@ -548,7 +576,7 @@ def startswith(self, attribute: str, word: FilterWord) -> FunctionFilter:
548576 :param word: value to feed the function
549577 :return: a QueryFilter instance that can render the OData function operation
550578 """
551- return self .function_operation (' startswith' , attribute , word )
579+ return self .function_operation (" startswith" , attribute , word )
552580
553581 def endswith (self , attribute : str , word : FilterWord ) -> FunctionFilter :
554582 """ Adds a endswith word check
@@ -557,7 +585,7 @@ def endswith(self, attribute: str, word: FilterWord) -> FunctionFilter:
557585 :param word: value to feed the function
558586 :return: a QueryFilter instance that can render the OData function operation
559587 """
560- return self .function_operation (' endswith' , attribute , word )
588+ return self .function_operation (" endswith" , attribute , word )
561589
562590 def iterable_operation (self , operation : str , collection : str , filter_instance : QueryFilter ,
563591 * , item_name : str = "a" ) -> IterableFilter :
@@ -598,7 +626,7 @@ def any(self, collection: str, filter_instance: QueryFilter, *, item_name: str =
598626 :return: a QueryFilter instance that can render the OData iterable operation
599627 """
600628
601- return self .iterable_operation (' any' , collection = collection ,
629+ return self .iterable_operation (" any" , collection = collection ,
602630 filter_instance = filter_instance , item_name = item_name )
603631
604632
@@ -618,7 +646,7 @@ def all(self, collection: str, filter_instance: QueryFilter, *, item_name: str =
618646 :return: a QueryFilter instance that can render the OData iterable operation
619647 """
620648
621- return self .iterable_operation (' all' , collection = collection ,
649+ return self .iterable_operation (" all" , collection = collection ,
622650 filter_instance = filter_instance , item_name = item_name )
623651
624652 @staticmethod
@@ -635,7 +663,7 @@ def chain_and(self, *filter_instances: QueryFilter, group: bool = False) -> Quer
635663 """
636664 if not all (isinstance (item , QueryFilter ) for item in filter_instances ):
637665 raise ValueError ("'filter_instances' parameter must contain only QueryFilter instances" )
638- chain = ChainFilter (operation = ' and' , filter_instances = list (filter_instances ))
666+ chain = ChainFilter (operation = " and" , filter_instances = list (filter_instances ))
639667 if group :
640668 return self .group (chain )
641669 else :
@@ -650,7 +678,7 @@ def chain_or(self, *filter_instances: QueryFilter, group: bool = False) -> Query
650678 """
651679 if not all (isinstance (item , QueryFilter ) for item in filter_instances ):
652680 raise ValueError ("'filter_instances' parameter must contain only QueryFilter instances" )
653- chain = ChainFilter (operation = 'or' , filter_instances = list (filter_instances ))
681+ chain = ChainFilter (operation = "or" , filter_instances = list (filter_instances ))
654682 if group :
655683 return self .group (chain )
656684 else :
0 commit comments