Skip to content

Commit

Permalink
add selfupdate
Browse files Browse the repository at this point in the history
  • Loading branch information
QuintenQVD0 committed Dec 2, 2024
1 parent 5858856 commit 4fdc52e
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 11 deletions.
23 changes: 12 additions & 11 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,24 @@ func init() {
rootCommand.AddCommand(versionCommand)
rootCommand.AddCommand(configureCmd)
rootCommand.AddCommand(newDiagnosticsCommand())
rootCommand.AddCommand(newSelfupdateCommand())
}

func isDockerSnap() bool {
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
log.Fatalf("Unable to initialize Docker client: %s", err)
}
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
log.Fatalf("Unable to initialize Docker client: %s", err)
}

defer cli.Close() // Close the client when the function returns (should not be needed, but just to be safe)
defer cli.Close() // Close the client when the function returns (should not be needed, but just to be safe)

info, err := cli.Info(context.Background())
if err != nil {
log.Fatalf("Unable to get Docker info: %s", err)
}
info, err := cli.Info(context.Background())
if err != nil {
log.Fatalf("Unable to get Docker info: %s", err)
}

// Check if Docker root directory contains '/var/snap/docker'
return strings.Contains(info.DockerRootDir, "/var/snap/docker")
// Check if Docker root directory contains '/var/snap/docker'
return strings.Contains(info.DockerRootDir, "/var/snap/docker")
}

func rootCmdRun(cmd *cobra.Command, _ []string) {
Expand Down
257 changes: 257 additions & 0 deletions cmd/selfupdate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
package cmd

import (
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"runtime"
"strings"
"time"

"github.com/pelican-dev/wings/system"
"github.com/spf13/cobra"
)

var updateArgs struct {
repoOwner string
repoName string
}

func newSelfupdateCommand() *cobra.Command {
command := &cobra.Command{
Use: "update",
Short: "Update the wings to the latest version",
Run: selfupdateCmdRun,
}

command.Flags().StringVar(&updateArgs.repoOwner, "repo-owner", "pelican-dev", "GitHub username or organization that owns the repository containing the updates")
command.Flags().StringVar(&updateArgs.repoName, "repo-name", "wings", "The name of the GitHub repository to fetch updates from")

return command
}

func selfupdateCmdRun(*cobra.Command, []string) {
currentVersion := system.Version
if currentVersion == "" {
fmt.Println("Error: Current version is not defined")
return
}

if currentVersion == "develop" {
fmt.Println("Running in development mode. Skipping update.")
return
}

fmt.Println("Current version:", currentVersion)

// Fetch the latest release tag from GitHub API
latestVersionTag, err := fetchLatestGitHubRelease()
if err != nil {
fmt.Println("Failed to fetch the latest version:", err)
return
}

currentVersionTag := "v" + currentVersion
if latestVersionTag == currentVersionTag {
fmt.Println("You are running the latest version:", currentVersion)
return
}

fmt.Printf("A new version is available: %s (current: %s)\n", latestVersionTag, currentVersionTag)

binaryName := determineBinaryName()
if binaryName == "" {
fmt.Println("Unsupported architecture")
return
}

downloadURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s", updateArgs.repoOwner, updateArgs.repoName, latestVersionTag, binaryName)
checksumURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/checksums.txt", updateArgs.repoOwner, updateArgs.repoName, latestVersionTag)

fmt.Println("Downloading checksums.txt...")
checksumFile, err := downloadFile(checksumURL, "checksums.txt")
if err != nil {
fmt.Println("Failed to download checksum file:", err)
return
}
defer os.Remove(checksumFile)

fmt.Println("Downloading", binaryName, "...")
binaryFile, err := downloadFile(downloadURL, binaryName)
if err != nil {
fmt.Println("Failed to download binary file:", err)
return
}
defer os.Remove(binaryFile)

if err := verifyChecksum(binaryFile, checksumFile, binaryName); err != nil {
fmt.Println("Checksum verification failed:", err)
return
}
fmt.Println("\nChecksum verification successful.")

currentExecutable, err := os.Executable()
if err != nil {
fmt.Println("Failed to locate current executable:", err)
return
}

if err := os.Chmod(binaryFile, 0755); err != nil {
fmt.Println("Failed to set executable permissions on the new binary:", err)
return
}

if err := replaceBinary(currentExecutable, binaryFile); err != nil {
fmt.Println("Failed to replace executable:", err)
return
}

fmt.Println("Restarting service...")

if err := restartService(); err != nil {
fmt.Println("Error restarting the wings service:", err)
} else {
fmt.Println("Service restarted successfully.")
}
}

func fetchLatestGitHubRelease() (string, error) {
apiURL := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", updateArgs.repoOwner, updateArgs.repoName)

resp, err := http.Get(apiURL)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

var releaseData struct {
TagName string `json:"tag_name"`
}
if err := json.NewDecoder(resp.Body).Decode(&releaseData); err != nil {
return "", err
}

return releaseData.TagName, nil
}

func determineBinaryName() string {
switch runtime.GOARCH {
case "amd64":
return "wings_linux_amd64"
case "arm64":
return "wings_linux_arm64"
default:
return ""
}
}

func downloadFile(url, fileName string) (string, error) {
tmpFile, err := os.CreateTemp("", fileName)
if err != nil {
return "", err
}
defer tmpFile.Close()

resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status: %s", resp.Status)
}

fmt.Printf("Downloading %s (%.2f MB)...\n", fileName, float64(resp.ContentLength)/1024/1024)
progressWriter := &progressWriter{Writer: tmpFile, Total: resp.ContentLength}
if _, err := io.Copy(progressWriter, resp.Body); err != nil {
return "", err
}

fmt.Println() // Ensure a newline after download progress
return tmpFile.Name(), nil
}

func verifyChecksum(binaryPath, checksumPath, binaryName string) error {
checksumData, err := os.ReadFile(checksumPath)
if err != nil {
return err
}

var expectedChecksum string
for _, line := range strings.Split(string(checksumData), "\n") {
if strings.HasSuffix(line, binaryName) {
parts := strings.Fields(line)
if len(parts) > 0 {
expectedChecksum = parts[0]
}
break
}
}
if expectedChecksum == "" {
return fmt.Errorf("checksum not found for %s", binaryName)
}

file, err := os.Open(binaryPath)
if err != nil {
return err
}
defer file.Close()

hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return err
}
actualChecksum := fmt.Sprintf("%x", hasher.Sum(nil))

if actualChecksum != expectedChecksum {
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum)
}

return nil
}

func replaceBinary(currentPath, newPath string) error {
return os.Rename(newPath, currentPath)
}

type progressWriter struct {
io.Writer
Total int64
Written int64
StartTime time.Time
}

func (pw *progressWriter) Write(p []byte) (int, error) {
n, err := pw.Writer.Write(p)
pw.Written += int64(n)

if pw.Total > 0 {
percent := float64(pw.Written) / float64(pw.Total) * 100
fmt.Printf("\rProgress: %.2f%%", percent)
}

return n, err
}

func restartService() error {
// Try to run the systemctl restart command
cmd := exec.Command("systemctl", "restart", "wings")
cmdOutput, err := cmd.CombinedOutput()

if err != nil {
// If systemctl command fails, return the error with output
return fmt.Errorf("failed to restart service: %s\n%s", err.Error(), string(cmdOutput))
}

// If successful, return nil
return nil
}

0 comments on commit 4fdc52e

Please sign in to comment.