completed auth + pair validation
This commit is contained in:
parent
c4912af8fb
commit
37c411f4cb
9
.air.toml
Normal file
9
.air.toml
Normal file
@ -0,0 +1,9 @@
|
||||
root = "."
|
||||
tmp_dir = "tmp"
|
||||
[build]
|
||||
cmd = "go build -o ./tmp/main ."
|
||||
bin = "./tmp/main"
|
||||
delay = 1000 # ms
|
||||
exclude_dir = ["assets", "tmp", "vendor"]
|
||||
include_ext = ["go", "tpl", "tmpl", "html"]
|
||||
exclude_regex = ["_test\\.go"]
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
||||
.env
|
||||
tmp
|
||||
60
app/router/middlewares/session.go
Normal file
60
app/router/middlewares/session.go
Normal file
@ -0,0 +1,60 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"system-trace/core/app/constants"
|
||||
"system-trace/core/auth"
|
||||
"system-trace/core/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func ValidateSession(c *fiber.Ctx) error {
|
||||
p := new(auth.PairTokens)
|
||||
if err := c.CookieParser(p); err != nil {
|
||||
return c.Status(http.StatusBadRequest).JSON(fiber.Map{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
if !validatePair(c, p) {
|
||||
return c.Status(http.StatusForbidden).JSON(fiber.Map{
|
||||
"error": constants.UNAUTHORIZED,
|
||||
})
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
func validatePair(c *fiber.Ctx, p *auth.PairTokens) bool {
|
||||
if len(p.AccessToken) <= 0 || len(p.RefreshToken) <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
claims, err := utils.ValidateJWT(p.AccessToken)
|
||||
if (err != nil && strings.Contains(err.Error(), "token is expired")) || claims["iss"] != constants.JWT_APP_ISS {
|
||||
rclaims, rerr := utils.ValidateJWT(p.RefreshToken)
|
||||
if rerr != nil || (rerr != nil && strings.Contains(rerr.Error(), "token is expired")) || rclaims["sub"] != p.AccessToken {
|
||||
return false
|
||||
}
|
||||
|
||||
pt, err := auth.GetPair(p)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
err = auth.RevokePair(p)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
err = auth.GeneratePairAndSetCookie(c, pt.UserID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// c.Locals("userId", id)
|
||||
|
||||
return true
|
||||
}
|
||||
@ -4,7 +4,7 @@ import (
|
||||
"os"
|
||||
"system-trace/core/app/constants"
|
||||
"system-trace/core/app/router"
|
||||
"system-trace/core/auth/middlewares"
|
||||
"system-trace/core/app/router/middlewares"
|
||||
"system-trace/core/environment"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
|
||||
14
auth/auth.go
14
auth/auth.go
@ -42,7 +42,7 @@ func ReqTokens(c *fiber.Ctx) error {
|
||||
})
|
||||
}
|
||||
if u != nil {
|
||||
p, err := genPair(u)
|
||||
err = GeneratePairAndSetCookie(c, u.ID)
|
||||
if err != nil {
|
||||
return c.
|
||||
Status(fiber.StatusBadRequest).
|
||||
@ -50,13 +50,23 @@ func ReqTokens(c *fiber.Ctx) error {
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
setCookie(c, p)
|
||||
users.Login(u)
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
return errors.New(constants.AUTH_FAILED)
|
||||
}
|
||||
|
||||
func GeneratePairAndSetCookie(c *fiber.Ctx, id int32) error {
|
||||
p, err := genPair(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setCookie(c, p)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func setCookie(c *fiber.Ctx, p *PairTokens) {
|
||||
// Access token
|
||||
atc := new(fiber.Cookie)
|
||||
|
||||
@ -1,40 +0,0 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"system-trace/core/app/constants"
|
||||
"system-trace/core/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func ValidateSession(c *fiber.Ctx) error {
|
||||
header := c.GetReqHeaders()[http.CanonicalHeaderKey("Authorization")]
|
||||
if len(header) <= 0 || len(header[0]) <= 0 || !validateToken(c, header[0]) {
|
||||
return c.Status(http.StatusForbidden).JSON(fiber.Map{
|
||||
"error": constants.UNAUTHORIZED,
|
||||
})
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
func validateToken(c *fiber.Ctx, hash string) bool {
|
||||
splitted := strings.Split(hash, " ")
|
||||
if len(splitted) <= 1 {
|
||||
return false
|
||||
}
|
||||
claims, err := utils.ValidateJWT(splitted[1])
|
||||
fmt.Println(claims, err)
|
||||
// id, ok := claims["ID"].(string)
|
||||
// TODO validate date and check refresh token
|
||||
if err != nil || claims["iss"] != constants.JWT_APP_ISS {
|
||||
return false
|
||||
}
|
||||
|
||||
// c.Locals("userId", id)
|
||||
|
||||
return true
|
||||
}
|
||||
53
auth/sql.go
Normal file
53
auth/sql.go
Normal file
@ -0,0 +1,53 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"system-trace/core/database"
|
||||
"system-trace/core/database/entities"
|
||||
)
|
||||
|
||||
func GetPair(p *PairTokens) (*entities.AuthToken, error) {
|
||||
aut := new(entities.AuthToken)
|
||||
|
||||
ctx := context.Background()
|
||||
err := database.PG.NewSelect().
|
||||
Model(aut).
|
||||
Where("access_token = ?", p.AccessToken).
|
||||
Where("refresh_token = ?", p.RefreshToken).
|
||||
Where("is_revoked = ?", false).
|
||||
Scan(ctx)
|
||||
|
||||
return aut, err
|
||||
}
|
||||
|
||||
func RevokePair(p *PairTokens) error {
|
||||
aut := entities.AuthToken{
|
||||
IsRevoked: true,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := database.PG.NewUpdate().
|
||||
Model(&aut).
|
||||
Column("is_revoked").
|
||||
Where("access_token = ?", p.AccessToken).
|
||||
Where("refresh_token = ?", p.RefreshToken).
|
||||
Exec(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func insertPair(id int32, at, rt string) error {
|
||||
p := entities.AuthToken{
|
||||
UserID: id,
|
||||
AccessToken: at,
|
||||
RefreshToken: rt,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := database.PG.NewInsert().
|
||||
Model(&p).
|
||||
Returning("NULL").
|
||||
Exec(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
@ -1,11 +1,8 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"system-trace/core/app/constants"
|
||||
"system-trace/core/database"
|
||||
"system-trace/core/database/entities"
|
||||
"system-trace/core/utils"
|
||||
"time"
|
||||
|
||||
@ -17,10 +14,13 @@ const (
|
||||
RefreshTokenLifetime int8 = 24
|
||||
)
|
||||
|
||||
func genPair(u *entities.User) (*PairTokens, error) {
|
||||
at, rt, err := genTokens(u)
|
||||
func genPair(id int32) (*PairTokens, error) {
|
||||
at, rt, err := genTokens(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = insertPair(u.ID, at, rt)
|
||||
err = insertPair(id, at, rt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -33,8 +33,8 @@ func genPair(u *entities.User) (*PairTokens, error) {
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func genTokens(u *entities.User) (string, string, error) {
|
||||
at, err := genToken(fmt.Sprintf("%d", u.ID), AccessTokenLifetime)
|
||||
func genTokens(id int32) (string, string, error) {
|
||||
at, err := genToken(fmt.Sprintf("%d", id), AccessTokenLifetime)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
@ -48,7 +48,6 @@ func genTokens(u *entities.User) (string, string, error) {
|
||||
}
|
||||
|
||||
func genToken(sub string, hours int8) (string, error) {
|
||||
fmt.Println(sub, hours)
|
||||
c := jwt.MapClaims{
|
||||
"iss": constants.JWT_APP_ISS,
|
||||
"sub": sub,
|
||||
@ -59,18 +58,3 @@ func genToken(sub string, hours int8) (string, error) {
|
||||
|
||||
return a, err
|
||||
}
|
||||
|
||||
func insertPair(id int32, at, rt string) error {
|
||||
aut := entities.AuthToken{
|
||||
UserID: id,
|
||||
AccessToken: at,
|
||||
RefreshToken: rt,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := database.PG.NewInsert().
|
||||
Model(&aut).
|
||||
Exec(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
package auth
|
||||
|
||||
type PairTokens struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
AccessToken string `json:"accessToken" cookie:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken" cookie:"refreshToken"`
|
||||
}
|
||||
|
||||
type AuthBody struct {
|
||||
|
||||
12
users/sql.go
12
users/sql.go
@ -15,8 +15,18 @@ func FindByEmailAndPassword(email, password string) (*entities.User, error) {
|
||||
Model(u).
|
||||
Where("email = ?", email).
|
||||
Where("password_hash = ?", passwordHash).
|
||||
Column("id").
|
||||
Scan(ctx)
|
||||
|
||||
return u, err
|
||||
}
|
||||
|
||||
func UpdateUser(u *entities.User, cols []string) error {
|
||||
ctx := context.Background()
|
||||
_, err := database.PG.NewUpdate().
|
||||
Model(u).
|
||||
Column(cols...).
|
||||
WherePK().
|
||||
Exec(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"system-trace/core/database/entities"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
func Get(c *fiber.Ctx) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Login(u *entities.User) error {
|
||||
u.LastLogin = time.Now()
|
||||
return UpdateUser(u, []string{"last_login"})
|
||||
}
|
||||
|
||||
7
utils/date.go
Normal file
7
utils/date.go
Normal file
@ -0,0 +1,7 @@
|
||||
package utils
|
||||
|
||||
import "time"
|
||||
|
||||
func DateExpired(t time.Time) bool {
|
||||
return time.Now().Before(t)
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user