Decision Tree
Tree-based model making decisions through sequence of if-else questions. Splits data based on feature values to create hierarchical structure. Interpretable and handles non-linear relationships.
Visualization
Interactive visualization for Decision Tree
Decision Tree:
- • Tree-based classifier
- • Splits on features
Interactive visualization with step-by-step execution
Implementation
1class DecisionTree {
2 private root: TreeNode | null = null;
3 private maxDepth: number;
4 private minSamplesSplit: number;
5
6 constructor(maxDepth: number = 10, minSamplesSplit: number = 2) {
7 this.maxDepth = maxDepth;
8 this.minSamplesSplit = minSamplesSplit;
9 }
10
11 private calculateGini(y: number[]): number {
12 const counts = new Map<number, number>();
13 for (const label of y) {
14 counts.set(label, (counts.get(label) || 0) + 1);
15 }
16
17 let gini = 1;
18 const total = y.length;
19 for (const count of counts.values()) {
20 const p = count / total;
21 gini -= p * p;
22 }
23 return gini;
24 }
25
26 private findBestSplit(X: number[][], y: number[], featureIdx: number): {
27 threshold: number;
28 gain: number;
29 } {
30 const values = X.map(row => row[featureIdx]).sort((a, b) => a - b);
31 const uniqueValues = [...new Set(values)];
32
33 let bestGain = -Infinity;
34 let bestThreshold = 0;
35
36 const parentGini = this.calculateGini(y);
37
38 for (let i = 0; i < uniqueValues.length - 1; i++) {
39 const threshold = (uniqueValues[i] + uniqueValues[i + 1]) / 2;
40
41 const leftY: number[] = [];
42 const rightY: number[] = [];
43
44 for (let j = 0; j < X.length; j++) {
45 if (X[j][featureIdx] <= threshold) {
46 leftY.push(y[j]);
47 } else {
48 rightY.push(y[j]);
49 }
50 }
51
52 if (leftY.length === 0 || rightY.length === 0) continue;
53
54 const leftGini = this.calculateGini(leftY);
55 const rightGini = this.calculateGini(rightY);
56 const weightedGini = (leftY.length * leftGini + rightY.length * rightGini) / y.length;
57 const gain = parentGini - weightedGini;
58
59 if (gain > bestGain) {
60 bestGain = gain;
61 bestThreshold = threshold;
62 }
63 }
64
65 return { threshold: bestThreshold, gain: bestGain };
66 }
67
68 fit(X: number[][], y: number[]): void {
69 this.root = this.buildTree(X, y, 0);
70 }
71
72 private buildTree(X: number[][], y: number[], depth: number): TreeNode {
73 // Check stopping criteria
74 if (depth >= this.maxDepth || y.length < this.minSamplesSplit || new Set(y).size === 1) {
75 return new TreeNode(this.majorityClass(y));
76 }
77
78 // Find best split
79 let bestFeature = 0;
80 let bestSplit = { threshold: 0, gain: -Infinity };
81
82 for (let f = 0; f < X[0].length; f++) {
83 const split = this.findBestSplit(X, y, f);
84 if (split.gain > bestSplit.gain) {
85 bestSplit = split;
86 bestFeature = f;
87 }
88 }
89
90 if (bestSplit.gain <= 0) {
91 return new TreeNode(this.majorityClass(y));
92 }
93
94 // Split data
95 const leftX: number[][] = [], leftY: number[] = [];
96 const rightX: number[][] = [], rightY: number[] = [];
97
98 for (let i = 0; i < X.length; i++) {
99 if (X[i][bestFeature] <= bestSplit.threshold) {
100 leftX.push(X[i]);
101 leftY.push(y[i]);
102 } else {
103 rightX.push(X[i]);
104 rightY.push(y[i]);
105 }
106 }
107
108 const node = new TreeNode();
109 node.feature = bestFeature;
110 node.threshold = bestSplit.threshold;
111 node.left = this.buildTree(leftX, leftY, depth + 1);
112 node.right = this.buildTree(rightX, rightY, depth + 1);
113
114 return node;
115 }
116
117 private majorityClass(y: number[]): number {
118 const counts = new Map<number, number>();
119 for (const label of y) {
120 counts.set(label, (counts.get(label) || 0) + 1);
121 }
122 let maxCount = 0, majority = 0;
123 for (const [label, count] of counts) {
124 if (count > maxCount) {
125 maxCount = count;
126 majority = label;
127 }
128 }
129 return majority;
130 }
131
132 predict(X: number[][]): number[] {
133 return X.map(x => this.predictOne(x));
134 }
135
136 private predictOne(x: number[]): number {
137 let node = this.root;
138 while (node && !node.isLeaf()) {
139 if (x[node.feature!] <= node.threshold!) {
140 node = node.left;
141 } else {
142 node = node.right;
143 }
144 }
145 return node!.value!;
146 }
147}
148
149class TreeNode {
150 value?: number;
151 feature?: number;
152 threshold?: number;
153 left: TreeNode | null = null;
154 right: TreeNode | null = null;
155
156 constructor(value?: number) {
157 this.value = value;
158 }
159
160 isLeaf(): boolean {
161 return this.left === null && this.right === null;
162 }
163}Deep Dive
Theoretical Foundation
Recursively splits data to maximize information gain or minimize impurity. Uses metrics: Gini impurity, entropy (information gain), or variance reduction. Each node tests feature, branches represent outcomes. Leaves contain predictions. Prone to overfitting, addressed by pruning or ensemble methods.
Complexity
Time
O(n×m×log n)
O(n×m×log n)
O(n²×m)
Space
O(n) tree size
Applications
Industry Use
Medical diagnosis systems
Credit scoring and loan approval
Customer segmentation
Feature selection in data mining
Rule-based expert systems
Fraud detection
Marketing campaign targeting
Use Cases
Related Algorithms
K-Nearest Neighbors (KNN)
Simple, instance-based learning algorithm that classifies new data points based on k closest training examples. Non-parametric, lazy learning method used for classification and regression.
Linear Regression
Fundamental supervised learning algorithm modeling relationship between dependent variable and independent variables using linear equation. Foundational for predictive modeling and statistics.
Logistic Regression
Binary classification algorithm using sigmoid function to model probability of class membership. Despite name, it's classification not regression. Foundation for neural networks.
K-Means Clustering
Unsupervised learning algorithm partitioning n observations into k clusters. Each observation belongs to cluster with nearest mean. Widely used for data segmentation and pattern discovery.