diff --git a/backend/README.md b/backend/README.md index 9faf76e..9eeeb64 100644 --- a/backend/README.md +++ b/backend/README.md @@ -8,5 +8,9 @@ and initialize it with the `sql/init.sql` script: $ cat sql/init.sql | sqlite3 users.db ``` -You can optionally provide the `PORT` environment variable to override the -default port of `7741` +You also need to create a `.env` file with the following variables: + +- `JWT_SECRET`: Required. A cryptographically secure string used to encode +tokens. +- `PORT`: Optional. Overrides the default port of `7741` + diff --git a/backend/main.go b/backend/main.go index 787b9d0..b7187a6 100644 --- a/backend/main.go +++ b/backend/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + _ "github.com/joho/godotenv/autoload" "log" "net/http" "os" diff --git a/backend/user.go b/backend/user.go index 8cc1ea0..2a63882 100644 --- a/backend/user.go +++ b/backend/user.go @@ -5,11 +5,9 @@ import ( "errors" "fmt" "io" + "os" // Encryption - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" "encoding/json" "net/http" @@ -21,18 +19,30 @@ import ( _ "github.com/mattn/go-sqlite3" ) -// Define the user handler struct -type UserHandler struct { - db *sql.DB - jwt_secret *ecdsa.PrivateKey -} - // Define the user request struct type UserRequest struct { Name string `json:"name"` Password string `json:"password"` } +func (ur *UserRequest) Parse(req *http.Request) error { + // Can't unmarshal the actual req.Body so must read first + body, err := io.ReadAll(req.Body) + if err != nil { + return err + } + if err := json.Unmarshal(body, &ur); err != nil { + return err + } + return nil +} + +// Define the user handler struct +type UserHandler struct { + db *sql.DB + jwt_secret []byte +} + // Define the function to create user handlers func NewUserHandler() (*UserHandler, error) { // Initialise the database using the database file @@ -42,11 +52,12 @@ func NewUserHandler() (*UserHandler, error) { } // Define the JSON web token - jwt_secret, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + jwt_secret_str := os.Getenv("JWT_SECRET") // Return any errors - if err != nil { + if jwt_secret_str == "" { return nil, err } + jwt_secret := []byte(jwt_secret_str) // Return the user handler struct return &UserHandler{ @@ -55,6 +66,33 @@ func NewUserHandler() (*UserHandler, error) { }, nil } +// JWT Utilities +func (h *UserHandler) ParseUserToken(token_string string) (*jwt.Token, error) { + token, err := jwt.Parse(token_string, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, errors.New("jwt: incorrect token signing method") + } + return h.jwt_secret, nil + }) + if err != nil { + return nil, err + } + return token, nil +} + +func (h *UserHandler) GenerateUserToken(name, pwdhash string) (string, error) { + token := jwt.New(jwt.SigningMethodHS256) + claims := token.Claims.(jwt.MapClaims) + claims["name"] = name + claims["pwdhash"] = pwdhash + + token_string, err := token.SignedString(h.jwt_secret) + if err != nil { + return "", err + } + return token_string, nil +} + func (h *UserHandler) Handle(res http.ResponseWriter, req *http.Request) { switch req.Method { case "POST": @@ -72,19 +110,11 @@ func (h *UserHandler) Handle(res http.ResponseWriter, req *http.Request) { // NOTE(midnadimple): This function could be considered to do too much stuff, but // I think this is the best implementation func (h *UserHandler) createUser(res http.ResponseWriter, req *http.Request) { - // Can't unmarshal the actual req.Body so must read first - body, err := io.ReadAll(req.Body) - if err != nil { - http.Error(res, fmt.Sprintf("user: failed to read request (%s)", err), http.StatusBadRequest) + user_request := new(UserRequest) + if err := user_request.Parse(req); err != nil { + http.Error(res, fmt.Sprintf("user: failed to parse request (%s)", err), http.StatusBadRequest) return } - - var user_request UserRequest - if err := json.Unmarshal(body, &user_request); err != nil { - http.Error(res, "user: json request body doesn't match schema", http.StatusBadRequest) - return - } - name := user_request.Name password := []byte(user_request.Password) @@ -92,7 +122,7 @@ func (h *UserHandler) createUser(res http.ResponseWriter, req *http.Request) { row := h.db.QueryRow("SELECT pwdhash FROM users WHERE name=?", name) var db_pwdhash string - if err = row.Scan(&db_pwdhash); err != nil { + if err := row.Scan(&db_pwdhash); err != nil { // If no user is found with the requested name, create the user if errors.Is(err, sql.ErrNoRows) { pwdhash_bytes, err := bcrypt.GenerateFromPassword(password, 12) @@ -101,9 +131,9 @@ func (h *UserHandler) createUser(res http.ResponseWriter, req *http.Request) { http.Error(res, fmt.Sprintf("user: failed to generate password hash (%s)", err), http.StatusInternalServerError) return } - pwdhash := string(pwdhash_bytes) + db_pwdhash = string(pwdhash_bytes) - _, err = h.db.Exec("INSERT INTO users VALUES (?,?)", name, pwdhash) + _, err = h.db.Exec("INSERT INTO users VALUES (?,?)", name, db_pwdhash) if err != nil { http.Error(res, fmt.Sprintf("db: failed to create user (%s)", err), http.StatusInternalServerError) return @@ -117,21 +147,132 @@ func (h *UserHandler) createUser(res http.ResponseWriter, req *http.Request) { return } - // JWT generation - token := jwt.New(jwt.SigningMethodES256) - claims := token.Claims.(jwt.MapClaims) - claims["name"] = name - claims["pwdhash"] = db_pwdhash - - token_string, err := token.SignedString(h.jwt_secret) + // JWT Generation + token_string, err := h.GenerateUserToken(name, db_pwdhash) if err != nil { http.Error(res, fmt.Sprintf("jwt: failed to generate token (%s)", err), http.StatusInternalServerError) return } - fmt.Fprintf(res, "%s", token_string) } -// TODO(midnadimple): implement: -func (h *UserHandler) updateUser(res http.ResponseWriter, req *http.Request) {} -func (h *UserHandler) deleteUser(res http.ResponseWriter, req *http.Request) {} +func (h *UserHandler) updateUser(res http.ResponseWriter, req *http.Request) { + user_request := new(UserRequest) + if err := user_request.Parse(req); err != nil { + http.Error(res, fmt.Sprintf("user: failed to parse request (%s)", err), http.StatusBadRequest) + return + } + req_name := user_request.Name + req_password := []byte(user_request.Password) + + if req.Header["Authorization"] == nil { + http.Error(res, "jwt: missing token", http.StatusUnauthorized) + return + } + + token, err := h.ParseUserToken(req.Header["Token"][0]) + if err != nil { + http.Error(res, fmt.Sprintf("jwt: error during parsing (%s)", err), http.StatusInternalServerError) + return + } + if !token.Valid { + http.Error(res, "jwt: invalid token", http.StatusUnauthorized) + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + http.Error(res, "jwt: failed to get claims", http.StatusInternalServerError) + return + } + + claim_name := claims["name"].(string) + claim_pwdhash := claims["pwdhash"].(string) + + var db_name, db_pwdhash string + row := h.db.QueryRow("SELECT * FROM users WHERE name=?", claim_name) + if err := row.Scan(&db_name, &db_pwdhash); err != nil { + if errors.Is(err, sql.ErrNoRows) { + http.Error(res, "user: authorized user doesn't exist", http.StatusBadRequest) + return + } else { + http.Error(res, fmt.Sprintf("db: failed to find user (%s)", err), http.StatusInternalServerError) + return + } + } + + if claim_pwdhash != db_pwdhash { + http.Error(res, "user: invalid password", http.StatusForbidden) + return + } + + if req_name == claim_name && bcrypt.CompareHashAndPassword([]byte(claim_pwdhash), req_password) == nil { + http.Error(res, "user: requested credentials are the same as current credentials", http.StatusBadRequest) + return + } + + req_pwdhash_bytes, err := bcrypt.GenerateFromPassword(req_password, 12) + if err != nil { + http.Error(res, fmt.Sprintf("user: failed to generate password hash (%s)", err), http.StatusInternalServerError) + return + } + req_pwdhash := string(req_pwdhash_bytes) + + _, err = h.db.Exec("UPDATE users SET name=?, pwdhash=? WHERE name=? AND pwdhash=?", + req_name, req_pwdhash, claim_name, claim_pwdhash) + if err != nil { + http.Error(res, fmt.Sprintf("db: failed to update user (%s)", err), http.StatusInternalServerError) + } + + token_string, err := h.GenerateUserToken(req_name, req_pwdhash) + if err != nil { + http.Error(res, fmt.Sprintf("jwt: failed to generate token (%s)", err), http.StatusInternalServerError) + return + } + fmt.Fprintf(res, "%s", token_string) +} + +// TODO(midnadimple): Implement: +func (h *UserHandler) deleteUser(res http.ResponseWriter, req *http.Request) { + token, err := h.ParseUserToken(req.Header["Token"][0]) + if err != nil { + http.Error(res, fmt.Sprintf("jwt: error during parsing (%s)", err), http.StatusInternalServerError) + return + } + if !token.Valid { + http.Error(res, "jwt: invalid token", http.StatusUnauthorized) + return + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + http.Error(res, "jwt: failed to get claims", http.StatusInternalServerError) + return + } + + claim_name := claims["name"].(string) + claim_pwdhash := claims["pwdhash"].(string) + + var db_name, db_pwdhash string + row := h.db.QueryRow("SELECT * FROM users WHERE name=?", claim_name) + if err := row.Scan(&db_name, &db_pwdhash); err != nil { + if errors.Is(err, sql.ErrNoRows) { + http.Error(res, "user: authorized user doesn't exist", http.StatusBadRequest) + return + } else { + http.Error(res, fmt.Sprintf("db: failed to find user (%s)", err), http.StatusInternalServerError) + return + } + } + + if claim_pwdhash != db_pwdhash { + http.Error(res, "user: invalid password", http.StatusForbidden) + return + } + + if _, err := h.db.Exec("DELETE FROM users WHERE name=? AND pwdhash=?", db_name, db_pwdhash); err != nil { + http.Error(res, fmt.Sprintf("db: failed to delete user (%s)", err), http.StatusInternalServerError) + return + } + +}