Skip to content

Commit 8e02638

Browse files
authored
Merge pull request #309 from FluxML/doc-touch
✨ Small improvement to EntityEmbedder docs.
2 parents 2919aa9 + fecdd07 commit 8e02638

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

src/mlj_embedder_interface.jl

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ In the following example we wrap a `NeuralNetworkClassifier` as an `EntityEmbedd
110110
that it can be used to supply continuously encoded features to a nearest neighbor model,
111111
which does not support categorical features.
112112
113+
## Simple Example
113114
```julia
114115
using MLJ
115116
@@ -129,21 +130,46 @@ EntityEmbedder = @load EntityEmbedder pkg=MLJFlux
129130
# Flux model to do learn the entity embeddings:
130131
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
131132
132-
# Other supervised model type, requiring `Continuous` features:
133-
KNNClassifier = @load KNNClassifier pkg=NearestNeighborModels
134-
135133
# Instantiate the models:
136134
clf = NeuralNetworkClassifier(embedding_dims=Dict(:b => 2, :c => 3))
137135
emb = EntityEmbedder(clf)
138136
139-
# For illustrative purposes, train the embedder on its own:
137+
# Train and transform the data using the embedder:
140138
mach = machine(emb, X, y)
141139
fit!(mach)
142140
Xnew = transform(mach, X)
143141
144-
# And compare feature scitypes:
142+
# Compare schemas before and after transformation
145143
schema(X)
146144
schema(Xnew)
145+
```
146+
147+
## Using with Downstream Models (Pipeline)
148+
```julia
149+
using MLJ
150+
151+
# Setup some data
152+
N = 400
153+
X = (
154+
a = rand(Float32, N),
155+
b = categorical(rand("abcde", N)),
156+
c = categorical(rand("ABCDEFGHIJ", N), ordered = true),
157+
)
158+
159+
y = categorical(rand("YN", N));
160+
161+
# Initiate model
162+
EntityEmbedder = @load EntityEmbedder pkg=MLJFlux
163+
164+
# Flux model to do learn the entity embeddings:
165+
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
166+
167+
# Other supervised model type, requiring `Continuous` features:
168+
KNNClassifier = @load KNNClassifier pkg=NearestNeighborModels
169+
170+
# Instantiate the models:
171+
clf = NeuralNetworkClassifier(embedding_dims=Dict(:b => 2, :c => 3))
172+
emb = EntityEmbedder(clf)
147173
148174
# Now construct the pipeline:
149175
pipe = emb |> KNNClassifier()

0 commit comments

Comments
 (0)