Skip to content

Commit

Permalink
Merge pull request #121 from cog-imperial/add_activations_onnx
Browse files Browse the repository at this point in the history
Add activations onnx
  • Loading branch information
rmisener authored Sep 26, 2023
2 parents c0f6c0b + 7d9814f commit 1033972
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/omlt/io/onnx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from omlt.neuralnet.network_definition import NetworkDefinition

_ACTIVATION_OP_TYPES = ["Relu", "Sigmoid", "LogSoftmax"]
_ACTIVATION_OP_TYPES = ["Relu", "Sigmoid", "LogSoftmax", "Tanh", "Softplus"]
_POOLING_OP_TYPES = ["MaxPool"]


Expand Down Expand Up @@ -232,12 +232,12 @@ def _consume_gemm_dense_nodes(self, node, next_nodes):
attr = _collect_attributes(node)
alpha = attr["alpha"]
beta = attr["beta"]
assert attr["transB"] == 1
[in_0, in_1, in_2] = list(node.input)
input_layer, transformer = self._node_input_and_transformer(in_0)
weights = self._initializers[in_1]
# transpose B
weights = np.transpose(weights)
if attr["transB"] == 1:
weights = np.transpose(weights)
biases = self._initializers[in_2]

input_output_size = _get_input_output_size(input_layer, transformer)
Expand Down
16 changes: 16 additions & 0 deletions tests/io/test_onnx_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ def test_gemm(datadir):
assert layers[3].activation == "logsoftmax"


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_gemm_transB(datadir):
model = onnx.load(datadir.file("gemm_not_transB.onnx"))
model_transB = onnx.load(datadir.file("gemm_transB.onnx"))
net = load_onnx_neural_network(model)
net_transB = load_onnx_neural_network(model_transB)
layers = list(net.layers)
layers_transB = list(net_transB.layers)
assert len(layers) == len(layers_transB)
assert layers[1].weights.shape == layers_transB[1].weights.shape
assert abs(layers[1].weights[0][0] - layers_transB[1].weights[0][0]) < 1e-05
assert abs(layers[1].weights[0][1] - layers_transB[1].weights[1][0]) < 1e-05
assert abs(layers[1].weights[1][0] - layers_transB[1].weights[0][1]) < 1e-05
assert abs(layers[1].weights[1][1] - layers_transB[1].weights[1][1]) < 1e-05


@pytest.mark.skipif(not onnx_available, reason="Need ONNX for this test")
def test_conv(datadir):
model = onnx.load(datadir.file("convx1_gemmx1.onnx"))
Expand Down
Binary file added tests/models/gemm_not_transB.onnx
Binary file not shown.
Binary file added tests/models/gemm_transB.onnx
Binary file not shown.
6 changes: 4 additions & 2 deletions tests/neuralnet/test_nn_formulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@ def test_invalid_layer_type():

def _maxpool_conv_network(inputs):
input_size = [1, 8, 6]
input_bounds = np.empty(input_size, dtype="i,i")
input_bounds.fill((-10, 10))
input_bounds = {}
for i in range(input_size[1]):
for j in range(input_size[2]):
input_bounds[(0, i, j)] = (-10.0, 10.0)
net = NetworkDefinition(scaled_input_bounds=input_bounds)

input_layer = InputLayer(input_size)
Expand Down

0 comments on commit 1033972

Please sign in to comment.