Blog / Architecture

From PyTorch to SNN: A Technical Conversion Walkthrough

Code transformation pipeline from PyTorch ANN model to SNN compiled binary

This walkthrough converts a real PyTorch keyword-spotting model to a deployable SNN binary using the Neurmorph SDK. The starting point is a 4-layer MLP trained on the Google Speech Commands dataset (35-class, spectrogram input, 80.3% test accuracy). The target hardware is BrainChip Akida AKD1500. By the end of this walkthrough, you'll have a compiled .nmc binary and a measured energy-per-inference number.

There are two paths to an SNN from a PyTorch model: ANN-to-SNN conversion (replace activations with LIF neurons, calibrate thresholds) and direct surrogate-gradient training with nrm.nn layers. This walkthrough covers ANN-to-SNN conversion, which is faster to iterate but typically produces higher firing rates (and thus higher energy) than native SNN training. The tradeoffs are discussed along the way.

Step 1: The source ANN model

import torch
import torch.nn as nn

class KWS_MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(40 * 101, 512),   # MFCC: 40 mels × 101 frames, flattened
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 35),
        )

    def forward(self, x):
        return self.net(x.flatten(1))

model = KWS_MLP()
# Trained weights: 80.3% top-1 on GSC test set
model.load_state_dict(torch.load('kws_mlp.pth'))

This is the model we're converting. Note that it uses standard nn.ReLU activations — these will be replaced by LIF neurons during conversion. The input is a 40×101 MFCC spectrogram (40 mel-frequency cepstral coefficients over 101 time frames), flattened to a 4040-element vector.

Step 2: ANN-to-SNN conversion with threshold calibration

ANN-to-SNN conversion replaces each ReLU with a LIF neuron population and sets the threshold V_th for each layer using calibration data. The threshold calibration ensures that the average firing rate across the conversion is in the target range (5–20%) rather than saturating (>50%) or being essentially silent (<2%).

import nrm.convert as nmc_convert

# Load calibration data (a representative sample of training inputs)
calib_loader = torch.utils.data.DataLoader(
    calib_dataset,  # 1000 samples from training set
    batch_size=64,
    shuffle=False
)

converter = nmc_convert.ANN2SNN(
    model=model,
    calibration_data=calib_loader,
    T=20,                          # number of timesteps for rate coding
    target_firing_rate=(0.05, 0.20),  # (min, max) acceptable range
    threshold_mode='percentile',   # set V_th at the 99.9th percentile of layer activations
    encoding='rate',               # Poisson rate encoding of input
)

snn_model = converter.convert()
print(converter.calibration_report())
# Output:
# Layer 0 (Linear 4040→512): threshold=1.34, avg_rate=11.2%, max_rate=48.3%
# Layer 1 (Linear 512→256):  threshold=0.92, avg_rate=8.7%,  max_rate=31.5%
# Layer 2 (Linear 256→128):  threshold=0.85, avg_rate=9.1%,  max_rate=29.8%

The calibration report is the first diagnostic checkpoint. The average rates (8–11%) are in the target range, which predicts reasonable energy efficiency. The maximum rates (up to 48% for layer 0 on some inputs) indicate that high-energy inputs will produce high firing rates — which is expected but means the energy-per-inference will vary with input content.

What to do when calibration produces high rates

If average rates come in above 25–30%, the threshold calibration failed to find good thresholds for the input distribution. Common causes: the calibration set is not representative of the deployment input distribution; the source ANN uses batch normalization (which skews the activation distribution); or the model has dead ReLU pathways that produce artificially low calibration activations.

For batch normalization layers: the converter needs to fold batch norm parameters into the preceding linear layer before calibration. The fold_batchnorm=True option handles this. It's enabled by default but worth checking in the report.

Step 3: Accuracy validation in software simulation

Before compiling to hardware, validate the converted SNN's accuracy in software simulation. The simulation runs the T=20 timestep inference loop on CPU:

from nrm.sim import SNNSimulator

sim = SNNSimulator(snn_model, T=20, device='cpu')

correct = 0
total = 0
for inputs, labels in test_loader:
    # Rate encode inputs: Poisson spike trains over T timesteps
    spike_inputs = nmc_convert.rate_encode(inputs, T=20)  # shape: (T, B, 4040)
    outputs = sim.forward(spike_inputs)                   # spike count per output neuron
    predicted = outputs.sum(dim=0).argmax(dim=-1)
    correct += (predicted == labels).sum().item()
    total += labels.size(0)

print(f"SNN accuracy: {correct/total*100:.1f}%")
# Target: within 3% of original 80.3% → acceptable range is 77.3% or above
# Actual: 78.1%

A 2.2 percentage point accuracy drop from ANN-to-SNN conversion at T=20 is typical. If the drop is larger (>5%), options are: increase T (T=50 typically recovers most accuracy at higher energy cost); switch to native SNN training with surrogate gradients (takes longer but produces lower firing rates and better accuracy tradeoff); or reduce model complexity to a form that converts more cleanly.

