package main import ( // The standard stuff "errors" "fmt" "io" "os" // A cute logging system "github.com/charmbracelet/log" // Encryption "crypto/rand" "encoding/json" "math/big" "net/http" "github.com/golang-jwt/jwt" "golang.org/x/crypto/bcrypt" // SQL databasing "database/sql" _ "github.com/mattn/go-sqlite3" ) // Define the configuration directories var userConfigDirectory, err = os.UserConfigDir() var serverConfigDirectory = fmt.Sprintf("%v/ambition/server", userConfigDirectory) var jwtPath = fmt.Sprintf("%v/jwt_secret", serverConfigDirectory) // Define the user request struct type UserRequest struct { Name string `json:"name"` Password string `json:"password"` } func (ur *UserRequest) Parse(req *http.Request) error { // Can't unmarshal the actual req.Body so must read first body, err := io.ReadAll(req.Body) if err != nil { return err } if err := json.Unmarshal(body, &ur); err != nil { return err } return nil } // A function to write a randomly-generated cryptographically secure 24-character string to a file func makeSecret() { const characters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-" ret := make([]byte, 24) for i := 0; i < 24; i++ { num, err := rand.Int(rand.Reader, big.NewInt(int64(len(characters)))) if err != nil { log.Fatal(err) } ret[i] = characters[num.Int64()] } // Check if the Ambition server config folder exists, otherwise make it _, err2 := os.Stat(serverConfigDirectory) if os.IsNotExist(err2) { log.Info("Ambition backend config folder does not exist, creating...") os.MkdirAll(serverConfigDirectory, 0755) log.Info("Made Ambition backend config folder!") } // Write the secret to the file os.WriteFile(jwtPath, ret, 0755) } // Define the user handler struct type UserHandler struct { db *sql.DB jwt_secret []byte } // Define the function to create user handlers func NewUserHandler() (*UserHandler, error) { // Initialise the database using the database file db, err := sql.Open("sqlite3", "users.db") if err != nil { return nil, err } // Get the JSON web token jwt_secret_bytes, err := os.ReadFile(jwtPath) jwt_secret_str := string(jwt_secret_bytes) // Return any errors if jwt_secret_str == "" { return nil, errors.New("no JWT_SECRET provided in .env") } jwt_secret := []byte(jwt_secret_str) // Return the user handler struct return &UserHandler{ db: db, jwt_secret: jwt_secret, }, nil } // JWT Utilities func (h *UserHandler) ParseUserToken(token_string string) (*jwt.Token, error) { token, err := jwt.Parse(token_string, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, errors.New("jwt: incorrect token signing method") } return h.jwt_secret, nil }) if err != nil { return nil, err } return token, nil } func (h *UserHandler) GenerateUserToken(name, pwdhash string) (string, error) { token := jwt.New(jwt.SigningMethodHS256) claims := token.Claims.(jwt.MapClaims) claims["name"] = name claims["pwdhash"] = pwdhash token_string, err := token.SignedString(h.jwt_secret) if err != nil { return "", err } return token_string, nil } func (h *UserHandler) Handle(res http.ResponseWriter, req *http.Request) { switch req.Method { case "POST": h.createUser(res, req) case "PUT": h.updateUser(res, req) case "DELETE": h.deleteUser(res, req) // Return an error message should an invalid method be used default: http.Error(res, "Only POST, PUT, and DELETE are valid methods", http.StatusMethodNotAllowed) } } // NOTE(midnadimple): This function could be considered to do too much stuff, but // I think this is the best implementation func (h *UserHandler) createUser(res http.ResponseWriter, req *http.Request) { user_request := new(UserRequest) if err := user_request.Parse(req); err != nil { http.Error(res, fmt.Sprintf("user: failed to parse request (%s)", err), http.StatusBadRequest) return } name := user_request.Name password := []byte(user_request.Password) // Password checks row := h.db.QueryRow("SELECT pwdhash FROM users WHERE name=?", name) var db_pwdhash string if err := row.Scan(&db_pwdhash); err != nil { // If no user is found with the requested name, create the user if errors.Is(err, sql.ErrNoRows) { pwdhash_bytes, err := bcrypt.GenerateFromPassword(password, 12) // Log any errors if err != nil { http.Error(res, fmt.Sprintf("user: failed to generate password hash (%s)", err), http.StatusInternalServerError) return } db_pwdhash = string(pwdhash_bytes) _, err = h.db.Exec("INSERT INTO users VALUES (?,?)", name, db_pwdhash) if err != nil { http.Error(res, fmt.Sprintf("db: failed to create user (%s)", err), http.StatusInternalServerError) return } } else { http.Error(res, fmt.Sprintf("db: failed to query row (%s)", err), http.StatusInternalServerError) return } } else if bcrypt.CompareHashAndPassword([]byte(db_pwdhash), password) != nil { http.Error(res, "User exists, but invalid password", http.StatusForbidden) return } // JWT Generation token_string, err := h.GenerateUserToken(name, db_pwdhash) if err != nil { http.Error(res, fmt.Sprintf("jwt: failed to generate token (%s)", err), http.StatusInternalServerError) return } fmt.Fprintf(res, "%s", token_string) } func (h *UserHandler) updateUser(res http.ResponseWriter, req *http.Request) { user_request := new(UserRequest) if err := user_request.Parse(req); err != nil { http.Error(res, fmt.Sprintf("user: failed to parse request (%s)", err), http.StatusBadRequest) return } req_name := user_request.Name req_password := []byte(user_request.Password) if req.Header["Authorization"] == nil { http.Error(res, "jwt: missing token", http.StatusUnauthorized) return } token, err := h.ParseUserToken(req.Header["Token"][0]) if err != nil { http.Error(res, fmt.Sprintf("jwt: error during parsing (%s)", err), http.StatusInternalServerError) return } if !token.Valid { http.Error(res, "jwt: invalid token", http.StatusUnauthorized) return } claims, ok := token.Claims.(jwt.MapClaims) if !ok { http.Error(res, "jwt: failed to get claims", http.StatusInternalServerError) return } claim_name := claims["name"].(string) claim_pwdhash := claims["pwdhash"].(string) var db_name, db_pwdhash string row := h.db.QueryRow("SELECT * FROM users WHERE name=?", claim_name) if err := row.Scan(&db_name, &db_pwdhash); err != nil { if errors.Is(err, sql.ErrNoRows) { http.Error(res, "user: authorized user doesn't exist", http.StatusBadRequest) return } else { http.Error(res, fmt.Sprintf("db: failed to find user (%s)", err), http.StatusInternalServerError) return } } if claim_pwdhash != db_pwdhash { http.Error(res, "user: invalid password", http.StatusForbidden) return } if req_name == claim_name && bcrypt.CompareHashAndPassword([]byte(claim_pwdhash), req_password) == nil { http.Error(res, "user: requested credentials are the same as current credentials", http.StatusBadRequest) return } req_pwdhash_bytes, err := bcrypt.GenerateFromPassword(req_password, 12) if err != nil { http.Error(res, fmt.Sprintf("user: failed to generate password hash (%s)", err), http.StatusInternalServerError) return } req_pwdhash := string(req_pwdhash_bytes) _, err = h.db.Exec("UPDATE users SET name=?, pwdhash=? WHERE name=? AND pwdhash=?", req_name, req_pwdhash, claim_name, claim_pwdhash) if err != nil { http.Error(res, fmt.Sprintf("db: failed to update user (%s)", err), http.StatusInternalServerError) } token_string, err := h.GenerateUserToken(req_name, req_pwdhash) if err != nil { http.Error(res, fmt.Sprintf("jwt: failed to generate token (%s)", err), http.StatusInternalServerError) return } fmt.Fprintf(res, "%s", token_string) } // TODO(midnadimple): Implement: func (h *UserHandler) deleteUser(res http.ResponseWriter, req *http.Request) { token, err := h.ParseUserToken(req.Header["Token"][0]) if err != nil { http.Error(res, fmt.Sprintf("jwt: error during parsing (%s)", err), http.StatusInternalServerError) return } if !token.Valid { http.Error(res, "jwt: invalid token", http.StatusUnauthorized) return } claims, ok := token.Claims.(jwt.MapClaims) if !ok { http.Error(res, "jwt: failed to get claims", http.StatusInternalServerError) return } claim_name := claims["name"].(string) claim_pwdhash := claims["pwdhash"].(string) var db_name, db_pwdhash string row := h.db.QueryRow("SELECT * FROM users WHERE name=?", claim_name) if err := row.Scan(&db_name, &db_pwdhash); err != nil { if errors.Is(err, sql.ErrNoRows) { http.Error(res, "user: authorized user doesn't exist", http.StatusBadRequest) return } else { http.Error(res, fmt.Sprintf("db: failed to find user (%s)", err), http.StatusInternalServerError) return } } if claim_pwdhash != db_pwdhash { http.Error(res, "user: invalid password", http.StatusForbidden) return } if _, err := h.db.Exec("DELETE FROM users WHERE name=? AND pwdhash=?", db_name, db_pwdhash); err != nil { http.Error(res, fmt.Sprintf("db: failed to delete user (%s)", err), http.StatusInternalServerError) return } }