In-Browser Machine Learning: Building a Random Forest Classifier in JavaScript
I had an afternoon to myself while my wife and child attended yet another birthday party (seriously, how many friends does a five-year-old need?). With no distractions, I decided to embark on a six-hour coding marathon.
In hindsight, this project was supe fun learning experience in understanding the inner workings of Random Forests and the delightful quirks of PURE JavaScript (COUGH NOT MY MOST EXPERIENCED LANGUAGE AND ITS ECCENTRICITIES DRIVE ME UP A WALL). To be brutally honest, this would not have been possible without the tireless patience of ChatGPT, arguably the world's most competent rubber duck.
Why Did I Choose to Rewrite a Random Forest in JavaScript? Am I Some Kind of Masochist?
Firstly, I'm extremely stressed out. When stress hits, I find solace in tackling complex problems; nothing says relaxation like reinventing the wheel (see ml-random-forest js) in a language that tests my patience.
Secondly, I wanted a challenge. Python is my go-to language for data science, thanks to its rich ecosystem of libraries like scikit-learn, pandas, and numpy. JavaScript, on the other hand, doesn't quite measure up in that department. Sure, it has libraries, but let's be real... they're not the same, are they? Despite this, I love the convenience of JavaScript and Observable for bringing beautiful, interactive content to the web.
Third(ly?), in my data science career, I've noticed that most tasks require simple solutions, and simple models are largely sufficient (#MaintainROI). Running machine learning models directly in the browser without server-side dependencies might not have massive practical applications, unless you're into flashy demos and educational tools. I'd love to be proven wrong!
1. Let's just get to the neato visuals
To test the classifier, I introduced a donut hole classification example... because I was totally eating and drinking me some 'dunks when I decided to embark on this. The goal is to classify points inside a circle (the hole) differently from those in the surrounding donut-shaped area.
1.1 Data Generation
Let's generate some synthetic data. Because I just got Dunkin', let's make this some kind of donut-hole classification. And thank you ChatGPT for whipping this one up in its entirety...
function generateDonutHoleData(nOuter, nInner, innerRadius, outerRadius) {
const X = [];
const y = [];
function randomPointInCircle(radiusMin, radiusMax) {
const angle = Math.random() * 2 * Math.PI;
const radius = radiusMin + Math.random() * (radiusMax - radiusMin);
const x = radius * Math.cos(angle);
const y = radius * Math.sin(angle);
return [x, y];
}
for (let i = 0; i < nOuter; i++) {
const [x, yVal] = randomPointInCircle(innerRadius, outerRadius);
X.push([x, yVal]);
y.push(0);
}
for (let i = 0; i < nInner; i++) {
const [x, yVal] = randomPointInCircle(0, innerRadius);
X.push([x, yVal]);
y.push(1);
}
return { X, y };
}
This function creates a challenging dataset for the classifier to learn non-linear boundaries.
1.2 Training and Visualization
Donut-hole Parameters:
Random forest classifier parameters:
Ok cool. Now let's project the decision boundary over the 'donut-data' to visualize our trained model.
Here's the trained forest object.
Neat. Now, let's visualize it. Note that trees here are connected to the terminal node. I don't want to mess with that visual (for now), so just like, get over it or whatever and stare at my beatiful FOREST.
Ok. Now let's make some predictions and check the performance.
Should we make an AUC chart? I guess.
The accuracy , recall , and precision .
2. How Does Each Part of This Function Work?
Let's dive into the RandomForestClassifier
class. I'll walk you through its methods in the order they are executed during training and prediction, explaining how they fit together, and sprinkle in some commentary to keep things interesting.
2.1 Constructor
The journey begins with initializing the Random Forest classifier. The constructor sets up the model with user-defined parameters, controlling the complexity and randomness of the forest. This setup allows us to balance the bias-variance trade-off, eventually finding the sweet spot between underfitting and overfitting.
constructor({ nEstimators = 10, maxDepth = 5, minSize = 1, sampleSize = 1.0,
maxFeatures = null, decimalPrecision = 2 } = {}) {
this.nEstimators = nEstimators;
this.maxDepth = maxDepth;
this.minSize = minSize;
this.sampleSize = sampleSize;
this.maxFeatures = maxFeatures;
this.decimalPrecision = decimalPrecision;
this.trees = [];
this.classLabels = []; // Keep track of all class labels
}
Parameters:
nEstimators
: Number of trees in the forest. Because more trees equal more fun, right?maxDepth
: Maximum depth of each tree. Let's not grow infinitely tall trees; we're not building skyscrapers.minSize
: Minimum number of samples required to split an internal node. Sometimes, you just have to know when to stop splitting hairs.sampleSize
: Proportion of the dataset to use for each tree (with replacement). Variety is the spice of life (and machine learning).maxFeatures
: Number of features to consider when looking for the best split. Ifnull
, it defaults to the square root of the number of features, because that seems entirely not a made up thing to becuase one guy/girl did it once and it worked.decimalPrecision
: Precision for rounding numerical values. Because floating-point errors are the stuff of nightmares.
These settings are crucial for controlling overfitting and ensuring that each tree in the forest is sufficiently unique.
2.2 Fitting the Model
Training the Random Forest involves building multiple decision trees. The fit
method is the entry point for this process.
fit(X, y) {
const dataset = X.map((row, idx) => [...row, y[idx]]);
this.classLabels = [...new Set(y.map(label => String(label)))]; // Store class labels as strings
for (let i = 0; i < this.nEstimators; i++) {
const sample = this.subsample(dataset, this.sampleSize);
const tree = this.buildTree(sample, this.maxDepth, this.minSize);
this.trees.push(tree);
}
}
Parameters:
X
: Feature matrix, an array of input vectors.y
: Target values corresponding to each input vector inX
.
First, it combines the features X
and labels y
into a single dataset. Each data point becomes an array of features followed by its label. Then, it stores all unique class labels (as strings) in this.classLabels
for later use. For each tree, it generates a bootstrap sample using the subsample
method. This sample is used to build a tree via the buildTree
method, and the resulting tree is added to the forest.
2.2.1 subsample(dataset, ratio)
Generates a bootstrap sample of the dataset with replacement. Each tree is trained on a random subset of the data, which helps create diverse trees and reduces overfitting.
subsample(dataset, ratio) {
const nSample = Math.round(dataset.length * ratio);
const sample = [];
for (let i = 0; i < nSample; i++) {
const index = Math.floor(Math.random() * dataset.length);
sample.push(dataset[index]);
}
return sample;
}
Parameters:
dataset
: The original dataset from which to sample.ratio
: The proportion of the dataset to include in the sample.
This method randomly selects data points from the dataset to create a sample of a specified size (ratio
of the original dataset). Because it samples with replacement, some data points may appear multiple times, while others may be omitted. This randomness is key to building uncorrelated trees in the forest.
2.3 Building Trees
The buildTree
method initiates the construction of a decision tree. It starts by finding the best split for the root node and then recursively splits child nodes.
buildTree(train, maxDepth, minSize) {
const root = this.getSplit(train);
this.split(root, maxDepth, minSize, 1);
return root;
}
Parameters:
train
: The training dataset to build the tree.maxDepth
: Maximum depth of the tree.minSize
: Minimum number of samples required to split a node.
It uses the getSplit
method to determine the optimal feature and value to split the data at the root node. Then, it calls the split
method to recursively build the tree from there.
2.4 Finding the Best Split
The getSplit
method determines the best feature and value to split the dataset to minimize Gini impurity.
getSplit(dataset) {
const classValues = [...new Set(dataset.map(row => row[row.length - 1]))];
let bestIndex, bestValue, bestScore = Infinity, bestGroups;
const nFeatures = dataset[0].length - 1;
const features = this.getRandomFeatures(nFeatures);
for (const index of features) {
for (const row of dataset) {
const groups = this.testSplit(index, row[index], dataset);
const gini = this.giniImpurity(groups, classValues);
if (gini < bestScore) {
bestIndex = index;
bestValue = row[index];
bestScore = gini;
bestGroups = groups;
}
}
}
if (bestGroups === undefined) {
return this.toTerminal(dataset);
}
return {
index: bestIndex,
value: this.round(bestValue),
gini: bestScore,
groups: bestGroups
};
}
Parameters:
dataset
: The dataset to find the best split for.
By considering a random subset of features (thanks to getRandomFeatures
), it ensures each tree is a unique snowflake. The goal is to find the most "pure" split, reducing the impurity like a water filter for your data. It iterates over possible splits, using testSplit
to divide the dataset and giniImpurity
to evaluate the quality of each split. Like some sort of dystopian future I fear we're headed towards, this function tracks and returns the split with the lowest Gini impurity (which is traditionally called 'best', but that's not really the right way to think about it... IS IT?).
2.4.1 getRandomFeatures(nFeatures)
Selects a random subset of features to consider at each split.
getRandomFeatures(nFeatures) {
let maxFeatures = this.maxFeatures;
if (!maxFeatures) {
maxFeatures = Math.max(1, Math.floor(Math.sqrt(nFeatures)));
}
maxFeatures = Math.min(maxFeatures, nFeatures); // Ensure we don't select more features than available
const features = [];
while (features.length < maxFeatures) {
const index = Math.floor(Math.random() * nFeatures);
if (!features.includes(index)) {
features.push(index);
}
}
return features;
}
Parameters:
nFeatures
: Total number of features in the dataset.
This introduces randomness into the model, which is crucial for the diversity of the trees in the forest. It helps reduce correlation among trees, thereby improving overall performance. Think of it as diversifying your investment portfolio but with data features.
2.4.2 testSplit(index, value, dataset)
Splits the dataset into two groups based on the specified feature index and value.
testSplit(index, value, dataset) {
const left = [], right = [];
for (const row of dataset) {
if (row[index] < value) {
left.push(row);
} else {
right.push(row);
}
}
return [left, right];
}
Parameters:
index
: The index of the feature to split on.value
: The value of the feature to split at.dataset
: The dataset to split.
This function divides the dataset into two groups: those that meet the split condition and those that don't. It's like dividing your Halloween candy stash into piles to maximize happiness...except with data and probably less eventual diabetes.
2.4.3 giniImpurity(groups, classes)
Computes the Gini impurity for a split, measuring how often a randomly chosen element would be incorrectly labeled.
giniImpurity(groups, classes) {
const nInstances = groups.reduce((sum, group) => sum + group.length, 0);
let gini = 0.0;
for (const group of groups) {
const size = group.length;
if (size === 0) continue; // Avoid dividing by zero.
const classCounts = {};
for (const row of group) {
const classVal = row[row.length - 1];
classCounts[classVal] = (classCounts[classVal] || 0) + 1;
}
let score = 0.0;
for (const classVal of classes) {
const proportion = (classCounts[classVal] || 0) / size;
score += proportion * proportion;
}
gini += (1.0 - score) * (size / nInstances);
}
return this.round(gini);
}
Parameters:
groups
: An array containing the left and right groups after a split.classes
: A list of unique class values in the dataset.
This method calculates the impurity of the groups created by a split. A lower Gini impurity indicates a better split. In other words, we're trying to make each node as "pure" as possible, much like trying to keep toddlers from sticking their fingers in their mouths, getting sick, destroying everyone's ability to rest and recover, only to eventually recover and do the same thing again literally the next day. Futile.
2.5 Recursive Splitting
The split
method recursively divides the dataset into smaller subsets, building the tree structure. Think of it as a Russian nesting doll, but it maxes out at the depth provided in the constructor to prevent infinite recursion.
split(node, maxDepth, minSize, depth) {
if (node.isTerminal) {
return; // We've reached a leaf node.
}
const [left, right] = node.groups;
delete node.groups; // Remove groups to free up memory.
// Check for a no-split condition
if (!left.length || !right.length) {
node.left = node.right = this.toTerminal(left.concat(right));
return;
}
// Max depth reached
if (depth >= maxDepth) {
node.left = this.toTerminal(left);
node.right = this.toTerminal(right);
return;
}
// Left child
if (left.length <= minSize) {
node.left = this.toTerminal(left);
} else {
node.left = this.getSplit(left);
this.split(node.left, maxDepth, minSize, depth + 1);
}
// Right child
if (right.length <= minSize) {
node.right = this.toTerminal(right);
} else {
node.right = this.getSplit(right);
this.split(node.right, maxDepth, minSize, depth + 1);
}
}
Parameters:
node
: The current node to split.maxDepth
: Maximum depth of the tree.minSize
: Minimum number of samples required to split a node.depth
: Current depth of the tree.
This method ensures that the tree doesn't grow indefinitely. Each recursive call increases the depth
by one, and when it reaches maxDepth
, the recursion stops. Additionally, if a node has fewer samples than minSize
, it becomes a terminal node. This helps prevent overfitting and keeps the tree manageable. It's like deciding that further debate is pointless and settling on an answer, Dad.
2.5.1 toTerminal(group)
Creates a terminal node (leaf) by assigning the most common class in the group.
toTerminal(group) {
const outcomes = group.map(row => row[row.length - 1]);
const counts = {};
let maxCount = 0;
let prediction;
for (const value of outcomes) {
counts[value] = (counts[value] || 0) + 1;
if (counts[value] > maxCount) {
maxCount = counts[value];
prediction = value;
}
}
return { isTerminal: true, value: prediction };
}
Parameters:
group
: The dataset at the current node.
This method determines the class that appears most frequently in the group and creates a terminal node with that prediction.
2.6 Making Predictions
After training, we use the model to make predictions on new data. The predict
method aggregates predictions from all trees.
predict(X, plotPath = false) {
return X.map(row => {
const predictions = this.trees.map(tree => this.predictTree(tree, row, plotPath));
const counts = {};
for (const pred of predictions) {
counts[pred] = (counts[pred] || 0) + 1;
}
let maxCount = 0;
let majorityClass = null;
for (const [key, count] of Object.entries(counts)) {
if (count > maxCount) {
maxCount = count;
majorityClass = key;
}
}
return majorityClass;
});
}
Parameters:
X
: Feature matrix of input data to predict.plotPath
(optional): Iftrue
, logs the path taken through the tree for each prediction.
For each data point, it collects predictions from all trees using the predictTree
method. It then determines the most common prediction among the aboreal majority voting to decide the final prediction.
2.6.1 predictTree(node, row, plotPath = false, depth = 0)
Traverses a single tree to make a prediction for a given data point.
predictTree(node, row, plotPath = false, depth = 0) {
if (node.isTerminal) {
if (plotPath) console.log(`${'| '.repeat(depth)}Leaf: Predict ${node.value}`);
return node.value;
}
if (plotPath) {
console.log(`${'| '.repeat(depth)}Node: X${node.index} < ${node.value} (Gini: ${node.gini})`);
}
if (row[node.index] < node.value) {
return this.predictTree(node.left, row, plotPath, depth + 1);
} else {
return this.predictTree(node.right, row, plotPath, depth + 1);
}
}
Parameters:
node
: The current node in the tree.row
: The data point to predict.plotPath
(optional): Iftrue
, logs the path taken through the tree.depth
: Current depth in the tree (used for logging purposes).
This function starts at the root and moves left or right based on the feature value until it reaches a leaf node. If plotPath
is true, it logs the path taken through the tree.
2.7 Predicting Probabilities
To compute metrics like the ROC curve and AUC, we need probability estimates. The predictProba
method provides class probability estimates based on the proportion of trees predicting each class.
predictProba(X) {
return X.map(row => {
const predictions = this.trees.map(tree => this.predictTree(tree, row));
const counts = {};
for (const pred of predictions) {
const classLabel = String(pred); // Ensure the label is a string
counts[classLabel] = (counts[classLabel] || 0) + 1;
}
const total = this.trees.length;
const probabilities = {};
for (const classLabel of this.classLabels) {
probabilities[classLabel] = counts[classLabel] ? counts[classLabel] / total : 0;
}
return probabilities;
});
}
This method ensures that all possible class labels are included in the probability output, even if their probability is zero.
2.8 Helper Methods
2.8.1 round(value)
Rounds a number to the specified decimal precision. Ensures consistency and readability in numerical outputs.
round(value) {
return parseFloat(value.toFixed(this.decimalPrecision));
}
Parameters:
value
: The numerical value to round.
This method helps prevent floating-point errors and keeps numerical values neat and tidy. Also, it helps prevent the infamous JavaScript floating-point shenanigans that can turn your elegant algorithm into a dumpster fire.
2.9 Model Persistence
2.9.1 saveModel()
Serializes the model to a JSON string for saving.
saveModel() {
return JSON.stringify({
nEstimators: this.nEstimators,
maxDepth: this.maxDepth,
minSize: this.minSize,
sampleSize: this.sampleSize,
decimalPrecision: this.decimalPrecision,
maxFeatures: this.maxFeatures,
trees: this.trees
});
}
This allows the trained model to be saved and reloaded later without retraining.
2.9.2 loadModel(modelJson)
Loads the model from a JSON string.
loadModel(modelJson) {
const model = JSON.parse(modelJson);
this.nEstimators = model.nEstimators;
this.maxDepth = model.maxDepth;
this.minSize = model.minSize;
this.sampleSize = model.sampleSize;
this.decimalPrecision = model.decimalPrecision;
this.maxFeatures = model.maxFeatures;
this.trees = model.trees;
}
Parameters:
modelJson
: A JSON string representing the saved model.
Restores the model state, making it ready for predictions without retraining.
2.10 Visualization and Debugging Methods
2.10.1 printTree(node, depth = 0)
Prints the structure of a tree.
printTree(node, depth = 0) {
if (node.isTerminal) {
console.log(`${'| '.repeat(depth)}[Leaf] Predict: ${node.value}`);
} else {
console.log(`${'| '.repeat(depth)}[X${node.index} < ${node.value}]`);
this.printTree(node.left, depth + 1);
this.printTree(node.right, depth + 1);
}
}
Useful for understanding how the tree makes decisions and for debugging purposes.
2.10.2 convertToHierarchy(node, depth = 0)
Converts a tree into a hierarchical structure for visualization.
convertToHierarchy(node, depth = 0) {
if (node.isTerminal) {
return { name: `Leaf: ${node.value}` };
}
return {
name: `X${node.index} < ${node.value}`,
children: [
this.convertToHierarchy(node.left, depth + 1),
this.convertToHierarchy(node.right, depth + 1)
]
};
}
This can be used with visualization libraries to create tree diagrams.
2.10.3 convertForestToHierarchy()
Converts the entire forest into a hierarchical structure.
convertForestToHierarchy() {
return this.trees.map((tree, index) => ({
name: `Tree ${index + 1}`,
children: [this.convertToHierarchy(tree)]
}));
}
Allows visualization of all trees in the forest, helping to bring order to the complexity.
2.10.4 getPaths(node, path = "")
Retrieves all decision paths in a tree.
getPaths(node, path = "") {
if (node.isTerminal) {
return [`${path}/Leaf: ${node.value}`];
}
const leftPaths = this.getPaths(node.left, `${path}/X${node.index} < ${node.value}`);
const rightPaths = this.getPaths(node.right, `${path}/X${node.index} >= ${node.value}`);
return [...leftPaths, ...rightPaths];
}
Provides a detailed view of all possible paths through the tree.
2.10.5 convertForestToPaths()
Retrieves decision paths from all trees in the forest.
convertForestToPaths() {
let paths = [];
this.trees.forEach((tree, index) => {
const treePaths = this.getPaths(tree, `Tree ${index + 1}`);
paths = [...paths, ...treePaths];
});
return paths;
}
Gives a comprehensive view of how the forest makes predictions.
2.11 Classification Domain Generation
Generates a grid of predictions over the feature space for visualization.
generateClassificationDomain(scatterData, stepSize = 0.1) {
const xMin = Math.min(...scatterData.map(d => d.x)) - 1;
const xMax = Math.max(...scatterData.map(d => d.x)) + 1;
const yMin = Math.min(...scatterData.map(d => d.y)) - 1;
const yMax = Math.max(...scatterData.map(d => d.y)) + 1;
const grid = [];
const values = [];
for (let x = xMin; x <= xMax; x += stepSize) {
for (let y = yMin; y <= yMax; y += stepSize) {
grid.push([x, y]);
values.push(this.predict([[x, y]])[0]);
}
}
const gridWidth = Math.round((xMax - xMin) / stepSize) + 1;
const gridHeight = Math.round((yMax - yMin) / stepSize) + 1;
const grid2D = Array.from({ length: gridHeight }, (_, i) => grid.slice(i * gridWidth, (i + 1) * gridWidth));
const values2D = Array.from({ length: gridHeight }, (_, i) => values.slice(i * gridWidth, (i + 1) * gridWidth));
return {
gridPredictions: values,
gridWidth: gridWidth,
gridHeight: gridHeight,
grid2D: grid2D,
values2D: values2D
};
}
Parameters:
scatterData
: An array of data points withx
andy
properties.stepSize
: The resolution of the grid.
This method helps in visualizing the decision boundaries learned by the model. It creates a grid over the feature space and predicts the class for each point in the grid.
3. What's Next?
I absolutely love decision trees. They're interpretable, intuitive, and just plain cool.
I could implement a boosting approach, but I think I'll dip my toes into some neural networks. Besides, the recursive insanity with decision trees has given me an insane headache.