completed auth + pair validation

This commit is contained in:
Vitaliy Pavlov 2024-05-16 04:16:05 +07:00
parent c4912af8fb
commit 37c411f4cb
12 changed files with 173 additions and 71 deletions

9
.air.toml Normal file
View File

@ -0,0 +1,9 @@
root = "."
tmp_dir = "tmp"
[build]
cmd = "go build -o ./tmp/main ."
bin = "./tmp/main"
delay = 1000 # ms
exclude_dir = ["assets", "tmp", "vendor"]
include_ext = ["go", "tpl", "tmpl", "html"]
exclude_regex = ["_test\\.go"]

3
.gitignore vendored
View File

@ -1 +1,2 @@
.env .env
tmp

View File

@ -0,0 +1,60 @@
package middlewares
import (
"net/http"
"strings"
"system-trace/core/app/constants"
"system-trace/core/auth"
"system-trace/core/utils"
"github.com/gofiber/fiber/v2"
)
func ValidateSession(c *fiber.Ctx) error {
p := new(auth.PairTokens)
if err := c.CookieParser(p); err != nil {
return c.Status(http.StatusBadRequest).JSON(fiber.Map{
"error": err.Error(),
})
}
if !validatePair(c, p) {
return c.Status(http.StatusForbidden).JSON(fiber.Map{
"error": constants.UNAUTHORIZED,
})
}
return c.Next()
}
func validatePair(c *fiber.Ctx, p *auth.PairTokens) bool {
if len(p.AccessToken) <= 0 || len(p.RefreshToken) <= 0 {
return false
}
claims, err := utils.ValidateJWT(p.AccessToken)
if (err != nil && strings.Contains(err.Error(), "token is expired")) || claims["iss"] != constants.JWT_APP_ISS {
rclaims, rerr := utils.ValidateJWT(p.RefreshToken)
if rerr != nil || (rerr != nil && strings.Contains(rerr.Error(), "token is expired")) || rclaims["sub"] != p.AccessToken {
return false
}
pt, err := auth.GetPair(p)
if err != nil {
return false
}
err = auth.RevokePair(p)
if err != nil {
return false
}
err = auth.GeneratePairAndSetCookie(c, pt.UserID)
if err != nil {
return false
}
}
// c.Locals("userId", id)
return true
}

View File

