-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Open
Labels
enhancementNew feature or requestNew feature or requestuntriagedNew issue has not been triagedNew issue has not been triaged
Description
Summary
I propose adding a minimal extension point to let users inject a custom RNG into MLContext
, without changing defaults or breaking back-compat. This enables deterministic, portable randomness across platforms and languages (e.g., align with C++ std::mt19937
), and allows advanced users to choose an RNG that matches their reproducibility requirements.
Motivation
- Reproducibility across ecosystems: Many data science stacks standardize on MT19937 (e.g., C++
std::mt19937
, NumPy’s legacy PCG/MT usage), making it easier to compare experiments when the same PRNG is available in ML.NET. - Zero impact by default: Default behavior remains unchanged and backwards compatible.
- Testability: Easier to write bitwise-stable tests that don’t depend on underlying
System.Random
variations.
Design
Add a small interface and an optional parameter to MLContext
:
public interface IRandomSource
{
int Next();
int Next(int maxValue);
int Next(int minValue, int maxValue);
long NextInt64();
long NextInt64(long maxValue);
long NextInt64(long minValue, long maxValue);
double NextDouble();
float NextSingle();
void NextBytes(Span<byte> buffer);
}
// Existing constructor remains
public sealed class MLContext
{
public MLContext(int? seed = null) : this(seed, rng: null) { }
public MLContext(int? seed, IRandomSource? rng)
{
_rng = rng ?? new RandomSourceAdapter(seed is null ? Random.Shared : new Random(seed.Value));
// ... existing initialization
}
internal IRandomSource RandomSource => _rng;
private readonly IRandomSource _rng;
}
internal sealed class RandomSourceAdapter : IRandomSource
{
private readonly Random _rand;
public RandomSourceAdapter(Random rand) => _rand = rand;
public int Next() => _rand.Next();
public int Next(int maxValue) => _rand.Next(maxValue);
public int Next(int minValue, int maxValue) => _rand.Next(minValue, maxValue);
public long NextInt64() => _rand.NextInt64();
public long NextInt64(long maxValue) => _rand.NextInt64(maxValue);
public long NextInt64(long minValue, long maxValue) => _rand.NextInt64(minValue, maxValue);
public double NextDouble() => _rand.NextDouble();
public float NextSingle() => _rand.NextSingle();
public void NextBytes(Span<byte> buffer) => _rand.NextBytes(buffer);
}
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestuntriagedNew issue has not been triagedNew issue has not been triaged