Skip to content

Commit 07b1a3f

Browse files
committed
unified new and old query with a get_filter_by_attribute method so they will be compatible
1 parent 5936a93 commit 07b1a3f

File tree

3 files changed

+75
-38
lines changed

3 files changed

+75
-38
lines changed

O365/calendar.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,28 +1825,19 @@ def get_events(self, limit=25, *, query=None, order_by=None, batch=None,
18251825
if query and not isinstance(query, str):
18261826
# extract start and end from query because
18271827
# those are required by a calendarView
1828-
for query_data in query._filters:
1829-
if not isinstance(query_data, list):
1830-
continue
1831-
attribute = query_data[0]
1832-
# the 2nd position contains the filter data
1833-
# and the 3rd position in filter_data contains the value
1834-
word = query_data[2][3]
1835-
1836-
if attribute.lower().startswith('start/'):
1837-
start = word.replace("'", '') # remove the quotes
1838-
query.remove_filter('start')
1839-
if attribute.lower().startswith('end/'):
1840-
end = word.replace("'", '') # remove the quotes
1841-
query.remove_filter('end')
1828+
start = query.get_filter_by_attribute('start/')
1829+
end = query.get_filter_by_attribute('start/')
18421830

1843-
if start is None or end is None:
1844-
raise ValueError(
1845-
"When 'include_recurring' is True you must provide a 'start' and 'end' datetimes inside a Query instance.")
1831+
if start:
1832+
start = start.replace("'", '') # remove the quotes
1833+
query.remove_filter('start')
1834+
if end:
1835+
end = end.replace("'", '') # remove the quotes
1836+
query.remove_filter('end')
18461837

1847-
if end < start:
1848-
raise ValueError('When using "include_recurring=True", the date asigned to the "end" datetime'
1849-
' should be greater or equal than the date asigned to the "start" datetime.')
1838+
if start is None or end is None:
1839+
raise ValueError("When 'include_recurring' is True you must provide "
1840+
"a 'start' and 'end' datetime inside a 'Query' instance.")
18501841

18511842
params[self._cc('startDateTime')] = start
18521843
params[self._cc('endDateTime')] = end

O365/utils/query.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
import datetime as dt
44
from abc import ABC, abstractmethod
5-
from typing import Union, Optional, TYPE_CHECKING, Type, Iterator, Literal
5+
from typing import Union, Optional, TYPE_CHECKING, Type, Iterator, Literal, TypeAlias
66

77
if TYPE_CHECKING:
88
from O365.connection import Protocol
99

10-
FilterWord = Union[str, bool, None, dt.date, int, float]
10+
FilterWord: TypeAlias = Union[str, bool, None, dt.date, int, float]
1111

1212

1313
class QueryBase(ABC):
@@ -35,6 +35,34 @@ def __and__(self, other):
3535
def __or__(self, other):
3636
pass
3737

38+
def get_filter_by_attribute(self, attribute: str) -> Optional[str]:
39+
"""
40+
Returns a filter value by attribute name. It will match the attribute to the start of each filter attribute
41+
and return the first found.
42+
:param attribute: the attribute you want to search
43+
:return: The value applied to that attribute or None
44+
"""
45+
search_object: Optional[QueryFilter] = getattr(self, "_filter_instance") or getattr(self, "filters")
46+
if search_object is not None:
47+
# CompositeFilter, IterableFilter, ModifierQueryFilter (negate, group)
48+
return search_object.get_filter_by_attribute(attribute)
49+
50+
search_object: Optional[list[QueryFilter]] = getattr(self, "_filter_instances")
51+
if search_object is not None:
52+
# ChainFilter
53+
for filter_obj in search_object:
54+
result = filter_obj.get_filter_by_attribute(attribute)
55+
if result is not None:
56+
return result
57+
return None
58+
59+
search_object: Optional[str] = getattr(self, "_attribute")
60+
if search_object is not None:
61+
# LogicalFilter or FunctionFilter
62+
if search_object.lower().startswith(attribute.lower()):
63+
return getattr(self, "_word")
64+
return None
65+
3866

3967
class QueryFilter(QueryBase, ABC):
4068
__slots__ = ()
@@ -428,7 +456,7 @@ def _parse_filter_word(self, word: FilterWord) -> str:
428456
# bools are treated as lower case bools
429457
parsed_word = str(word).lower()
430458
elif word is None:
431-
parsed_word = 'null'
459+
parsed_word = "null"
432460
elif isinstance(word, dt.date):
433461
if isinstance(word, dt.datetime):
434462
if word.tzinfo is None:
@@ -451,9 +479,9 @@ def _get_attribute_from_mapping(self, attribute: str) -> str:
451479
"""
452480
mapping = self._attribute_mapping.get(attribute)
453481
if mapping:
454-
attribute = '/'.join(
482+
attribute = "/".join(
455483
[self.protocol.convert_case(step) for step in
456-
mapping.split('/')])
484+
mapping.split("/")])
457485
else:
458486
attribute = self.protocol.convert_case(attribute)
459487
return attribute
@@ -475,7 +503,7 @@ def equals(self, attribute: str, word: FilterWord) -> LogicalFilter:
475503
:param word: word to compare with
476504
:return: a QueryFilter instance that can render the OData this logical operation
477505
"""
478-
return self.logical_operation('eq', attribute, word)
506+
return self.logical_operation("eq", attribute, word)
479507

480508
def unequal(self, attribute: str, word: FilterWord) -> LogicalFilter:
481509
""" Return an unequal check
@@ -484,7 +512,7 @@ def unequal(self, attribute: str, word: FilterWord) -> LogicalFilter:
484512
:param word: word to compare with
485513
:return: a QueryFilter instance that can render the OData this logical operation
486514
"""
487-
return self.logical_operation('ne', attribute, word)
515+
return self.logical_operation("ne", attribute, word)
488516

489517
def greater(self, attribute: str, word: FilterWord) -> LogicalFilter:
490518
""" Return a 'greater than' check
@@ -493,7 +521,7 @@ def greater(self, attribute: str, word: FilterWord) -> LogicalFilter:
493521
:param word: word to compare with
494522
:return: a QueryFilter instance that can render the OData this logical operation
495523
"""
496-
return self.logical_operation('gt', attribute, word)
524+
return self.logical_operation("gt", attribute, word)
497525

498526
def greater_equal(self, attribute: str, word: FilterWord) -> LogicalFilter:
499527
""" Return a 'greater than or equal to' check
@@ -502,7 +530,7 @@ def greater_equal(self, attribute: str, word: FilterWord) -> LogicalFilter:
502530
:param word: word to compare with
503531
:return: a QueryFilter instance that can render the OData this logical operation
504532
"""
505-
return self.logical_operation('ge', attribute, word)
533+
return self.logical_operation("ge", attribute, word)
506534

507535
def less(self, attribute: str, word: FilterWord) -> LogicalFilter:
508536
""" Return a 'less than' check
@@ -511,7 +539,7 @@ def less(self, attribute: str, word: FilterWord) -> LogicalFilter:
511539
:param word: word to compare with
512540
:return: a QueryFilter instance that can render the OData this logical operation
513541
"""
514-
return self.logical_operation('lt', attribute, word)
542+
return self.logical_operation("lt", attribute, word)
515543

516544
def less_equal(self, attribute: str, word: FilterWord) -> LogicalFilter:
517545
""" Return a 'less than or equal to' check
@@ -520,7 +548,7 @@ def less_equal(self, attribute: str, word: FilterWord) -> LogicalFilter:
520548
:param word: word to compare with
521549
:return: a QueryFilter instance that can render the OData this logical operation
522550
"""
523-
return self.logical_operation('le', attribute, word)
551+
return self.logical_operation("le", attribute, word)
524552

525553
def function_operation(self, operation: str, attribute: str, word: FilterWord) -> FunctionFilter:
526554
""" Apply a function operation
@@ -539,7 +567,7 @@ def contains(self, attribute: str, word: FilterWord) -> FunctionFilter:
539567
:param word: value to feed the function
540568
:return: a QueryFilter instance that can render the OData function operation
541569
"""
542-
return self.function_operation('contains', attribute, word)
570+
return self.function_operation("contains", attribute, word)
543571

544572
def startswith(self, attribute: str, word: FilterWord) -> FunctionFilter:
545573
""" Adds a startswith word check
@@ -548,7 +576,7 @@ def startswith(self, attribute: str, word: FilterWord) -> FunctionFilter:
548576
:param word: value to feed the function
549577
:return: a QueryFilter instance that can render the OData function operation
550578
"""
551-
return self.function_operation('startswith', attribute, word)
579+
return self.function_operation("startswith", attribute, word)
552580

553581
def endswith(self, attribute: str, word: FilterWord) -> FunctionFilter:
554582
""" Adds a endswith word check
@@ -557,7 +585,7 @@ def endswith(self, attribute: str, word: FilterWord) -> FunctionFilter:
557585
:param word: value to feed the function
558586
:return: a QueryFilter instance that can render the OData function operation
559587
"""
560-
return self.function_operation('endswith', attribute, word)
588+
return self.function_operation("endswith", attribute, word)
561589

562590
def iterable_operation(self, operation: str, collection: str, filter_instance: QueryFilter,
563591
*, item_name: str = "a") -> IterableFilter:
@@ -598,7 +626,7 @@ def any(self, collection: str, filter_instance: QueryFilter, *, item_name: str =
598626
:return: a QueryFilter instance that can render the OData iterable operation
599627
"""
600628

601-
return self.iterable_operation('any', collection=collection,
629+
return self.iterable_operation("any", collection=collection,
602630
filter_instance=filter_instance, item_name=item_name)
603631

604632

@@ -618,7 +646,7 @@ def all(self, collection: str, filter_instance: QueryFilter, *, item_name: str =
618646
:return: a QueryFilter instance that can render the OData iterable operation
619647
"""
620648

621-
return self.iterable_operation('all', collection=collection,
649+
return self.iterable_operation("all", collection=collection,
622650
filter_instance=filter_instance, item_name=item_name)
623651

624652
@staticmethod
@@ -635,7 +663,7 @@ def chain_and(self, *filter_instances: QueryFilter, group: bool = False) -> Quer
635663
"""
636664
if not all(isinstance(item, QueryFilter) for item in filter_instances):
637665
raise ValueError("'filter_instances' parameter must contain only QueryFilter instances")
638-
chain = ChainFilter(operation='and', filter_instances=list(filter_instances))
666+
chain = ChainFilter(operation="and", filter_instances=list(filter_instances))
639667
if group:
640668
return self.group(chain)
641669
else:
@@ -650,7 +678,7 @@ def chain_or(self, *filter_instances: QueryFilter, group: bool = False) -> Query
650678
"""
651679
if not all(isinstance(item, QueryFilter) for item in filter_instances):
652680
raise ValueError("'filter_instances' parameter must contain only QueryFilter instances")
653-
chain = ChainFilter(operation='or', filter_instances=list(filter_instances))
681+
chain = ChainFilter(operation="or", filter_instances=list(filter_instances))
654682
if group:
655683
return self.group(chain)
656684
else:

O365/utils/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,3 +1306,21 @@ def close_group(self):
13061306
else:
13071307
raise RuntimeError("No filters present. Can't close a group")
13081308
return self
1309+
1310+
def get_filter_by_attribute(self, attribute):
1311+
""" Returns a filter word applied to an attribute """
1312+
1313+
attribute = attribute.lower()
1314+
1315+
# iterate over the filters to find the corresponding attribute
1316+
for query_data in self._filters:
1317+
if not isinstance(query_data, list):
1318+
continue
1319+
filter_attribute = query_data[0]
1320+
# the 2nd position contains the filter data
1321+
# and the 3rd position in filter_data contains the value
1322+
word = query_data[2][3]
1323+
1324+
if filter_attribute.lower().startswith(attribute):
1325+
return word
1326+
return None

0 commit comments

Comments
 (0)