Files
cocommit/src/cmd/update.go
T

265 lines
6.2 KiB
Go

package cmd
import (
"archive/tar"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"regexp"
"runtime"
"strings"
"github.com/spf13/cobra"
)
type github_release struct {
TagName string `json:"tag_name"`
}
// updateCmd represents the update command
var updateCmd = &cobra.Command{
Use: "update",
Short: "Tries to update the cocommit cli tool by either running the update script or by running the go get command if the -g flag is set",
Long: `This command will try to update the cocommit cli tool by either running the update script or by running the go get Command if the -g flag is set.`,
Run: func(cmd *cobra.Command, args []string) {
gflag, _ := cmd.Flags().GetBool("go-get")
cflag, _ := cmd.Flags().GetBool("check")
if cflag {
fmt.Println("Checking if Cocommit is up to date")
if update {
update_msg()
} else {
fmt.Println("Cocommit is up to date")
}
os.Exit(0)
}
// check version of the cli tool
Github, err := http.Get("https://api.github.com/repos/Slug-Boi/cocommit/releases/latest")
if err != nil {
fmt.Println("Error getting latest release version")
fmt.Println("Would you still like to update? (y/n)")
var input string
fmt.Scanln(&input)
if input == "y" || input == "Y" || input == "yes" {
fmt.Println("Running update script to update cocommit cli tool")
} else {
fmt.Println("Update cancelled")
return
}
}
defer Github.Body.Close()
var release github_release
err = json.NewDecoder(Github.Body).Decode(&release)
if err != nil {
panic("Error decoding json")
}
if release.TagName == Coco_Version {
fmt.Println("Cocommit cli tool is already up to date")
return
}
if gflag {
fmt.Println("Running go get command to update cocommit cli tool")
cmd := exec.Command("go", "get", "-u", "github.com/Slug-Boi/cocommit")
err := cmd.Run()
if err != nil {
fmt.Println("Error running go get command")
}
fmt.Println("Cocommit cli tool updated successfully")
} else {
fmt.Println("Running binary replace to update cocommit cli tool")
updateScript()
}
},
}
func cleanup() {
fmt.Println("Cleaning up")
os.Remove("cocommit.tar.gz")
}
func updateScript() {
exec_path, err := os.Executable()
if err != nil {
fmt.Println("Error getting executable path")
}
if filepath.Base(exec_path) == "main" {
fmt.Println("Cancelling update running as source code")
return
}
exec_path, err = filepath.EvalSymlinks(exec_path)
if err != nil {
log.Fatal(err)
}
file, err := os.Create("cocommit.tar.gz")
if err != nil {
fmt.Println("Error creating file")
}
defer cleanup()
defer file.Close()
var resp *http.Response
switch runtime.GOOS {
case "darwin":
fmt.Println("Downloading mac version")
if runtime.GOARCH == "amd64" {
resp, err = http.Get("https://github.com/Slug-Boi/cocommit/releases/latest/download/cocommit-darwin-x86_64.tar.gz")
} else {
resp, err = http.Get("https://github.com/Slug-Boi/cocommit/releases/latest/download/cocommit-darwin-aarch64.tar.gz")
}
case "windows":
fmt.Println("Downloading windows version")
resp, err = http.Get("https://github.com/Slug-Boi/cocommit/releases/latest/download/cocommit-win.tar.gz")
default:
fmt.Println("Downloading linux version")
resp, err = http.Get("https://github.com/Slug-Boi/cocommit/releases/latest/download/cocommit-linux.tar.gz")
}
if err != nil {
fmt.Println("Error downloading file")
}
defer resp.Body.Close()
_, err = io.Copy(file, resp.Body)
if err != nil {
fmt.Println("Error copying file")
}
r, err := os.Open("cocommit.tar.gz")
if err != nil {
fmt.Println("Error opening file")
}
err = unzipper("./", r)
if err != nil {
panic("Error unzipping file - " + err.Error())
}
swapper(exec_path)
fmt.Println(update_style.Render("Cocommit cli tool updated successfully"))
}
func swapper(exec_path string) {
regExp := regexp.MustCompile("cocommit-.+")
var new_binary string
err := filepath.Walk("./", func(path string, info os.FileInfo, err error) error {
if err == nil && regExp.MatchString(info.Name()) {
new_binary = info.Name()
return nil
}
return nil
})
if err != nil {
log.Fatal(err)
}
defer os.Remove(new_binary)
if new_binary != "" {
err = os.Rename(new_binary, exec_path)
if err != nil {
log.Fatal(err)
}
}
}
func unzipper(dst string, r io.Reader) error {
gzr, err := gzip.NewReader(r)
if err != nil {
return err
}
defer gzr.Close()
tr := tar.NewReader(gzr)
for {
header, err := tr.Next()
switch {
// if no more files are found return
case err == io.EOF:
return nil
// return any other error
case err != nil:
return err
// if the header is nil, just skip it (not sure how this happens)
case header == nil:
continue
}
// the target location where the dir/file should be created
target := filepath.Join(dst, header.Name)
// ensure the target path is within the destination directory
cleanTarget, err := filepath.Abs(target)
if err != nil {
return fmt.Errorf("failed to get absolute path: %v", err)
}
cleanDst, err := filepath.Abs(dst)
if err != nil {
return fmt.Errorf("failed to get absolute path: %v", err)
}
if !strings.HasPrefix(cleanTarget, cleanDst+string(os.PathSeparator)) {
return fmt.Errorf("illegal file path: %s\nExpected: %s", cleanTarget, cleanDst+string(os.PathSeparator))
}
// check the file type
switch header.Typeflag {
// if its a dir and it doesn't exist create it
case tar.TypeDir:
if _, err := os.Stat(target); err != nil {
if err := os.MkdirAll(target, 0755); err != nil {
return err
}
}
// if it's a file create it
case tar.TypeReg:
f, err := os.OpenFile(target, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode))
if err != nil {
return err
}
// copy over contents
if _, err := io.Copy(f, tr); err != nil {
return err
}
// manually close here after each file operation; defering would cause each file close
// to wait until all operations have completed.
f.Close()
}
}
}
func init() {
rootCmd.AddCommand(updateCmd)
updateCmd.Flags().BoolP("go-get", "g", false, "Use the go get command to update the cocommit cli tool")
updateCmd.Flags().BoolP("check", "c", false, "Check if the cocommit cli tool is up to date")
}