diff --git a/examples/xor_network/train.py b/examples/xor_network/train.py index c5cd6799..a24180a4 100644 --- a/examples/xor_network/train.py +++ b/examples/xor_network/train.py @@ -45,14 +45,27 @@ class XORNetwork: """A simple 2-layer network for solving XOR.""" def __init__(self): + from tinytorch.core.autograd import Variable + # Architecture: 2 -> 4 -> 1 self.hidden = Dense(2, 4) self.output = Dense(4, 1) self.relu = ReLU() self.sigmoid = Sigmoid() + + # Convert parameters to Variables for training + self.hidden.weights = Variable(self.hidden.weights, requires_grad=True) + self.hidden.bias = Variable(self.hidden.bias, requires_grad=True) + self.output.weights = Variable(self.output.weights, requires_grad=True) + self.output.bias = Variable(self.output.bias, requires_grad=True) def forward(self, x): """Forward pass through the network.""" + # Convert input to Variable if it isn't already + from tinytorch.core.autograd import Variable + if not hasattr(x, 'requires_grad'): + x = Variable(x, requires_grad=True) + x = self.hidden(x) x = self.relu(x) x = self.output(x) @@ -61,16 +74,8 @@ class XORNetwork: def parameters(self): """Get all trainable parameters.""" - params = [] - if hasattr(self.hidden, 'weights'): - params.append(self.hidden.weights) - if hasattr(self.hidden, 'bias') and self.hidden.bias is not None: - params.append(self.hidden.bias) - if hasattr(self.output, 'weights'): - params.append(self.output.weights) - if hasattr(self.output, 'bias') and self.output.bias is not None: - params.append(self.output.bias) - return params + return [self.hidden.weights, self.hidden.bias, + self.output.weights, self.output.bias] def train(model, X, y, epochs=1000, lr=0.5): @@ -111,8 +116,12 @@ def train(model, X, y, epochs=1000, lr=0.5): def evaluate(model, X, y): """Evaluate model accuracy.""" predictions = model.forward(X) - predicted_classes = (predictions.data > 0.5).astype(int) - correct = np.sum(predicted_classes == y.data) + # Handle Variable data extraction + pred_data = predictions.data.data if hasattr(predictions.data, 'data') else predictions.data + y_data = y.data.data if hasattr(y.data, 'data') else y.data + + predicted_classes = (pred_data > 0.5).astype(int) + correct = np.sum(predicted_classes == y_data) return correct / y.shape[0] @@ -140,10 +149,20 @@ def main(): print("Input | Target | Prediction | Correct") print("-" * 40) + # Handle Variable data extraction for printing + X_data = X.data.data if hasattr(X.data, 'data') else X.data + y_data = y.data.data if hasattr(y.data, 'data') else y.data + pred_data = predictions.data.data if hasattr(predictions.data, 'data') else predictions.data + for i in range(X.shape[0]): - x_input = X.data[i] - target = y.data[i, 0] - pred = predictions.data[i, 0] + # Convert to numpy for indexing + X_np = np.array(X_data) + y_np = np.array(y_data) + pred_np = np.array(pred_data) + + x_input = X_np[i] + target = y_np[i, 0] + pred = pred_np[i, 0] correct = "✅" if abs(pred - target) < 0.5 else "❌" print(f"{x_input} | {target} | {pred:.3f} | {correct}")