Not the neurons we want, but the neurons we need
- π₯ Activation-Free Non-linearity: Learn complex, non-linear relationships without separate activation functions
- π Multiple Frameworks: Full support for Flax (Linen & NNX), Keras, PyTorch, and TensorFlow
- π§ Advanced Layer Types: Dense (YatNMN), Convolutional (YatConv, YatConvTranspose), Attention (YatAttention)
- π Recurrent Layers: YatSimpleRNN, YatLSTM, YatGRU with custom activation functions
- β¨ Custom Squashing Functions: Softermax, softer_sigmoid, soft_tanh for smoother gradients
- β‘ DropConnect Support: Built-in regularization for NNX layers
- π Novel Operations: Yat-Product and Yat-Conv based on distance-similarity tradeoff
- π Research-Driven: Based on "Deep Learning 2.0/2.1: Artificial Neurons that Matter"
nmn (Neural-Matter Network) provides cutting-edge neural network layers for multiple frameworks (Flax NNX & Linen, Keras, PyTorch, TensorFlow) that learn non-linearity without requiring traditional activation functions. The library introduces the Yat-Product operationβa novel approach that combines similarity and distance metrics to create inherently non-linear transformations.
- Distance-Similarity Tradeoff: Unlike traditional neurons that rely on dot products and activations, Yat neurons balance alignment (similarity) with proximity (distance)
- Built-in Non-linearity: The squared-ratio operation eliminates the need for ReLU, sigmoid, or tanh activations
- Geometric Interpretation: Maximizes response when weights and inputs are aligned, close, and large in magnitude
- Production-Ready: Comprehensive implementations across major deep learning frameworks
Inspired by the papers:
Deep Learning 2.0: Artificial Neurons that Matter - Reject Correlation, Embrace Orthogonality
Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks
Yat-Product: $$ β΅(\mathbf{w},\mathbf{x}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{|\mathbf{w} - \mathbf{x}|^2 + \epsilon} = \frac{ |\mathbf{x}|^2 |\mathbf{w}|^2 \cos^2 \theta}{|\mathbf{w}|^2 - 2\mathbf{w}^\top\mathbf{x} + |\mathbf{x}|^2 + \epsilon} = \frac{ |\mathbf{x}|^2 |\mathbf{w}|^2 \cos^2 \theta}{((\mathbf{x}-\mathbf{w})\cdot(\mathbf{x}-\mathbf{w}))^2 + \epsilon}. $$
Explanation:
-
$\mathbf{w}$ is the weight vector,$\mathbf{x}$ is the input vector. -
$\langle \mathbf{w}, \mathbf{x} \rangle$ is the dot product between$\mathbf{w}$ and$\mathbf{x}$ . -
$|\mathbf{w} - \mathbf{x}|^2$ is the squared Euclidean distance between$\mathbf{w}$ and$\mathbf{x}$ . -
$\epsilon$ is a small constant for numerical stability. -
$\theta$ is the angle between$\mathbf{w}$ and$\mathbf{x}$ .
This operation:
-
Numerator: Squares the similarity (dot product) between
$\mathbf{w}$ and$\mathbf{x}$ , emphasizing strong alignments. -
Denominator: Penalizes large distances, so the response is high only when
$\mathbf{w}$ and$\mathbf{x}$ are both similar in direction and close in space. - No activation needed: The non-linearity is built into the operation itself, allowing the layer to learn complex, non-linear relationships without a separate activation function.
-
Geometric view: The output is maximized when
$\mathbf{w}$ and$\mathbf{x}$ are both large in norm, closely aligned (small$\theta$ ), and close together in Euclidean space.
Yat-Conv: $$ β΅^*(\mathbf{W}, \mathbf{X}) := \frac{\langle \mathbf{w}, \mathbf{x} \rangle^2}{|\mathbf{w} - \mathbf{x}|^2 + \epsilon} = \frac{\left(\sum_{i,j} w_{ij} x_{ij}\right)^2}{\sum_{i,j} (w_{ij} - x_{ij})^2 + \epsilon} $$
Where:
-
$\mathbf{W}$ and$\mathbf{X}$ are local patches (e.g., kernel and input patch in convolution) -
$w_{ij}$ and$x_{ij}$ are elements of the kernel and input patch, respectively -
$\epsilon$ is a small constant for numerical stability
This generalizes the Yat-product to convolutional (patch-wise) operations.
The library provides several layer types across all supported frameworks:
| Layer Type | Description | Flax NNX | Flax Linen | Keras | PyTorch | TensorFlow |
|---|---|---|---|---|---|---|
| YatNMN | Dense/Linear layer | β | β | β | β | β |
| YatConv | Convolutional layer | β | β | β | β | β |
| YatConvTranspose | Transposed convolution | β | π§ | π§ | π§ | π§ |
| YatAttention | Multi-head attention | β | π§ | π§ | π§ | π§ |
| Layer Type | Description | Status |
|---|---|---|
| YatSimpleCell | Simple RNN cell | β |
| YatLSTMCell | LSTM cell with Yat operations | β |
| YatGRUCell | GRU cell with Yat operations | β |
| Function | Formula | Description |
|---|---|---|
| softermax | Generalized softmax with power parameter | |
| softer_sigmoid | Smoother sigmoid variant | |
| soft_tanh | Smoother tanh variant |
- β DropConnect: Built-in weight dropout for regularization
- β Alpha Scaling: Learnable output scaling parameter
- π§ Ternary Networks: Quantized weight versions (in development)
- β Custom Initializers: Optimized for Yat-layer convergence
Legend: β Implemented, π§ In Development
pip install nmnFor optimal performance and full feature access, install with framework-specific dependencies:
# For JAX/Flax (NNX and Linen)
pip install "nmn[nnx]" # or "nmn[linen]"
# For PyTorch
pip install "nmn[torch]"
# For TensorFlow/Keras
pip install "nmn[keras]" # or "nmn[tf]"
# For all frameworks
pip install "nmn[all]"
# For development and testing
pip install "nmn[dev]" # Basic dev tools
pip install "nmn[test]" # All dependencies for testinggit clone https://github.com/mlnomadpy/nmn.git
cd nmn
pip install -e ".[dev]"import jax
import jax.numpy as jnp
from flax import nnx
from nmn.nnx.nmn import YatNMN
# Create a simple Yat dense layer
model_key, param_key, drop_key, input_key = jax.random.split(jax.random.key(0), 4)
layer = YatNMN(
in_features=3,
out_features=4,
rngs=nnx.Rngs(params=param_key, dropout=drop_key)
)
# Forward pass
dummy_input = jax.random.normal(input_key, (2, 3)) # Batch size 2
output = layer(dummy_input)
print("YatNMN Output Shape:", output.shape) # (2, 4)from nmn.nnx.yatconv import YatConv
# Create a Yat convolutional layer
conv_layer = YatConv(
in_features=3,
out_features=8,
kernel_size=(3, 3),
strides=1,
padding='SAME',
rngs=nnx.Rngs(params=jax.random.key(1))
)
# Forward pass on image data
image = jax.random.normal(jax.random.key(2), (1, 28, 28, 3))
conv_output = conv_layer(image)
print("YatConv Output Shape:", conv_output.shape) # (1, 28, 28, 8)from nmn.nnx.yatattention import YatMultiHeadAttention
# Create multi-head attention with Yat-product
attention = YatMultiHeadAttention(
num_heads=8,
in_features=512,
qkv_features=512,
out_features=512,
use_softermax=True, # Use custom softermax activation
rngs=nnx.Rngs(params=jax.random.key(3))
)
# Forward pass
seq = jax.random.normal(jax.random.key(4), (2, 10, 512)) # (batch, seq_len, features)
attn_output = attention(seq)
print("Attention Output Shape:", attn_output.shape) # (2, 10, 512)from nmn.nnx.rnn import YatLSTMCell, YatGRUCell
# LSTM cell
lstm_cell = YatLSTMCell(
in_features=64,
hidden_features=128,
rngs=nnx.Rngs(params=jax.random.key(5))
)
# GRU cell
gru_cell = YatGRUCell(
in_features=64,
hidden_features=128,
rngs=nnx.Rngs(params=jax.random.key(6))
)
# Initialize carry state
batch_size = 2
carry = lstm_cell.initialize_carry(jax.random.key(7), (batch_size,))
# Process sequence
x = jax.random.normal(jax.random.key(8), (batch_size, 64))
new_carry, output = lstm_cell(carry, x)# Enable DropConnect for regularization
layer_with_dropout = YatNMN(
in_features=128,
out_features=64,
use_dropconnect=True,
drop_rate=0.2, # 20% dropout on weights
rngs=nnx.Rngs(params=jax.random.key(9), dropout=jax.random.key(10))
)
# Training mode (with dropout)
x_train = jax.random.normal(jax.random.key(11), (32, 128))
y_train = layer_with_dropout(x_train, deterministic=False)
# Inference mode (no dropout)
x_test = jax.random.normal(jax.random.key(12), (32, 128))
y_test = layer_with_dropout(x_test, deterministic=True)Note: For examples in other frameworks (PyTorch, Keras, TensorFlow, Linen), see the examples/ directory.
The examples/ directory contains comprehensive, production-ready examples for all supported frameworks:
examples/
βββ nnx/ # Flax NNX (most feature-complete)
β βββ vision/
β β βββ cnn_cifar.py # Complete CNN training on CIFAR-10
β βββ language/
β βββ mingpt.py # GPT implementation with Yat layers
βββ torch/ # PyTorch examples
β βββ yat_examples.py # Basic usage patterns
β βββ yat_cifar10.py # CIFAR-10 training
β βββ vision/
β βββ resnet_training.py # ResNet with Yat layers
βββ keras/ # Keras examples
β βββ basic_usage.py # Getting started
β βββ vision_cifar10.py # Image classification
β βββ language_imdb.py # Text classification
βββ tensorflow/ # TensorFlow examples
β βββ basic_usage.py
β βββ vision_cifar10.py
β βββ language_imdb.py
βββ linen/ # Flax Linen examples
β βββ basic_usage.py
βββ comparative/ # Framework comparisons
βββ framework_comparison.py # Side-by-side comparison
Check Framework Availability:
python examples/comparative/framework_comparison.pyTrain Vision Model:
# PyTorch - CIFAR-10 with Yat convolutions
python examples/torch/yat_cifar10.py
# Keras - CIFAR-10 classification
python examples/keras/vision_cifar10.py
# Flax NNX - Full CNN with data augmentation
python examples/nnx/vision/cnn_cifar.pyNatural Language Processing:
# GPT-style language model (Flax NNX)
python examples/nnx/language/mingpt.py
# Sentiment analysis (Keras)
python examples/keras/language_imdb.pyBasic Usage Patterns:
# PyTorch introduction
python examples/torch/yat_examples.py
# Keras getting started
python examples/keras/basic_usage.pyEach example includes:
- Complete training loops with metrics tracking
- Data loading and preprocessing using standard datasets
- Model architecture definitions with Yat layers
- Hyperparameter configurations and best practices
- Evaluation and visualization code
- Documentation explaining design choices
The examples demonstrate:
- Computer Vision: Image classification, feature extraction, CNNs
- Natural Language Processing: Language modeling, sentiment analysis, attention mechanisms
- Regularization: DropConnect, weight decay, data augmentation
- Optimization: Learning rate schedules, gradient clipping, optimizer tuning
- Production Patterns: Model checkpointing, logging, reproducibility
The library includes a comprehensive test suite with high code coverage:
# Install test dependencies
pip install "nmn[test]"
# Run all tests
pytest tests/
# Run with coverage report
pytest tests/ --cov=nmn --cov-report=html --cov-report=term
# Run specific framework tests
pytest tests/test_nnx/ # Flax NNX tests
pytest tests/test_torch/ # PyTorch tests
pytest tests/test_keras/ # Keras tests
pytest tests/test_tf/ # TensorFlow tests
pytest tests/test_linen/ # Flax Linen tests
# Run integration tests
pytest tests/integration/
# Run with verbose output
pytest tests/ -v
# Run specific test file
pytest tests/test_nnx/test_basic.py -vtests/
βββ test_nnx/ # NNX layer tests
β βββ test_basic.py # YatNMN, YatConv, Attention, RNN
βββ test_torch/ # PyTorch layer tests
β βββ test_basic.py
βββ test_keras/ # Keras layer tests
β βββ test_keras_basic.py
βββ test_tf/ # TensorFlow layer tests
β βββ test_tf_basic.py
βββ test_linen/ # Linen layer tests
β βββ test_basic.py
βββ integration/ # Cross-framework compatibility
βββ test_compatibility.py
The project uses GitHub Actions for automated testing:
- β Tests run on every push and pull request
- β Coverage reports uploaded to Codecov
- β Multi-framework compatibility verified
- β Code quality checks (formatting, linting)
- Core Architecture: Yat-Product and Yat-Conv operations across all frameworks
- Flax NNX Features: Dense, Conv, ConvTranspose, Attention, RNN cells (Simple, LSTM, GRU)
- Custom Activations: Softermax, softer_sigmoid, soft_tanh
- Regularization: DropConnect support in NNX
- Multi-Framework: Full implementations for NNX, Linen, Keras, PyTorch, TensorFlow
- Examples: Comprehensive examples for vision and language tasks
- Testing: Complete test suite with CI/CD and code coverage
- Documentation: API documentation and usage examples
- Ternary Networks: Quantized weight versions of Yat layers
- Additional RNN Features: Bidirectional RNN wrappers, stacked RNNs
- Advanced Attention: Sparse attention, linear attention variants
- Performance: JIT compilation optimizations, memory efficiency improvements
- Documentation: API reference website, tutorials, benchmark results
- Graph Neural Networks: Yat-based message passing layers
- Normalization Layers: YatBatchNorm, YatLayerNorm variants
- Advanced Regularization: Spectral normalization, weight normalization
- Mixed Precision: FP16/BF16 training support
- Model Zoo: Pre-trained models for common tasks
- Benchmarks: Comprehensive performance comparisons with standard layers
- Interactive Demos: Colab notebooks, live web demos
We welcome contributions from the community! Here's how you can help:
-
Fork and Clone:
git clone https://github.com/yourusername/nmn.git cd nmn -
Install in Development Mode:
pip install -e ".[dev]" # or install all dependencies for testing pip install -e ".[test]"
-
Install Pre-commit Hooks (optional but recommended):
pip install pre-commit pre-commit install
We maintain high code quality standards:
# Format code with Black
black src/ tests/ examples/
# Check code style
flake8 src/ tests/
# Sort imports
isort src/ tests/ examples/
# Type checking
mypy src/nmnEnsure all tests pass before submitting:
# Run full test suite
pytest tests/ -v
# Run tests for specific framework
pytest tests/test_nnx/ -v
# Check code coverage
pytest tests/ --cov=nmn --cov-report=term-missing-
Bug Reports: Open an issue with:
- Clear description of the problem
- Minimal reproducible example
- Expected vs actual behavior
- System information (OS, Python version, framework versions)
-
Feature Requests: Open an issue describing:
- The proposed feature and its use case
- Why it belongs in the library
- Potential implementation approach
-
Pull Requests:
- Create a new branch for your feature/fix
- Write tests for new functionality
- Update documentation as needed
- Ensure all tests pass and code is formatted
- Reference related issues in PR description
-
Documentation: Help improve:
- API documentation and docstrings
- Usage examples and tutorials
- README and guides
- Example code and notebooks
- π Bug Fixes: Check open issues
- β¨ New Features: Implement layers for additional frameworks
- π Documentation: Improve guides, add examples
- π§ͺ Testing: Increase test coverage, add edge cases
- β‘ Performance: Optimize implementations, add benchmarks
- π¨ Examples: Add new use cases and applications
This project is licensed under the GNU Affero General Public License v3 (AGPL-3.0). See the LICENSE file for details.
- β Free to use for personal, academic, and commercial projects
- β Modify and distribute with attribution
β οΈ Network use is distribution: If you run modified code on a server accessible over a network, you must make the modified source code available- β Compatible with open-source research and academic use
For commercial applications where AGPL terms are not suitable, please contact us to discuss alternative licensing options.
Flax NNX (Most Feature-Complete):
from nmn.nnx.nmn import YatNMN
from nmn.nnx.yatconv import YatConv, YatConvTranspose
from nmn.nnx.yatattention import YatMultiHeadAttention, YatSelfAttention
from nmn.nnx.rnn import YatSimpleCell, YatLSTMCell, YatGRUCell
from nmn.nnx.squashers import softermax, softer_sigmoid, soft_tanhPyTorch:
from nmn.torch.nmn import YatNMN
from nmn.torch.conv import YatConvKeras/TensorFlow:
from nmn.keras.nmn import YatNMN
from nmn.keras.conv import YatConvFlax Linen:
from nmn.linen.nmn import YatNMNYatNMN (Dense Layer):
in_features: Input dimensionout_features: Output dimensionuse_bias: Whether to include bias term (default: True)use_alpha: Whether to use learnable scaling (default: True)use_dropconnect: Enable DropConnect regularization (default: False)drop_rate: DropConnect probability (default: 0.0)epsilon: Numerical stability constant (default: 1e-5)
YatConv (Convolutional Layer):
in_features: Number of input channelsout_features: Number of output channelskernel_size: Convolution kernel size (int or tuple)strides: Stride of convolution (default: 1)padding: Padding mode ('SAME', 'VALID', or tuple)use_bias,use_alpha,epsilon: Same as YatNMN
YatAttention (Attention Layer):
num_heads: Number of attention headsin_features: Input feature dimensionqkv_features: Query/Key/Value dimensionout_features: Output dimensionuse_softermax: Use custom softermax instead of softmaxepsilon: Numerical stability for distance calculation
If you use nmn in your research or projects, please cite the foundational papers:
@article{bouhsine2024dl2,
author = {Taha Bouhsine},
title = {Deep Learning 2.0: Artificial Neurons that Matter - Reject Correlation, Embrace Orthogonality},
year = {2024},
note = {Foundational paper for Yat-Product operation}
}@article{bouhsine2025dl21,
author = {Taha Bouhsine},
title = {Deep Learning 2.1: Mind and Cosmos - Towards Cosmos-Inspired Interpretable Neural Networks},
year = {2025},
note = {Extended theoretical framework}
}This library is inspired by research into alternative neural network architectures that challenge conventional wisdom about how artificial neurons should operate. Special thanks to the JAX, Flax, PyTorch, and TensorFlow communities for creating the foundational tools that make this work possible.
- π Documentation: GitHub Wiki (coming soon)
- π Bug Reports: GitHub Issues
- π¬ Discussions: GitHub Discussions
- π§ Contact: [email protected]
- JAX - Composable transformations of Python+NumPy programs
- Flax - Neural network library for JAX
- PyTorch - Deep learning framework
- TensorFlow - End-to-end machine learning platform
Built with β€οΈ by the ML Nomads team