Step 4: Target-aware weight quantization

Akida AKD1500 supports INT4 and binary weights. The converter has already produced float32 weights in the SNN model. The quantization pass fits the weights to INT4:

from nrm.quantize import WeightQuantizer

quantizer = WeightQuantizer(
    model=snn_model,
    target='akida-akd1500',
    weight_precision='int4',
    qat_finetune=True,          # brief quantization-aware fine-tuning
    finetune_epochs=5,
    finetune_lr=1e-4,
    calib_data=calib_loader,
)

quantized_model = quantizer.quantize()
print(quantizer.accuracy_report())
# Post-quantization (INT4, 5 epochs QAT fine-tuning): 77.4%
# Delta from float32 SNN: -0.7%
# Total delta from original ANN: -2.9%

The 5-epoch QAT fine-tuning recovers most of the quantization accuracy loss. Without fine-tuning, INT4 quantization typically drops an additional 1.5–2.5% on this class of model. The fine-tuning uses standard Adam with a small learning rate and trains only the quantization scale factors, not the full weight tensors — it's inexpensive (typically 15–20 minutes on a standard GPU).

Step 5: Compilation to hardware target

import nrm.compiler as nmc

compile_config = nmc.CompileConfig(
    target='akida-akd1500',
    timesteps=20,
    input_encoding='rate',
    output_decoding='spike_count',
    optimization_level=2,         # runs all optimization passes
    sleep_annotation=True,        # enable wake-on-spike annotations
)

binary = nmc.compile(quantized_model, compile_config)
binary.save('kws_akida_t20.nmc')

# Compilation report
report = binary.compilation_report()
print(report.summary())
# Target: akida-akd1500
# Network: 3 inference layers (input encoding handled by hardware)
# Total neurons: 931 (512 + 256 + 128 + 35)
# Total synapses: 339,968 weight entries
# Core allocation: 4 cores used (of 80 available)
# Synapse memory: 168 KB (of 2 MB available)
# Dead neurons eliminated: 47 (5.0% of total)
# Estimated SOPs/inference: 312,000 (at 10% avg firing rate)
# Estimated energy/inference: 0.81 µJ (model: 0.8 pJ/SOP × 312K SOPs)
# Compile time: 4.2s

Step 6: Hardware validation and measured energy

Load the compiled binary to the Akida board and run the test set:

from nrm.runtime import NMCRuntime

rt = NMCRuntime(target='akida-akd1500', device_id=0)
net = rt.load('kws_akida_t20.nmc')

# Run test set with energy measurement
results = rt.run_dataset(
    network=net,
    dataset=test_dataset,
    measure_energy=True,
    batch_size=1,           # single-sample for per-inference measurement
)

print(f"Hardware accuracy: {results.accuracy*100:.1f}%")
print(f"Mean energy/inference: {results.energy_mean_uj:.3f} µJ")
print(f"P50 latency: {results.latency_p50_us:.0f} µs")
print(f"P99 latency: {results.latency_p99_us:.0f} µs")

# Hardware accuracy: 77.1%   (vs 77.4% in simulation — 0.3% hardware variation)
# Mean energy/inference: 0.88 µJ
# P50 latency: 312 µs
# P99 latency: 580 µs

The hardware energy (0.88 µJ) is 8.6% higher than the compilation estimate (0.81 µJ), primarily from routing overhead not captured in the per-SOP model. This is within expected compiler estimation accuracy for the AKD1500 routing topology.

Comparison: ANN-to-SNN vs native SNN training

The conversion path produced 77.1% accuracy at 0.88 µJ/inference. For comparison, a natively trained SNN with the same architecture and target (same 3-layer LIF network, trained from scratch with BPTT + FastSigmoid surrogate gradients over 100 epochs) achieves:

  • Accuracy: 79.3% (2.2 percentage points better)
  • Average firing rate: 7.1% (vs 10.2% for ANN-to-SNN)
  • Energy/inference: 0.58 µJ (34% lower)
  • Training time: approximately 8 hours vs 30 minutes for ANN-to-SNN

The native SNN training path wins on both accuracy and energy efficiency. The ANN-to-SNN conversion path wins on development speed — useful when you have an existing trained ANN and want a quick accuracy/energy estimate before committing to a full SNN training pipeline. For production deployment, we recommend native SNN training once the model architecture is validated through the conversion path.

Common conversion failures and diagnostics

Several model architectures don't convert cleanly. Batch normalization requires fold-and-recalibrate. Dropout layers must be removed before conversion (they have no meaning in inference mode). Skip connections (ResNet-style) require the converter to handle additive spike merging, which is supported but produces higher firing rates at merge points. Attention layers (as described in our temporal locality post) require the dedicated LocalSpikeAttn layer class rather than standard conversion.

The compiler emits ConversionWarning events for each of these patterns and suggests the appropriate fix. Running nmc.analyze(model) before conversion produces a pre-conversion analysis report that flags these issues before any computation occurs — useful for estimating conversion difficulty on a new model architecture.