diff --git a/Tests/Backend/CNTK/CNTKBackendTest.cs b/Tests/Backend/CNTK/CNTKBackendTest.cs index cc5c30b..d45fbfb 100644 --- a/Tests/Backend/CNTK/CNTKBackendTest.cs +++ b/Tests/Backend/CNTK/CNTKBackendTest.cs @@ -299,6 +299,61 @@ public void cntk_sum_test() } } + [Test] + public void cntk_sum_test_direct_api() + { + double[][] r; + + r = sum(null); // first, a sanity check to verify that values are being read correctly + double[,] d = new double[2, 3]; // result will be { 0, 1, 2, 3, 4, 5, 6 } + Buffer.BlockCopy(r[0], 0, d, 0, sizeof(double) * d.Length); + Assert.AreEqual(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }, d); // ok + + r = sum(0, 1); // sum over all axes + double a = r[0][0]; // result will be { 21 } + Assert.AreEqual(21, a); // ok + + r = sum(0); // sum over first axis + double[] b = r[0]; // result will be { 3, 7, 11 } + Assert.AreEqual(new[] { 5.0, 7.0, 9.0 }, b); // fails + + r = sum(1); // sum over second axis + double[] c = r[0]; // result will be { 9, 12 } + Assert.AreEqual(new[] { 6.0, 15.0 }, b); // fails + } + + private static double[][] sum(params int[] axes) + { + var arr = new[] + { + /* total: + /* */ 1.0, 2.0, 3.0, /* 6.0 */ + /* */ 4.0, 5.0, 6.0, /* 15.0 */ + /* total: 5.0, 7.0, 9.0 21.0 */ + }; + + var shape = NDShape.CreateNDShape(new[] { 2, 3 }); + Value vx = Value.CreateBatch(shape, arr, DeviceDescriptor.CPUDevice, readOnly: true); + Variable x = Variable.InputVariable(shape, CNTK.DataType.Double, name: "input"); + + CNTK.Function f; + if (axes == null) + { + f = CNTKLib.Alias(x); + } + else + { + var axisVector = new AxisVector(axes.Select(ax => new Axis(ax)).ToArray()); + f = CNTKLib.ReduceSum(x, axis: axisVector); + } + + var inputs = new Dictionary() { { x, vx } }; + var outputs = new Dictionary() { { f, null } }; + f.Evaluate(inputs, outputs, DeviceDescriptor.CPUDevice); + var r = outputs[f].GetDenseData((Variable)f); + return r.Select(ri => ri.ToArray()).ToArray(); + } + [Test] public void cntk_mean_test() {