@ -4,7 +4,7 @@ import (
"os" "os"
"system-trace/core/app/constants" "system-trace/core/app/constants"
"system-trace/core/app/router" "system-trace/core/app/router"
"system-trace/core/auth/middlewares" "system-trace/core/app/router/middlewares"
"system-trace/core/environment" "system-trace/core/environment"
"github.com/goccy/go-json" "github.com/goccy/go-json"

View File

@ -42,7 +42,7 @@ func ReqTokens(c *fiber.Ctx) error {
}) })
} }
if u != nil { if u != nil {
p, err := genPair(u) err = GeneratePairAndSetCookie(c, u.ID)
if err != nil { if err != nil {
return c. return c.
Status(fiber.StatusBadRequest). Status(fiber.StatusBadRequest).
@ -50,13 +50,23 @@ func ReqTokens(c *fiber.Ctx) error {
"error": err.Error(), "error": err.Error(),
}) })
} }
setCookie(c, p) users.Login(u)
return c.SendStatus(fiber.StatusOK) return c.SendStatus(fiber.StatusOK)
} }
return errors.New(constants.AUTH_FAILED) return errors.New(constants.AUTH_FAILED)
} }
func GeneratePairAndSetCookie(c *fiber.Ctx, id int32) error {
p, err := genPair(id)
if err != nil {
return err
}
setCookie(c, p)
return nil
}
func setCookie(c *fiber.Ctx, p *PairTokens) { func setCookie(c *fiber.Ctx, p *PairTokens) {
// Access token // Access token
atc := new(fiber.Cookie) atc := new(fiber.Cookie)

View File

@ -1,40 +0,0 @@
package middlewares
import (
"fmt"
"net/http"
"strings"
"system-trace/core/app/constants"
"system-trace/core/utils"
"github.com/gofiber/fiber/v2"
)
func ValidateSession(c *fiber.Ctx) error {
header := c.GetReqHeaders()[http.CanonicalHeaderKey("Authorization")]
if len(header) <= 0 || len(header[0]) <= 0 || !validateToken(c, header[0]) {
return c.Status(http.StatusForbidden).JSON(fiber.Map{
"error": constants.UNAUTHORIZED,
})
}
return c.Next()
}
func validateToken(c *fiber.Ctx, hash string) bool {
splitted := strings.Split(hash, " ")
if len(splitted) <= 1 {
return false
}
claims, err := utils.ValidateJWT(splitted[1])
fmt.Println(claims, err)
// id, ok := claims["ID"].(string)
// TODO validate date and check refresh token
if err != nil || claims["iss"] != constants.JWT_APP_ISS {
return false
}
// c.Locals("userId", id)
return true
}

53
auth/sql.go Normal file
View File

@ -0,0 +1,53 @@
package auth
import (
"context"
"system-trace/core/database"
"system-trace/core/database/entities"
)
func GetPair(p *PairTokens) (*entities.AuthToken, error) {
aut := new(entities.AuthToken)
ctx := context.Background()
err := database.PG.NewSelect().
Model(aut).
Where("access_token = ?", p.AccessToken).
Where("refresh_token = ?", p.RefreshToken).
Where("is_revoked = ?", false).
Scan(ctx)
return aut, err
}
func RevokePair(p *PairTokens) error {
aut := entities.AuthToken{
IsRevoked: true,
}
ctx := context.Background()
_, err := database.PG.NewUpdate().
Model(&aut).
Column("is_revoked").
Where("access_token = ?", p.AccessToken).
Where("refresh_token = ?", p.RefreshToken).
Exec(ctx)
return err
}
func insertPair(id int32, at, rt string) error {
p := entities.AuthToken{
UserID: id,
AccessToken: at,
RefreshToken: rt,
}
ctx := context.Background()
_, err := database.PG.NewInsert().
Model(&p).
Returning("NULL").
Exec(ctx)
return err
}

View File

@ -1,11 +1,8 @@
package auth package auth
import ( import (
"context"
"fmt" "fmt"
"system-trace/core/app/constants" "system-trace/core/app/constants"
"system-trace/core/database"
"system-trace/core/database/entities"
"system-trace/core/utils" "system-trace/core/utils"
"time" "time"
@ -17,10 +14,13 @@ const (
RefreshTokenLifetime int8 = 24 RefreshTokenLifetime int8 = 24
) )
func genPair(u *entities.User) (*PairTokens, error) { func genPair(id int32) (*PairTokens, error) {
at, rt, err := genTokens(u) at, rt, err := genTokens(id)
if err != nil {
return nil, err
}
err = insertPair(u.ID, at, rt) err = insertPair(id, at, rt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -33,8 +33,8 @@ func genPair(u *entities.User) (*PairTokens, error) {
return &p, nil return &p, nil
} }
func genTokens(u *entities.User) (string, string, error) { func genTokens(id int32) (string, string, error) {
at, err := genToken(fmt.Sprintf("%d", u.ID), AccessTokenLifetime) at, err := genToken(fmt.Sprintf("%d", id), AccessTokenLifetime)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -48,7 +48,6 @@ func genTokens(u *entities.User) (string, string, error) {
} }
func genToken(sub string, hours int8) (string, error) { func genToken(sub string, hours int8) (string, error) {
fmt.Println(sub, hours)
c := jwt.MapClaims{ c := jwt.MapClaims{
"iss": constants.JWT_APP_ISS, "iss": constants.JWT_APP_ISS,
"sub": sub, "sub": sub,
@ -59,18 +58,3 @@ func genToken(sub string, hours int8) (string, error) {
return a, err return a, err
} }
func insertPair(id int32, at, rt string) error {
aut := entities.AuthToken{
UserID: id,
AccessToken: at,
RefreshToken: rt,
}
ctx := context.Background()
_, err := database.PG.NewInsert().
Model(&aut).
Exec(ctx)
return err
}

View File

@ -1,8 +1,8 @@
package auth package auth
type PairTokens struct { type PairTokens struct {
AccessToken string `json:"accessToken"` AccessToken string `json:"accessToken" cookie:"accessToken"`
RefreshToken string `json:"refreshToken"` RefreshToken string `json:"refreshToken" cookie:"refreshToken"`
} }
type AuthBody struct { type AuthBody struct {

View File

@ -15,8 +15,18 @@ func FindByEmailAndPassword(email, password string) (*entities.User, error) {
Model(u). Model(u).
Where("email = ?", email). Where("email = ?", email).
Where("password_hash = ?", passwordHash). Where("password_hash = ?", passwordHash).
Column("id").
Scan(ctx) Scan(ctx)
return u, err return u, err
} }
func UpdateUser(u *entities.User, cols []string) error {
ctx := context.Background()
_, err := database.PG.NewUpdate().
Model(u).
Column(cols...).
WherePK().
Exec(ctx)
return err
}

View File

@ -1,9 +1,17 @@
package users package users
import ( import (
"system-trace/core/database/entities"
"time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
func Get(c *fiber.Ctx) error { func Get(c *fiber.Ctx) error {
return nil return nil
} }
func Login(u *entities.User) error {
u.LastLogin = time.Now()
return UpdateUser(u, []string{"last_login"})
}

7
utils/date.go Normal file
View File

@ -0,0 +1,7 @@
package utils
import "time"
func DateExpired(t time.Time) bool {
return time.Now().Before(t)
}