Skip to content
\n\n

I tried adding the layer names in exclude_op_names in config, but it does not seem to respect it. I even tried adding the entire list and still it does as if this prop is not there.

\n
sparsity_ratio = 0.5\nconfig_list = [{\n    'op_types': ['Conv2d'],\n    'sparse_ratio': sparsity_ratio,\n    'exclude_op_names': [\n        'conv1.weight',\n        'conv1.bias',\n        'bn1.weight',\n        'bn1.bias',\n        'conv2.weight',\n        'conv2.bias',\n        'bn2.weight',\n        'bn2.bias'\n    ],\n}]\n
\n

Even if it works, I am no entirely if the pipeline will entirely skip the weights when mentioned. For example in the above model I would like to preserve the output channels of conv2.

\n

I tried adding conv2* to exclude_op_names_re, then I get WARNING: no multi-dimension masks found. followed by a IndexError: list index out of range. It is understandable that it is trying to skip it, but I only want it to preserve the dim0 of the weights shape in conv2. I also tried adding bn2*, but it simply ignores it.

\n

What is the correct way to preserve or freeze output channels of a particular layer in the model?

","upvoteCount":1,"answerCount":1,"acceptedAnswer":{"@type":"Answer","text":"

Found out that the op names do not and should not include weight or bias. So with that, using adding the last layer's name in the exclude_op_names just works:

