diff --git a/relate/src/main/scala/com/lucidchart/relate/SqlResult.scala b/relate/src/main/scala/com/lucidchart/relate/SqlResult.scala index 42d212a..2964c9f 100644 --- a/relate/src/main/scala/com/lucidchart/relate/SqlResult.scala +++ b/relate/src/main/scala/com/lucidchart/relate/SqlResult.scala @@ -50,12 +50,8 @@ class SqlResult(val resultSet: java.sql.ResultSet) extends CollectionsSqlResult } def asScalar[A](): A = asScalarOption.get - def asScalarOption[A](): Option[A] = { - if (resultSet.next()) { - Some(resultSet.getObject(1).asInstanceOf[A]) - } else { - None - } + def asScalarOption[A](): Option[A] = withResultSet { resultSet => + Option.when(resultSet.next())(resultSet.getObject(1).asInstanceOf[A]) } /** diff --git a/relate/src/test/scala/SqlResultSpec.scala b/relate/src/test/scala/SqlResultSpec.scala index 206fe5f..da07ab6 100644 --- a/relate/src/test/scala/SqlResultSpec.scala +++ b/relate/src/test/scala/SqlResultSpec.scala @@ -256,6 +256,27 @@ class SqlResultSpec extends Specification with Mockito { result.asScalarOption[Long] must_== None } + + "close the ResultSet" in { + val (rs, _, result) = getMocks + + rs.next returns true + rs.getObject(1) returns (2: java.lang.Long) + + result.asScalar[Long] mustEqual 2L + + there was one(rs).close() + } + + "close the ResultSet if it was empty" in { + val (rs, _, result) = getMocks + + rs.next returns false + + result.asScalarOption[Long] must beNone + + there was one(rs).close() + } } "extractOption" should {