Skip to content

Commit a9b508c

Browse files
authored
add kmeans (#25)
1 parent 5317ae6 commit a9b508c

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed

build.zig

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,14 @@ pub fn build(b: *std.Build) void {
208208
.name = "X25519+Kyber768Draft00.zig",
209209
.category = "web/tls",
210210
});
211+
// Machine Learning
212+
if (std.mem.eql(u8, op, "machine_learning/k_means_clustering.zig"))
213+
buildAlgorithm(b, .{
214+
.optimize = optimize,
215+
.target = target,
216+
.name = "k_means_clustering.zig",
217+
.category = "machine_learning",
218+
});
211219
}
212220

213221
fn buildAlgorithm(b: *std.Build, info: BInfo) void {
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
const KMeansError = error{SmallK};
2+
3+
const Point2D = struct {
4+
x: f32,
5+
y: f32,
6+
const zero: Point2D = .{ .x = 0, .y = 0 };
7+
fn eq(self: Point2D, other: Point2D) bool {
8+
return self.x == other.x and self.y == other.y;
9+
}
10+
fn add(self: Point2D, other: Point2D) Point2D {
11+
return .{ .x = self.x + other.x, .y = self.y + other.y };
12+
}
13+
fn div(self: Point2D, scalar: f32) Point2D {
14+
return .{ .x = self.x / scalar, .y = self.y / scalar };
15+
}
16+
};
17+
const Cluster = struct {
18+
point: Point2D,
19+
count: usize,
20+
const zero: Cluster = .{ .point = .zero, .count = 0 };
21+
fn eq(self: Cluster, other: Cluster) bool {
22+
return self.point.eq(other.point) and self.count == other.count;
23+
}
24+
};
25+
26+
fn distanceSquared(a: Point2D, b: Point2D) f32 {
27+
const y = a.y - b.y;
28+
const x = a.x - b.x;
29+
return x * x + y * y;
30+
}
31+
32+
fn calculateNearest(point: Point2D, clusters: anytype) usize {
33+
var min_distance = distanceSquared(point, clusters[0].point);
34+
var closest_cluster_idx: usize = 0;
35+
for (1..clusters.len) |cluster_idx| {
36+
const distance = distanceSquared(clusters[cluster_idx].point, point);
37+
if (distance < min_distance) {
38+
min_distance = distance;
39+
closest_cluster_idx = cluster_idx;
40+
}
41+
}
42+
return closest_cluster_idx;
43+
}
44+
45+
pub fn KMeans(data: []const Point2D, comptime k: usize) ![k]Cluster {
46+
if (data.len < k) {
47+
return KMeansError.SmallK;
48+
}
49+
// assign clusters to different data points
50+
var old_clusters: [k]Cluster = undefined;
51+
for (0..old_clusters.len) |i| {
52+
old_clusters[i].point = data[i];
53+
old_clusters[i].count = 0;
54+
}
55+
while (true) {
56+
var new_clusters: [k]Cluster = .{Cluster.zero} ** k;
57+
for (data) |point| {
58+
const cluster_idx = calculateNearest(point, old_clusters);
59+
const new = &new_clusters[cluster_idx];
60+
new.point = new.point.add(point);
61+
new.count += 1;
62+
}
63+
for (&new_clusters) |*cluster| {
64+
const count_as_f32: f32 = @floatFromInt(cluster.count);
65+
cluster.point = cluster.point.div(count_as_f32);
66+
}
67+
check_equal: {
68+
for (old_clusters, new_clusters) |old, new| {
69+
if (!old.eq(new)) {
70+
break :check_equal;
71+
}
72+
}
73+
return new_clusters;
74+
}
75+
old_clusters = new_clusters;
76+
}
77+
}
78+
79+
const std = @import("std");
80+
const expectEqual = std.testing.expectEqual;
81+
test "Kmeans" {
82+
try expectEqual(
83+
[_]Cluster{
84+
.{ .point = .{ .x = 34.0, .y = 34.0 }, .count = 1 },
85+
},
86+
try KMeans(
87+
&[_]Point2D{
88+
.{ .x = 34.0, .y = 34.0 },
89+
},
90+
1,
91+
),
92+
);
93+
try expectEqual(
94+
[_]Cluster{
95+
.{ .point = .{ .x = 33.0, .y = 33.0 }, .count = 2 },
96+
},
97+
try KMeans(
98+
&[_]Point2D{
99+
.{ .x = 33.0, .y = 33.0 },
100+
.{ .x = 33.0, .y = 33.0 },
101+
},
102+
1,
103+
),
104+
);
105+
try expectEqual(
106+
[_]Cluster{
107+
.{ .point = .{ .x = 33.0, .y = 34.0 }, .count = 2 },
108+
},
109+
try KMeans(
110+
&[_]Point2D{
111+
.{ .x = 32.0, .y = 33.0 },
112+
.{ .x = 34.0, .y = 35.0 },
113+
},
114+
1,
115+
),
116+
);
117+
try expectEqual(
118+
[_]Cluster{
119+
.{ .point = .{ .x = 0.0, .y = 0.5 }, .count = 3 },
120+
.{ .point = .{ .x = 2.0, .y = 0.5 }, .count = 3 },
121+
},
122+
try KMeans(
123+
&[_]Point2D{
124+
.{ .x = 0.0, .y = 1.0 },
125+
.{ .x = 2.0, .y = 1.0 },
126+
.{ .x = 0.0, .y = 0.5 },
127+
.{ .x = 0.0, .y = 0.0 },
128+
.{ .x = 2.0, .y = 0.5 },
129+
.{ .x = 2.0, .y = 0.0 },
130+
},
131+
2,
132+
),
133+
);
134+
}

runall.cmd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,7 @@ rem Web
4949
%ZIG_TEST% -Dalgorithm=web/httpServer %Args%
5050
%ZIG_TEST% -Dalgorithm=web/tls1_3 %Args%
5151

52+
rem Machine Learning
53+
%ZIG_TEST% -Dalgorithm=machine_learning/k_means_clustering.zig %Args%
54+
5255
rem Add more...

runall.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,6 @@ $ZIG_TEST -Dalgorithm=web/httpClient $Args
4949
$ZIG_TEST -Dalgorithm=web/httpServer $Args
5050
$ZIG_TEST -Dalgorithm=web/tls1_3 $Args
5151

52+
## Machine Learning
53+
$ZIG_TEST -Dalgorithm=machine_learning/k_means_clustering.zig $Args
5254
## Add more...

0 commit comments

Comments
 (0)