@@ -82,14 +82,27 @@ are still presented a target variable in training, but they behave as transforme
8282pipelines. They are entity embedding transformers, in the sense of the article, "Entity
8383Embeddings of Categorical Variables" by Cheng Guo, Felix Berkhahn.
8484
85- The atomic `model` must be an instance of `MLJFlux.NeuralNetworkClassifier`,
86- `MLJFlux.NeuralNetworkBinaryClassifier`, `MLJFlux.NeuralNetworkRegressor`, or
87- `MLJFlux.MultitargetNeuralNetworkRegressor`. Hyperparameters of the atomic model, in
88- particular `builder` and `embedding_dims`, will effect embedding performance.
85+ # Training data
8986
90- The wrapped model is bound to a machine and trained exactly as the wrapped supervised
91- `model`, and supports the same form of training data. In particular, a training target
92- must be supplied.
87+ In MLJ (or MLJBase) bind an instance unsupervised `model` to data with
88+
89+ mach = machine(embed_model, X, y)
90+
91+ Here:
92+
93+ - `embed_model` is an instance of `EntityEmbedder`, which wraps a supervised MLJFlux
94+ model, `model`, which must be an instance of one of these:
95+ `MLJFlux.NeuralNetworkClassifier`, `NeuralNetworkBinaryClassifier`,
96+ `MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`.
97+
98+ - `X` is any table of input features supported by the model being wrapped. Features to be
99+ transformed must have element scitype `Multiclass` or `OrderedFactor`. Use `schema(X)`
100+ to check scitypes.
101+
102+ - `y` is the target, which can be any `AbstractVector` supported by the model being
103+ wrapped.
104+
105+ Train the machine using `fit!(mach)`.
93106
94107# Examples
95108
@@ -107,6 +120,7 @@ X = (
107120 b = categorical(rand("abcde", N)),
108121 c = categorical(rand("ABCDEFGHIJ", N), ordered = true),
109122)
123+
110124y = categorical(rand("YN", N));
111125
112126# Initiate model
0 commit comments