Skip to content

Commit

Permalink
fix(snowflake): fix new array operations; remove ArrayRemove operation
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Apr 11, 2023
1 parent 5208801 commit 772668b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 13 deletions.
19 changes: 13 additions & 6 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,19 @@ def _format_in_memory_table(self, op, ref_op, translator):
if _NATIVE_ARROW:
return super()._format_in_memory_table(op, ref_op, translator)

columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
rows = list(ref_op.data.to_frame().itertuples(index=False))
pos_columns = [
sa.column(f"${idx}") for idx in range(1, len(ref_op.schema.names) + 1)
]
return sa.select(*pos_columns).select_from(sa.values(*columns).data(rows))
import ibis

schema = ref_op.schema
selects = (
sa.select(
*(
translator.translate(ibis.literal(col, typ).op()).label(name)
for col, (name, typ) in zip(row, schema.items())
)
)
for row in ref_op.data.to_frame().itertuples(index=False)
)
return sa.union_all(*selects)


class SnowflakeCompiler(AlchemyCompiler):
Expand Down
15 changes: 9 additions & 6 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,17 @@ def _group_concat(t, op):
sa.func.ifnull(arg, sa.func.parse_json("null")), type_=ARRAY
)
),
ops.ArrayContains: fixed_arity(sa.func.array_contains, 2),
ops.ArrayPosition: fixed_arity(
lambda lst, el: sa.func.array_position(lst, el) - 1, 2
ops.ArrayContains: fixed_arity(
lambda arr, el: sa.func.array_contains(sa.func.to_variant(el), arr), 2
),
ops.ArrayDistinct: fixed_arity(sa.func.array_distinct, 1),
ops.ArrayRemove: fixed_arity(
lambda lst, el: sa.func.array_except(lst, sa.func.array_construct(el)),
ops.ArrayPosition: fixed_arity(
# snowflake is zero-based here, so we don't need to substract 1 from the result
lambda lst, el: sa.func.coalesce(
sa.func.array_position(sa.func.to_variant(el), lst), -1
),
2,
),
ops.ArrayDistinct: fixed_arity(sa.func.array_distinct, 1),
ops.ArrayUnion: fixed_arity(
lambda left, right: sa.func.array_distinct(sa.func.array_cat(left, right)),
2,
Expand Down Expand Up @@ -396,6 +398,7 @@ def _group_concat(t, op):
ops.CumulativeOp,
ops.NTile,
# ibis.expr.operations.array
ops.ArrayRemove,
ops.ArrayRepeat,
ops.ArraySort,
# ibis.expr.operations.reductions
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def test_array_position(backend, con):
"pandas",
"polars",
"postgres",
"snowflake",
],
raises=com.OperationNotDefinedError,
)
Expand Down Expand Up @@ -629,6 +630,7 @@ def test_array_unique(backend, con):
"pandas",
"polars",
"postgres",
"snowflake",
],
raises=com.OperationNotDefinedError,
)
Expand Down Expand Up @@ -659,7 +661,7 @@ def test_array_sort(backend, con):
raises=com.OperationNotDefinedError,
)
@pytest.mark.broken(
["trino", "pyspark"],
["snowflake", "trino", "pyspark"],
raises=AssertionError,
reason="array_distinct([NULL]) seems to differ from other backends",
)
Expand Down

0 comments on commit 772668b

Please sign in to comment.