Skip to content

Commit 7fccbd9

Browse files
committed
Added Decision Tree (ID3, binary classification) implementation
1 parent e2c20ed commit 7fccbd9

File tree

3 files changed

+260
-0
lines changed

3 files changed

+260
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
using NUnit.Framework;
2+
using Algorithms.MachineLearning;
3+
using System;
4+
5+
namespace Algorithms.Tests.MachineLearning;
6+
7+
[TestFixture]
8+
public class DecisionTreeTests
9+
{
10+
[Test]
11+
public void Fit_ThrowsOnEmptyInput()
12+
{
13+
var tree = new DecisionTree();
14+
Assert.Throws<ArgumentException>(() => tree.Fit(Array.Empty<int[]>(), Array.Empty<int>()));
15+
}
16+
17+
[Test]
18+
public void Fit_ThrowsOnMismatchedLabels()
19+
{
20+
var tree = new DecisionTree();
21+
int[][] X = { new[] { 1, 2 } };
22+
int[] y = { 1, 0 };
23+
Assert.Throws<ArgumentException>(() => tree.Fit(X, y));
24+
}
25+
26+
[Test]
27+
public void Predict_ThrowsIfNotTrained()
28+
{
29+
var tree = new DecisionTree();
30+
Assert.Throws<InvalidOperationException>(() => tree.Predict(new[] { 1, 2 }));
31+
}
32+
33+
[Test]
34+
public void Predict_ThrowsOnFeatureMismatch()
35+
{
36+
var tree = new DecisionTree();
37+
int[][] X = { new[] { 1, 2 } };
38+
int[] y = { 1 };
39+
tree.Fit(X, y);
40+
Assert.Throws<ArgumentException>(() => tree.Predict(new[] { 1 }));
41+
}
42+
43+
[Test]
44+
public void FitAndPredict_WorksOnSimpleData()
45+
{
46+
// Simple OR logic
47+
int[][] X =
48+
{
49+
new[] { 0, 0 },
50+
new[] { 0, 1 },
51+
new[] { 1, 0 },
52+
new[] { 1, 1 }
53+
};
54+
int[] y = { 0, 1, 1, 1 };
55+
var tree = new DecisionTree();
56+
tree.Fit(X, y);
57+
Assert.That(tree.Predict(new[] { 0, 0 }), Is.EqualTo(0));
58+
Assert.That(tree.Predict(new[] { 0, 1 }), Is.EqualTo(1));
59+
Assert.That(tree.Predict(new[] { 1, 0 }), Is.EqualTo(1));
60+
Assert.That(tree.Predict(new[] { 1, 1 }), Is.EqualTo(1));
61+
}
62+
63+
[Test]
64+
public void FeatureCount_ReturnsCorrectValue()
65+
{
66+
var tree = new DecisionTree();
67+
int[][] X = { new[] { 1, 2, 3 } };
68+
int[] y = { 1 };
69+
tree.Fit(X, y);
70+
Assert.That(tree.FeatureCount, Is.EqualTo(3));
71+
}
72+
73+
[Test]
74+
public void Predict_FallbacksToZeroForUnseenValue()
75+
{
76+
int[][] X = { new[] { 0, 0 }, new[] { 1, 1 } };
77+
int[] y = { 0, 1 };
78+
var tree = new DecisionTree();
79+
tree.Fit(X, y);
80+
// Value 2 is unseen in feature 0
81+
Assert.That(tree.Predict(new[] { 2, 0 }), Is.EqualTo(0));
82+
}
83+
}
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
5+
namespace Algorithms.MachineLearning;
6+
7+
/// <summary>
8+
/// Simple Decision Tree for binary classification using the ID3 algorithm.
9+
/// Supports categorical features (int values).
10+
/// </summary>
11+
public class DecisionTree
12+
{
13+
private Node? root;
14+
15+
/// <summary>
16+
/// Trains the decision tree using the ID3 algorithm.
17+
/// </summary>
18+
/// <param name="x">2D array of features (samples x features), categorical (int).</param>
19+
/// <param name="y">Array of labels (0 or 1).</param>
20+
public void Fit(int[][] x, int[] y)
21+
{
22+
if (x.Length == 0 || x[0].Length == 0)
23+
{
24+
throw new ArgumentException("Input features cannot be empty.");
25+
}
26+
27+
if (x.Length != y.Length)
28+
{
29+
throw new ArgumentException("Number of samples and labels must match.");
30+
}
31+
32+
root = BuildTree(x, y, Enumerable.Range(0, x[0].Length).ToList());
33+
}
34+
35+
/// <summary>
36+
/// Predicts the class label (0 or 1) for a single sample.
37+
/// </summary>
38+
public int Predict(int[] x)
39+
{
40+
if (root is null)
41+
{
42+
throw new InvalidOperationException("Model not trained.");
43+
}
44+
45+
if (x.Length != FeatureCount)
46+
{
47+
throw new ArgumentException("Feature count mismatch.");
48+
}
49+
50+
return Traverse(root, x);
51+
}
52+
53+
/// <summary>
54+
/// Gets the number of features used in training.
55+
/// </summary>
56+
public int FeatureCount => root?.FeatureCount ?? 0;
57+
58+
private static Node BuildTree(int[][] x, int[] y, List<int> features)
59+
{
60+
if (y.All(l => l == y[0]))
61+
{
62+
return new Node { Label = y[0], FeatureCount = x[0].Length };
63+
}
64+
65+
if (features.Count == 0)
66+
{
67+
return new Node { Label = MostCommon(y), FeatureCount = x[0].Length };
68+
}
69+
70+
int bestFeature = BestFeature(x, y, features);
71+
var node = new Node { Feature = bestFeature, FeatureCount = x[0].Length };
72+
var values = x.Select(row => row[bestFeature]).Distinct();
73+
node.Children = new();
74+
foreach (var v in values)
75+
{
76+
var idx = x.Select((row, i) => (row, i)).Where(t => t.row[bestFeature] == v).Select(t => t.i).ToArray();
77+
if (idx.Length == 0)
78+
{
79+
continue;
80+
}
81+
82+
var subX = idx.Select(i => x[i]).ToArray();
83+
var subY = idx.Select(i => y[i]).ToArray();
84+
var subFeatures = features.Where(f => f != bestFeature).ToList();
85+
node.Children[v] = BuildTree(subX, subY, subFeatures);
86+
}
87+
88+
return node;
89+
}
90+
91+
private static int Traverse(Node node, int[] x)
92+
{
93+
if (node.Label is not null)
94+
{
95+
return node.Label.Value;
96+
}
97+
98+
int v = x[node.Feature!.Value];
99+
if (node.Children!.TryGetValue(v, out var child))
100+
{
101+
return Traverse(child, x);
102+
}
103+
104+
// fallback to 0 if unseen value
105+
return 0;
106+
}
107+
108+
private static int MostCommon(int[] y) => y.GroupBy(l => l).OrderByDescending(g => g.Count()).First().Key;
109+
110+
private static int BestFeature(int[][] x, int[] y, List<int> features)
111+
{
112+
double baseEntropy = Entropy(y);
113+
double bestGain = double.MinValue;
114+
int bestFeature = features[0];
115+
foreach (var f in features)
116+
{
117+
var values = x.Select(row => row[f]).Distinct();
118+
double splitEntropy = 0;
119+
foreach (var v in values)
120+
{
121+
var idx = x.Select((row, i) => (row, i)).Where(t => t.row[f] == v).Select(t => t.i).ToArray();
122+
if (idx.Length == 0)
123+
{
124+
continue;
125+
}
126+
127+
var subY = idx.Select(i => y[i]).ToArray();
128+
splitEntropy += (double)subY.Length / y.Length * Entropy(subY);
129+
}
130+
131+
double gain = baseEntropy - splitEntropy;
132+
if (gain > bestGain)
133+
{
134+
bestGain = gain;
135+
bestFeature = f;
136+
}
137+
}
138+
139+
return bestFeature;
140+
}
141+
142+
private static double Entropy(int[] y)
143+
{
144+
int n = y.Length;
145+
if (n == 0)
146+
{
147+
return 0;
148+
}
149+
150+
double p0 = y.Count(l => l == 0) / (double)n;
151+
double p1 = y.Count(l => l == 1) / (double)n;
152+
double e = 0;
153+
if (p0 > 0)
154+
{
155+
e -= p0 * Math.Log2(p0);
156+
}
157+
158+
if (p1 > 0)
159+
{
160+
e -= p1 * Math.Log2(p1);
161+
}
162+
163+
return e;
164+
}
165+
166+
private class Node
167+
{
168+
public int? Feature { get; set; }
169+
170+
public int? Label { get; set; }
171+
172+
public int FeatureCount { get; set; }
173+
174+
public Dictionary<int, Node>? Children { get; set; }
175+
}
176+
}

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ find more than one implementation for the same objective but using different alg
108108
* [CollaborativeFiltering](./Algorithms/RecommenderSystem/CollaborativeFiltering)
109109
* [Machine Learning](./Algorithms/MachineLearning)
110110
* [Linear Regression](./Algorithms/MachineLearning/LinearRegression.cs)
111+
* [Decision Tree](./Algorithms/MachineLearning/DecisionTree.cs)
111112
* [Searches](./Algorithms/Search)
112113
* [A-Star](./Algorithms/Search/AStar/)
113114
* [Binary Search](./Algorithms/Search/BinarySearcher.cs)

0 commit comments

Comments
 (0)