completed auth + pair validation
This commit is contained in:
parent
c4912af8fb
commit
37c411f4cb
9
.air.toml
Normal file
9
.air.toml
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
root = "."
|
||||||
|
tmp_dir = "tmp"
|
||||||
|
[build]
|
||||||
|
cmd = "go build -o ./tmp/main ."
|
||||||
|
bin = "./tmp/main"
|
||||||
|
delay = 1000 # ms
|
||||||
|
exclude_dir = ["assets", "tmp", "vendor"]
|
||||||
|
include_ext = ["go", "tpl", "tmpl", "html"]
|
||||||
|
exclude_regex = ["_test\\.go"]
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
.env
|
.env
|
||||||
|
tmp
|
||||||
60
app/router/middlewares/session.go
Normal file
60
app/router/middlewares/session.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package middlewares
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"system-trace/core/app/constants"
|
||||||
|
"system-trace/core/auth"
|
||||||
|
"system-trace/core/utils"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ValidateSession(c *fiber.Ctx) error {
|
||||||
|
p := new(auth.PairTokens)
|
||||||
|
if err := c.CookieParser(p); err != nil {
|
||||||
|
return c.Status(http.StatusBadRequest).JSON(fiber.Map{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if !validatePair(c, p) {
|
||||||
|
return c.Status(http.StatusForbidden).JSON(fiber.Map{
|
||||||
|
"error": constants.UNAUTHORIZED,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Next()
|
||||||
|
}
|
||||||
|
|
||||||
|
func validatePair(c *fiber.Ctx, p *auth.PairTokens) bool {
|
||||||
|
if len(p.AccessToken) <= 0 || len(p.RefreshToken) <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := utils.ValidateJWT(p.AccessToken)
|
||||||
|
if (err != nil && strings.Contains(err.Error(), "token is expired")) || claims["iss"] != constants.JWT_APP_ISS {
|
||||||
|
rclaims, rerr := utils.ValidateJWT(p.RefreshToken)
|
||||||
|
if rerr != nil || (rerr != nil && strings.Contains(rerr.Error(), "token is expired")) || rclaims["sub"] != p.AccessToken {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
pt, err := auth.GetPair(p)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
err = auth.RevokePair(p)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
err = auth.GeneratePairAndSetCookie(c, pt.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// c.Locals("userId", id)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
@ -4,7 +4,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"system-trace/core/app/constants"
|
"system-trace/core/app/constants"
|
||||||
"system-trace/core/app/router"
|
"system-trace/core/app/router"
|
||||||
"system-trace/core/auth/middlewares"
|
"system-trace/core/app/router/middlewares"
|
||||||
"system-trace/core/environment"
|
"system-trace/core/environment"
|
||||||
|
|
||||||
"github.com/goccy/go-json"
|
"github.com/goccy/go-json"
|
||||||
|
|||||||
14
auth/auth.go
14
auth/auth.go
@ -42,7 +42,7 @@ func ReqTokens(c *fiber.Ctx) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
if u != nil {
|
if u != nil {
|
||||||
p, err := genPair(u)
|
err = GeneratePairAndSetCookie(c, u.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.
|
return c.
|
||||||
Status(fiber.StatusBadRequest).
|
Status(fiber.StatusBadRequest).
|
||||||
@ -50,13 +50,23 @@ func ReqTokens(c *fiber.Ctx) error {
|
|||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
setCookie(c, p)
|
users.Login(u)
|
||||||
return c.SendStatus(fiber.StatusOK)
|
return c.SendStatus(fiber.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors.New(constants.AUTH_FAILED)
|
return errors.New(constants.AUTH_FAILED)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GeneratePairAndSetCookie(c *fiber.Ctx, id int32) error {
|
||||||
|
p, err := genPair(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
setCookie(c, p)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func setCookie(c *fiber.Ctx, p *PairTokens) {
|
func setCookie(c *fiber.Ctx, p *PairTokens) {
|
||||||
// Access token
|
// Access token
|
||||||
atc := new(fiber.Cookie)
|
atc := new(fiber.Cookie)
|
||||||
|
|||||||
@ -1,40 +0,0 @@
|
|||||||
package middlewares
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"system-trace/core/app/constants"
|
|
||||||
"system-trace/core/utils"
|
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ValidateSession(c *fiber.Ctx) error {
|
|
||||||
header := c.GetReqHeaders()[http.CanonicalHeaderKey("Authorization")]
|
|
||||||
if len(header) <= 0 || len(header[0]) <= 0 || !validateToken(c, header[0]) {
|
|
||||||
return c.Status(http.StatusForbidden).JSON(fiber.Map{
|
|
||||||
"error": constants.UNAUTHORIZED,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateToken(c *fiber.Ctx, hash string) bool {
|
|
||||||
splitted := strings.Split(hash, " ")
|
|
||||||
if len(splitted) <= 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
claims, err := utils.ValidateJWT(splitted[1])
|
|
||||||
fmt.Println(claims, err)
|
|
||||||
// id, ok := claims["ID"].(string)
|
|
||||||
// TODO validate date and check refresh token
|
|
||||||
if err != nil || claims["iss"] != constants.JWT_APP_ISS {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// c.Locals("userId", id)
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
53
auth/sql.go
Normal file
53
auth/sql.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"system-trace/core/database"
|
||||||
|
"system-trace/core/database/entities"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetPair(p *PairTokens) (*entities.AuthToken, error) {
|
||||||
|
aut := new(entities.AuthToken)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := database.PG.NewSelect().
|
||||||
|
Model(aut).
|
||||||
|
Where("access_token = ?", p.AccessToken).
|
||||||
|
Where("refresh_token = ?", p.RefreshToken).
|
||||||
|
Where("is_revoked = ?", false).
|
||||||
|
Scan(ctx)
|
||||||
|
|
||||||
|
return aut, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func RevokePair(p *PairTokens) error {
|
||||||
|
aut := entities.AuthToken{
|
||||||
|
IsRevoked: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err := database.PG.NewUpdate().
|
||||||
|
Model(&aut).
|
||||||
|
Column("is_revoked").
|
||||||
|
Where("access_token = ?", p.AccessToken).
|
||||||
|
Where("refresh_token = ?", p.RefreshToken).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func insertPair(id int32, at, rt string) error {
|
||||||
|
p := entities.AuthToken{
|
||||||
|
UserID: id,
|
||||||
|
AccessToken: at,
|
||||||
|
RefreshToken: rt,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err := database.PG.NewInsert().
|
||||||
|
Model(&p).
|
||||||
|
Returning("NULL").
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
@ -1,11 +1,8 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"system-trace/core/app/constants"
|
"system-trace/core/app/constants"
|
||||||
"system-trace/core/database"
|
|
||||||
"system-trace/core/database/entities"
|
|
||||||
"system-trace/core/utils"
|
"system-trace/core/utils"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -17,10 +14,13 @@ const (
|
|||||||
RefreshTokenLifetime int8 = 24
|
RefreshTokenLifetime int8 = 24
|
||||||
)
|
)
|
||||||
|
|
||||||
func genPair(u *entities.User) (*PairTokens, error) {
|
func genPair(id int32) (*PairTokens, error) {
|
||||||
at, rt, err := genTokens(u)
|
at, rt, err := genTokens(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
err = insertPair(u.ID, at, rt)
|
err = insertPair(id, at, rt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -33,8 +33,8 @@ func genPair(u *entities.User) (*PairTokens, error) {
|
|||||||
return &p, nil
|
return &p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func genTokens(u *entities.User) (string, string, error) {
|
func genTokens(id int32) (string, string, error) {
|
||||||
at, err := genToken(fmt.Sprintf("%d", u.ID), AccessTokenLifetime)
|
at, err := genToken(fmt.Sprintf("%d", id), AccessTokenLifetime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
@ -48,7 +48,6 @@ func genTokens(u *entities.User) (string, string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func genToken(sub string, hours int8) (string, error) {
|
func genToken(sub string, hours int8) (string, error) {
|
||||||
fmt.Println(sub, hours)
|
|
||||||
c := jwt.MapClaims{
|
c := jwt.MapClaims{
|
||||||
"iss": constants.JWT_APP_ISS,
|
"iss": constants.JWT_APP_ISS,
|
||||||
"sub": sub,
|
"sub": sub,
|
||||||
@ -59,18 +58,3 @@ func genToken(sub string, hours int8) (string, error) {
|
|||||||
|
|
||||||
return a, err
|
return a, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func insertPair(id int32, at, rt string) error {
|
|
||||||
aut := entities.AuthToken{
|
|
||||||
UserID: id,
|
|
||||||
AccessToken: at,
|
|
||||||
RefreshToken: rt,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
_, err := database.PG.NewInsert().
|
|
||||||
Model(&aut).
|
|
||||||
Exec(ctx)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
type PairTokens struct {
|
type PairTokens struct {
|
||||||
AccessToken string `json:"accessToken"`
|
AccessToken string `json:"accessToken" cookie:"accessToken"`
|
||||||
RefreshToken string `json:"refreshToken"`
|
RefreshToken string `json:"refreshToken" cookie:"refreshToken"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthBody struct {
|
type AuthBody struct {
|
||||||
|
|||||||
12
users/sql.go
12
users/sql.go
@ -15,8 +15,18 @@ func FindByEmailAndPassword(email, password string) (*entities.User, error) {
|
|||||||
Model(u).
|
Model(u).
|
||||||
Where("email = ?", email).
|
Where("email = ?", email).
|
||||||
Where("password_hash = ?", passwordHash).
|
Where("password_hash = ?", passwordHash).
|
||||||
Column("id").
|
|
||||||
Scan(ctx)
|
Scan(ctx)
|
||||||
|
|
||||||
return u, err
|
return u, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func UpdateUser(u *entities.User, cols []string) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err := database.PG.NewUpdate().
|
||||||
|
Model(u).
|
||||||
|
Column(cols...).
|
||||||
|
WherePK().
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|||||||
@ -1,9 +1,17 @@
|
|||||||
package users
|
package users
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"system-trace/core/database/entities"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Get(c *fiber.Ctx) error {
|
func Get(c *fiber.Ctx) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Login(u *entities.User) error {
|
||||||
|
u.LastLogin = time.Now()
|
||||||
|
return UpdateUser(u, []string{"last_login"})
|
||||||
|
}
|
||||||
|
|||||||
7
utils/date.go
Normal file
7
utils/date.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
func DateExpired(t time.Time) bool {
|
||||||
|
return time.Now().Before(t)
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user