π Differentiable Rule Extraction for Explainable AI
This notebook introduces the concept and implementation roadmap for Differentiable Rule Extractionβan approach for generating symbolic, human-readable rules from trained neural networks. This is a key task in Explainable Artificial Intelligence (XAI).
π§ Overview
Modern deep learning systems are accurate but opaque. The goal of this project is to extract interpretable logical rules (e.g., βif-thenβ statements) from differentiable models like neural nets, without losing too much predictive power.
π― Objective
Given a trained neural model ( f_\theta(x) ), develop a symbolic approximation ( g(x) ) such that:
\[\mathcal{L}_{\text{total}} = \alpha \cdot \mathcal{L}_{\text{fidelity}} + \beta \cdot \mathcal{L}_{\text{sparsity}} + \gamma \cdot \mathcal{L}_{\text{consistency}}\]Where:
- \(\mathcal{L}_{\text{fidelity}}\) measures how close the symbolic model is to the neural modelβs predictions
- \(\mathcal{L}_{\text{sparsity}}\) penalizes overly complex rules
- \(\mathcal{L}_{\text{consistency}}\) penalizes logical contradictions across examples
π¦ Setup
Install necessary libraries:
pip install torch scikit-learn sympy pandas numpy matplotlib
Optional: for symbolic regression models or differentiable logic layers, you may use:
pip install pysr # PySR: Symbolic Regression
π§ Step 1: Train a Neural Model
import torch
import torch.nn as nn
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# Load data
X, y = load_breast_cancer(return_X_y=True)
scaler = StandardScaler()
X = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# Convert to tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
# Basic neural model
model = nn.Sequential(
nn.Linear(X_train.shape[1], 16),
nn.ReLU(),
nn.Linear(16, 1),
nn.Sigmoid()
)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Train
for epoch in range(200):
output = model(X_train)
loss = criterion(output, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
π€ Step 2: Extract Intermediate Outputs
To mimic the behavior of the neural net, we extract either soft predictions (logits) or probabilities and train a surrogate model on them.
with torch.no_grad():
y_soft = model(X_train).numpy().flatten()
πͺ Step 3: Fit a Symbolic Rule Model (e.g., Decision Tree)
from sklearn.tree import DecisionTreeClassifier, export_text
tree = DecisionTreeClassifier(max_depth=3)
tree.fit(X_train.numpy(), y_soft > 0.5)
print(export_text(tree))
π Step 4: Define Custom Loss Functions
You can now formalize symbolic loss components.
Fidelity loss:
\[\mathcal{L}_{\text{fidelity}} = \frac{1}{N} \sum_{i=1}^N (g(x_i) - f_\theta(x_i))^2\]Sparsity loss:
Let ( R ) be the number of rules (or conditions):
\[\mathcal{L}_{\text{sparsity}} = \lambda \cdot R\]Consistency loss:
Quantify contradictions in rule coverage (this is domain-specific).
π§ͺ Step 5: Evaluate Rule Model
from sklearn.metrics import accuracy_score
y_pred_tree = tree.predict(X_test)
y_pred_nn = model(torch.tensor(X_test, dtype=torch.float32)).detach().numpy().flatten() > 0.5
print("Tree Accuracy vs True Labels:", accuracy_score(y_test, y_pred_tree))
print("Tree Fidelity vs Neural Net:", accuracy_score(y_pred_tree, y_pred_nn))
π Further Exploration
- Swap the decision tree for symbolic regression using
PySR
orgplearn
- Use
Neural-Backed Decision Trees
for hybrid models - Try differentiable program induction (e.g.,
Neural Logic Machines
,Logic Tensor Networks
)
β Summary
You now have a pipeline that:
- Trains a neural net
- Extracts soft outputs
- Fits symbolic surrogate rules
- Calculates loss terms to enforce interpretability
- Evaluates fidelity and accuracy
π References
- Ribeiro et al. (2016). βWhy Should I Trust You?β β LIME
- Lundberg & Lee (2017). SHAP values
- Yang et al. (2018). Neural-Backed Decision Trees
- Cranmer et al. (2020). PySR: Symbolic Regression
- Selsam et al. (2019). Neural Logic Machines