Skip to content
Open
202 changes: 117 additions & 85 deletions machine_learning/apriori_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
"""
Apriori Algorithm is a Association rule mining technique, also known as market basket
analysis, aims to discover interesting relationships or associations among a set of
items in a transactional or relational database.
Apriori Algorithm with Association Rules (support, confidence, lift).

For example, Apriori Algorithm states: "If a customer buys item A and item B, then they
are likely to buy item C." This rule suggests a relationship between items A, B, and C,
indicating that customers who purchased A and B are more likely to also purchase item C.
This implementation finds:
- Frequent itemsets
- Association rules with minimum confidence and lift

WIKI: https://en.wikipedia.org/wiki/Apriori_algorithm
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
"""

from collections import defaultdict
from itertools import combinations


Expand All @@ -24,90 +22,124 @@ def load_data() -> list[list[str]]:
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]


def prune(itemset: list, candidates: list, length: int) -> list:
"""
Prune candidate itemsets that are not frequent.
The goal of pruning is to filter out candidate itemsets that are not frequent. This
is done by checking if all the (k-1) subsets of a candidate itemset are present in
the frequent itemsets of the previous iteration (valid subsequences of the frequent
itemsets from the previous iteration).

Prunes candidate itemsets that are not frequent.

>>> itemset = ['X', 'Y', 'Z']
>>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
>>> prune(itemset, candidates, 2)
[['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]

>>> itemset = ['1', '2', '3', '4']
>>> candidates = ['1', '2', '4']
>>> prune(itemset, candidates, 3)
[]
"""
pruned = []
for candidate in candidates:
is_subsequence = True
for item in candidate:
if item not in itemset or itemset.count(item) < length - 1:
is_subsequence = False
class Apriori:
"""Apriori algorithm class with support, confidence, and lift filtering."""

def __init__(
self,
transactions: list[list[str]],
min_support: float = 0.25,
min_confidence: float = 0.5,
min_lift: float = 1.0,
) -> None:
self.transactions: list[set[str]] = [set(t) for t in transactions]
self.min_support: float = min_support
self.min_confidence: float = min_confidence
self.min_lift: float = min_lift
self.itemsets: list[dict[frozenset, float]] = []
self.rules: list[tuple[frozenset, frozenset, float, float]] = []

self.find_frequent_itemsets()
self.generate_association_rules()

def _get_support(self, itemset: frozenset) -> float:
"""Return support of an itemset."""
return sum(1 for t in self.transactions if itemset.issubset(t)) / len(
self.transactions
)

def confidence(self, antecedent: frozenset, consequent: frozenset) -> float:
"""Calculate confidence of a rule A -> B."""
support_antecedent = self._get_support(antecedent)
support_both = self._get_support(antecedent | consequent)
return support_both / support_antecedent if support_antecedent > 0 else 0.0

def lift(self, antecedent: frozenset, consequent: frozenset) -> float:
"""Calculate lift of a rule A -> B."""
support_consequent = self._get_support(consequent)
conf = self.confidence(antecedent, consequent)
return conf / support_consequent if support_consequent > 0 else 0.0

def find_frequent_itemsets(self) -> list[dict[frozenset, float]]:
"""Generate all frequent itemsets."""
item_counts: dict[frozenset, int] = defaultdict(int)
for t in self.transactions:
for item in t:
item_counts[frozenset([item])] += 1

total: int = len(self.transactions)
current_itemsets: dict[frozenset, float] = {
k: v / total
for k, v in item_counts.items()
if v / total >= self.min_support
}
if current_itemsets:
self.itemsets.append(current_itemsets)

k: int = 2
while current_itemsets:
candidates: set[frozenset] = set()
keys: list[frozenset] = list(current_itemsets.keys())
for i in range(len(keys)):
for j in range(i + 1, len(keys)):
union = keys[i] | keys[j]
if len(union) == k and all(
frozenset(sub) in current_itemsets
for sub in combinations(union, k - 1)
):
candidates.add(union)

freq_candidates: dict[frozenset, float] = {
c: self._get_support(c)
for c in candidates
if self._get_support(c) >= self.min_support
}
if not freq_candidates:
break
if is_subsequence:
pruned.append(candidate)
return pruned


def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
"""
Returns a list of frequent itemsets and their support counts.

>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
>>> apriori(data, 2)
[(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]

>>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
>>> apriori(data, 3)
[]
"""
itemset = [list(transaction) for transaction in data]
frequent_itemsets = []
length = 1

while itemset:
# Count itemset support
counts = [0] * len(itemset)
for transaction in data:
for j, candidate in enumerate(itemset):
if all(item in transaction for item in candidate):
counts[j] += 1

# Prune infrequent itemsets
itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support]

# Append frequent itemsets (as a list to maintain order)
for i, item in enumerate(itemset):
frequent_itemsets.append((sorted(item), counts[i]))

length += 1
itemset = prune(itemset, list(combinations(itemset, length)), length)

return frequent_itemsets
self.itemsets.append(freq_candidates)
current_itemsets = freq_candidates
k += 1

return self.itemsets

def generate_association_rules(
self,
) -> list[tuple[frozenset, frozenset, float, float]]:
"""Generate association rules with min confidence and lift."""
for level in self.itemsets:
for itemset in level:
if len(itemset) < 2:
continue
for i in range(1, len(itemset)):
for antecedent in combinations(itemset, i):
antecedent_set = frozenset(antecedent)
consequent_set = itemset - antecedent_set
conf = self.confidence(antecedent_set, consequent_set)
lft = self.lift(antecedent_set, consequent_set)
if conf >= self.min_confidence and lft >= self.min_lift:
self.rules.append(
(antecedent_set, consequent_set, conf, lft)
)
return self.rules


if __name__ == "__main__":
"""
Apriori algorithm for finding frequent itemsets.

Args:
data: A list of transactions, where each transaction is a list of items.
min_support: The minimum support threshold for frequent itemsets.

Returns:
A list of frequent itemsets along with their support counts.
"""
import doctest

doctest.testmod()

# user-defined threshold or minimum support level
frequent_itemsets = apriori(data=load_data(), min_support=2)
print("\n".join(f"{itemset}: {support}" for itemset, support in frequent_itemsets))
transactions = load_data()
model = Apriori(transactions, min_support=0.25, min_confidence=0.1, min_lift=0.0)

print("Frequent itemsets:")
for level in model.itemsets:
for items, sup in level.items():
print(f"{set(items)}: {sup:.2f}")

print("\nAssociation Rules:")
for rule in model.rules:
antecedent, consequent, conf, lift = rule
print(
f"{set(antecedent)} -> {set(consequent)}, conf={conf:.2f}, lift={lift:.2f}"
)