\n
config_list = [{\n    'op_types': ['Conv2d'],\n    'sparse_ratio': sparsity_ratio,\n    'exclude_op_names': [\n        'conv2',\n    ]\n}]\n
\n
\nLog\n
Ouput shape: torch.Size([1, 80, 32, 32])\n[2024-01-18 16:35:00] Start to speedup the model...\n[2024-01-18 16:35:00] Resolve the mask conflict before mask propagate...\n[2024-01-18 16:35:00] dim0 sparsity: 0.489796\n[2024-01-18 16:35:00] dim1 sparsity: 0.000000\n0 Filter\n[2024-01-18 16:35:00] dim0 sparsity: 0.489796\n[2024-01-18 16:35:00] dim1 sparsity: 0.000000\n[2024-01-18 16:35:00] Infer module masks...\n[2024-01-18 16:35:00] Propagate original variables\n[2024-01-18 16:35:00] Propagate variables for placeholder: x, output mask:  0.0000 \n[2024-01-18 16:35:00] Propagate variables for call_module: conv1, weight:  0.4898 bias:  0.4898 , output mask:  0.0000 \n[2024-01-18 16:35:00] Propagate variables for call_module: bn1, , output mask:  0.0000 \n[2024-01-18 16:35:00] Propagate variables for call_module: relu, , output mask:  0.0000 \n[2024-01-18 16:35:01] Propagate variables for call_module: conv2, , output mask:  0.0000 \n[2024-01-18 16:35:01] Propagate variables for call_module: bn2, , output mask:  0.0000 \n[2024-01-18 16:35:01] Propagate variables for output: output, output mask:  0.0000 \n[2024-01-18 16:35:01] Update direct sparsity...\n[2024-01-18 16:35:01] Update direct mask for placeholder: x, output mask:  0.0000 \n[2024-01-18 16:35:01] Update direct mask for call_module: conv1, weight:  0.4898 bias:  0.4898 , output mask:  0.4898 \n[2024-01-18 16:35:02] Update direct mask for call_module: bn1, , output mask:  0.4898 \n[2024-01-18 16:35:02] Update direct mask for call_module: relu, , output mask:  0.4898 \n[2024-01-18 16:35:02] Update direct mask for call_module: conv2, , output mask:  0.0000 \n[2024-01-18 16:35:02] Update direct mask for call_module: bn2, , output mask:  0.0000 \n[2024-01-18 16:35:02] Update direct mask for output: output, output mask:  0.0000 \n[2024-01-18 16:35:02] Update indirect sparsity...\n[2024-01-18 16:35:02] Update indirect mask for output: output, output mask:  0.0000 \n[2024-01-18 16:35:03] Update indirect mask for call_module: bn2, , output mask:  0.0000 \n[2024-01-18 16:35:03] Update indirect mask for call_module: conv2, , output mask:  0.0000 \n[2024-01-18 16:35:04] Update indirect mask for call_module: relu, , output mask:  0.4898 \n[2024-01-18 16:35:04] Update indirect mask for call_module: bn1, , output mask:  0.4898 \n[2024-01-18 16:35:04] Update indirect mask for call_module: conv1, weight:  0.4898 bias:  0.4898 , output mask:  0.4898 \n[2024-01-18 16:35:04] Update indirect mask for placeholder: x, output mask:  0.0000 \n[2024-01-18 16:35:04] Resolve the mask conflict after mask propagate...\n[2024-01-18 16:35:04] dim0 sparsity: 0.489796\n[2024-01-18 16:35:04] dim1 sparsity: 0.000000\n0 Filter\n[2024-01-18 16:35:04] dim0 sparsity: 0.489796\n[2024-01-18 16:35:04] dim1 sparsity: 0.000000\n[2024-01-18 16:35:04] Replace compressed modules...\n[2024-01-18 16:35:04] replace module (name: conv1, op_type: Conv2d)\n[2024-01-18 16:35:04] replace conv2d with in_channels: 3, out_channels: 25\n[2024-01-18 16:35:04] replace module (name: bn1, op_type: BatchNorm2d)\n[2024-01-18 16:35:04] replace batchnorm2d with num_features: 25\n[2024-01-18 16:35:04] replace module (name: relu, op_type: ReLU)\n[2024-01-18 16:35:04] replace module (name: conv2, op_type: Conv2d)\n[2024-01-18 16:35:04] replace conv2d with in_channels: 25, out_channels: 80\n[2024-01-18 16:35:04] replace module (name: bn2, op_type: BatchNorm2d)\n[2024-01-18 16:35:04] replace batchnorm2d with num_features: 80\n[2024-01-18 16:35:04] Speedup done.\nNumber of parameters before pruning: 36990\nNumber of parameters after pruning: 18990\nNumber of parameters pruned: 18000\nParameter ratio: 51.34%\nConvNet(\n  (relu): ReLU(inplace=True)\n  (conv1): Conv2d(3, 25, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n  (bn1): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n  (conv2): Conv2d(25, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n  (bn2): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n)\n
\n
\n

However, using re with 'exclude_op_names_re': ['conv2*'] (or con*2 for example) does not work.

\n

Perhaps for a more direct way to add in config to only do pruning on input channels, there is a granularity prop. It must be specified along with op_names and sparse_ratio in a separate config added to the config list.

\n
config_list = [\n    {\n        'op_names': ['conv2'],\n        'sparse_ratio': sparsity_ratio,\n        'granularity': 'in_channel',\n    },\n    {\n        'op_types': ['Conv2d'],\n        'sparse_ratio': sparsity_ratio,\n    }, \n]\n
\n

There are a few warnings running it with this config

\n
\nLog\n
Ouput shape: torch.Size([1, 40, 32, 32])\n[2024-01-18 16:32:58] WARNING: bias have already configured, the new config will be ignored.\n[2024-01-18 16:32:58] WARNING: weight have already configured, the new config will be ignored.\n[2024-01-18 16:32:58] Start to speedup the model...\n[2024-01-18 16:32:58] Resolve the mask conflict before mask propagate...\n[2024-01-18 16:32:58] dim0 sparsity: 0.166667\n[2024-01-18 16:32:58] dim1 sparsity: 0.434783\n[2024-01-18 16:32:58] WARNING: both dim0 and dim1 masks found.\n1 Filter\n[2024-01-18 16:32:58] dim0 sparsity: 0.166667\n[2024-01-18 16:32:58] dim1 sparsity: 0.434783\n[2024-01-18 16:32:58] WARNING: both dim0 and dim1 masks found.\n[2024-01-18 16:32:58] Infer module masks...\n[2024-01-18 16:32:58] Propagate original variables\n[2024-01-18 16:32:58] Propagate variables for placeholder: x, output mask:  0.0000 \n[2024-01-18 16:32:59] Propagate variables for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.0000 \n[2024-01-18 16:32:59] Propagate variables for call_module: bn1, , output mask:  0.0000 \n[2024-01-18 16:32:59] Propagate variables for call_module: relu, , output mask:  0.0000 \n[2024-01-18 16:32:59] Propagate variables for call_module: conv2, weight:  0.5000 bias:  0.0000 , output mask:  0.0000 \n[2024-01-18 16:32:59] Propagate variables for call_module: bn2, , output mask:  0.0000 \n[2024-01-18 16:32:59] Propagate variables for output: output, output mask:  0.0000 \n[2024-01-18 16:32:59] Update direct sparsity...\n[2024-01-18 16:33:00] Update direct mask for placeholder: x, output mask:  0.0000 \n[2024-01-18 16:33:00] Update direct mask for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.5000 \n[2024-01-18 16:33:00] Update direct mask for call_module: bn1, , output mask:  0.5000 \n[2024-01-18 16:33:00] Update direct mask for call_module: relu, , output mask:  0.5000 \n[2024-01-18 16:33:00] Update direct mask for call_module: conv2, weight:  0.5000 bias:  0.0000 , output mask:  0.0000 \n[2024-01-18 16:33:00] Update direct mask for call_module: bn2, , output mask:  0.0000 \n[2024-01-18 16:33:00] Update direct mask for output: output, output mask:  0.0000 \n[2024-01-18 16:33:00] Update indirect sparsity...\n[2024-01-18 16:33:01] Update indirect mask for output: output, output mask:  0.0000 \n[2024-01-18 16:33:01] Update indirect mask for call_module: bn2, , output mask:  0.0000 \n[2024-01-18 16:33:01] Update indirect mask for call_module: conv2, weight:  0.7000 bias:  0.0000 , output mask:  0.0000 \n[2024-01-18 16:33:02] Update indirect mask for call_module: relu, , output mask:  0.7000 \n[2024-01-18 16:33:02] Update indirect mask for call_module: bn1, , output mask:  0.7000 \n[2024-01-18 16:33:02] Update indirect mask for call_module: conv1, weight:  0.5000 bias:  0.5000 , output mask:  0.7000 \n[2024-01-18 16:33:02] Update indirect mask for placeholder: x, output mask:  0.0000 \n[2024-01-18 16:33:02] Resolve the mask conflict after mask propagate...\n[2024-01-18 16:33:02] dim0 sparsity: 0.166667\n[2024-01-18 16:33:02] dim1 sparsity: 0.608696\n[2024-01-18 16:33:02] WARNING: both dim0 and dim1 masks found.\n1 Filter\n[2024-01-18 16:33:02] dim0 sparsity: 0.166667\n[2024-01-18 16:33:02] dim1 sparsity: 0.608696\n[2024-01-18 16:33:02] WARNING: both dim0 and dim1 masks found.\n[2024-01-18 16:33:02] Replace compressed modules...\n[2024-01-18 16:33:02] replace module (name: conv1, op_type: Conv2d)\n[2024-01-18 16:33:02] replace conv2d with in_channels: 3, out_channels: 6\n[2024-01-18 16:33:02] replace module (name: bn1, op_type: BatchNorm2d)\n[2024-01-18 16:33:03] replace batchnorm2d with num_features: 6\n[2024-01-18 16:33:03] replace module (name: relu, op_type: ReLU)\n[2024-01-18 16:33:03] replace module (name: conv2, op_type: Conv2d)\n[2024-01-18 16:33:03] replace conv2d with in_channels: 6, out_channels: 40\n[2024-01-18 16:33:03] replace module (name: bn2, op_type: BatchNorm2d)\n[2024-01-18 16:33:03] replace batchnorm2d with num_features: 40\n[2024-01-18 16:33:03] Speedup done.\nNumber of parameters before pruning: 7920\nNumber of parameters after pruning: 2460\nNumber of parameters pruned: 5460\nParameter ratio: 31.06%\nConvNet(\n  (relu): ReLU(inplace=True)\n  (conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n  (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n  (conv2): Conv2d(6, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n  (bn2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n
\n
\n

The warnings mention that the weights and biases configured for more than once are ignored, the ones from the first config are kept. When swapping the order granularity config is simply ignored because of the same.

\n

That said, it would be more intuitive to define the base case first and override it with a special case.

","upvoteCount":1,"url":"https://github.com/microsoft/nni/discussions/5737#discussioncomment-8171172"}}}
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Pruning: How to preserve the number of output channels of a particular layer? #5737

Discussion options

You must be logged in to vote

Found out that the op names do not and should not include weight or bias. So with that, using adding the last layer's name in the exclude_op_names just works:

config_list = [{
    'op_types': ['Conv2d'],
    'sparse_ratio': sparsity_ratio,
    'exclude_op_names': [
        'conv2',
    ]
}]
Log
Ouput shape: torch.Size([1, 80, 32, 32])
[2024-01-18 16:35:00] Start to speedup the model...
[2024-01-18 16:35:00] Resolve the mask conflict before mask propagate...
[2024-01-18 16:35:00] dim0 sparsity: 0.489796
[2024-01-18 16:35:00] dim1 sparsity: 0.000000
0 Filter
[2024-01-18 16:35:00] dim0 sparsity: 0.489796
[2024-01-18 16:35:00] dim1 sparsity: 0.000000
[2024-01-18 16:35:00] Infer module masks…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by saravanabalagi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant