Skip to content

Commit

Permalink
add optional_features
Browse files Browse the repository at this point in the history
  • Loading branch information
iuu6 committed Aug 13, 2024
1 parent 58a16fd commit 8db102f
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 18 deletions.
2 changes: 1 addition & 1 deletion functions/ai_chat/ai_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func generateUserMessage(userInput string) string {
// createAIRequest initializes the AI request payload.
func createAIRequest(content string) AIRequest {
return AIRequest{
Model: "llama-3-70b",
Model: "mixtral-8x7b",
Messages: []struct {
Role string `json:"role"`
Content string `json:"content"`
Expand Down
51 changes: 39 additions & 12 deletions functions/setting/setting.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@ import (
"fmt"
"log"
"regexp"
"strconv"
"strings"

tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
)

// Add the list of valid features
var availableFeatures = []string{
"help", "play", "ask", "getid", "status", "admins",
"num", "string", "curconv", "color", "setting",
"num", "string", "curconv", "color", "setting", "ai_chat",
}

// Features that cannot be disabled
var nonDisablableFeatures = []string{"setting"}

func IsFeatureEnabled(db *sql.DB, groupID int64, featureName string) bool {
// Check if the feature is non-disablable
for _, feature := range nonDisablableFeatures {
if feature == featureName {
return true
Expand All @@ -31,7 +30,7 @@ func IsFeatureEnabled(db *sql.DB, groupID int64, featureName string) bool {
err := db.QueryRow("SELECT feature_off FROM group_setting WHERE groupid = ?", groupID).Scan(&featureOffList)
if err != nil && err != sql.ErrNoRows {
log.Printf("Error querying feature status: %v", err)
return true // Default to enabled if error occurs
return true
}

if featureOffList == "" {
Expand All @@ -56,8 +55,8 @@ func HandleSettingCommand(db *sql.DB, message *tgbotapi.Message, bot *tgbotapi.B
}

args := strings.Fields(message.CommandArguments())
if len(args) != 2 {
msg := tgbotapi.NewMessage(message.Chat.ID, "用法: /setting <enable/disable> <feature_name>")
if len(args) < 2 || (len(args) == 3 && args[0] != "enable") {
msg := tgbotapi.NewMessage(message.Chat.ID, "用法: /setting <enable/disable> <feature_name> [value]")
bot.Send(msg)
return
}
Expand Down Expand Up @@ -85,25 +84,55 @@ func HandleSettingCommand(db *sql.DB, message *tgbotapi.Message, bot *tgbotapi.B
return
}

var featureOffList string
err := db.QueryRow("SELECT feature_off FROM group_setting WHERE groupid = ?", message.Chat.ID).Scan(&featureOffList)
var featureOffList, optionalFeatures string
var currentValue *int
err := db.QueryRow("SELECT feature_off, optional_features, value_ai_chat FROM group_setting WHERE groupid = ?", message.Chat.ID).Scan(&featureOffList, &optionalFeatures, &currentValue)
if err != nil && err != sql.ErrNoRows {
log.Printf("Error querying feature status: %v", err)
return
}

disabledFeatures := strings.Split(featureOffList, ",")
enabledOptionalFeatures := strings.Split(optionalFeatures, ",")

if action == "enable" {
disabledFeatures = remove(disabledFeatures, feature)
if feature == "ai_chat" && len(args) == 3 {
value, err := strconv.Atoi(args[2])
if err != nil || value < 0 || value > 100 {
msg := tgbotapi.NewMessage(message.Chat.ID, "无效的值,请输入0到100之间的整数")
bot.Send(msg)
return
}
log.Printf("Setting AI chat trigger value to %d for groupID %d", value, message.Chat.ID)

// 使用 INSERT OR REPLACE INTO 确保值被正确写入
_, err = db.Exec(`
INSERT INTO group_setting (groupid, feature_off, optional_features, value_ai_chat)
VALUES (?, ?, ?, ?)
ON CONFLICT(groupid) DO UPDATE SET value_ai_chat = excluded.value_ai_chat`, message.Chat.ID, featureOffList, optionalFeatures, value)

if err != nil {
log.Printf("Error updating AI chat trigger value: %v", err)
return
}
if !contains(enabledOptionalFeatures, feature) {
enabledOptionalFeatures = append(enabledOptionalFeatures, feature)
}
}
} else if action == "disable" {
if !contains(disabledFeatures, feature) {
disabledFeatures = append(disabledFeatures, feature)
}
if contains(enabledOptionalFeatures, feature) {
enabledOptionalFeatures = remove(enabledOptionalFeatures, feature)
}
}

featureOffList = strings.Join(disabledFeatures, ",")
_, err = db.Exec("REPLACE INTO group_setting (groupid, feature_off) VALUES (?, ?)", message.Chat.ID, featureOffList)
optionalFeatures = strings.Join(enabledOptionalFeatures, ",")

_, err = db.Exec("INSERT INTO group_setting (groupid, feature_off, optional_features) VALUES (?, ?, ?) ON CONFLICT(groupid) DO UPDATE SET feature_off = excluded.feature_off, optional_features = excluded.optional_features", message.Chat.ID, featureOffList, optionalFeatures)
if err != nil {
log.Printf("Error updating feature status: %v", err)
return
Expand Down Expand Up @@ -161,13 +190,11 @@ func remove(slice []string, item string) []string {

// validateInput ensures that the action and feature are safe from SQL injection
func validateInput(action, feature string) bool {
// Allowed actions are only "enable" and "disable"
validActions := []string{"enable", "disable"}
if !contains(validActions, action) {
return false
}

// Ensure feature name only contains alphanumeric characters to prevent SQL injection
re := regexp.MustCompile("^[a-zA-Z0-9]+$")
re := regexp.MustCompile("^[a-zA-Z0-9_]+$")
return re.MatchString(feature)
}
22 changes: 17 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ func initDatabase() {
log.Fatalf("Error opening database: %v", err)
}

// 创建 group_setting 表格
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS group_setting (
groupid INTEGER PRIMARY KEY,
feature_off TEXT
feature_off TEXT,
optional_features TEXT,
value_ai_chat INTEGER
)`)
if err != nil {
log.Fatalf("Error creating table: %v", err)
Expand Down Expand Up @@ -160,7 +161,7 @@ func processMessage(message *tgbotapi.Message, bot *tgbotapi.BotAPI) {
} else if command == "setting" {
setting.HandleSettingCommand(db, message, bot, config.SuperAdmins)
}
} else if (message.Chat.IsGroup() || message.Chat.IsSuperGroup()) && isReplyToBot(message) && shouldTriggerResponse() {
} else if (message.Chat.IsGroup() || message.Chat.IsSuperGroup()) && isReplyToBot(message) && shouldTriggerResponse(message.Chat.ID) {
ai_chat.HandleAIChat(message, bot)
}
}
Expand All @@ -180,10 +181,21 @@ func isReplyToBot(message *tgbotapi.Message) bool {
return false
}

func shouldTriggerResponse() bool {
func shouldTriggerResponse(groupID int64) bool {
var triggerValue int
err := db.QueryRow("SELECT value_ai_chat FROM group_setting WHERE groupid = ?", groupID).Scan(&triggerValue)
if err != nil {
log.Printf("Error fetching AI chat trigger value: %v", err)
triggerValue = 0 // Default value
}

log.Printf("AI Chat Trigger Value: %d", triggerValue)
rand.Seed(time.Now().UnixNano())
randomValue := rand.Intn(100) + 1
return randomValue > 30
log.Printf("Random value generated: %d", randomValue)

// Adjust this logic based on your expectation
return randomValue <= triggerValue
}

func setBotCommands(bot *tgbotapi.BotAPI) {
Expand Down

0 comments on commit 8db102f

Please sign in to comment.