diff --git a/protocol/message_persistence.go b/protocol/message_persistence.go index 10fc15a5b5..017ded4683 100644 --- a/protocol/message_persistence.go +++ b/protocol/message_persistence.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "sort" "strings" @@ -1768,17 +1769,41 @@ func (db sqlitePersistence) SavePinMessage(message *common.PinMessage) (inserted func (db sqlitePersistence) DeleteMessage(id string) error { _, err := db.db.Exec(`DELETE FROM user_messages WHERE id = ?`, id) + + if err != nil { + return err + } + + _, err = db.db.Exec("DELETE FROM pin_messages WHERE message_id = ?", id) + return err } -func (db sqlitePersistence) DeleteMessages(ids []string) error { +func (db sqlitePersistence) DeleteMessages(ids []string) (err error) { idsArgs := make([]interface{}, 0, len(ids)) for _, id := range ids { idsArgs = append(idsArgs, id) } inVector := strings.Repeat("?, ", len(ids)-1) + "?" - _, err := db.db.Exec("DELETE FROM user_messages WHERE id IN ("+inVector+")", idsArgs...) // nolint: gosec + tx, err := db.db.BeginTx(context.Background(), &sql.TxOptions{}) + if err != nil { + return err + } + defer func() { + if err == nil { + err = tx.Commit() + return + } + err = errors.Join(err, tx.Rollback()) + }() + + _, err = tx.Exec("DELETE FROM user_messages WHERE id IN ("+inVector+")", idsArgs...) // nolint: gosec + if err != nil { + return err + } + + _, err = tx.Exec("DELETE FROM pin_messages WHERE message_id IN ("+inVector+")", idsArgs...) // nolint: gosec return err }