changed bun to gorm in auth plugin
This commit is contained in:
parent
6028c5d60d
commit
6055fc6426
12 changed files with 109 additions and 97 deletions
|
@ -5,21 +5,25 @@ import (
|
|||
"os"
|
||||
|
||||
"github.com/anthdm/superkit/db"
|
||||
"github.com/anthdm/superkit/kit"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||
"github.com/uptrace/bun/extra/bundebug"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// I could not came up with a better naming for this.
|
||||
// Ideally, app should export a global variable called "DB"
|
||||
// but this will cause imports cycles for plugins.
|
||||
var Query *bun.DB
|
||||
// By default this is a pre-configured Gorm DB instance.
|
||||
// Change this type based on the database package of your likings.
|
||||
var dbInstance *gorm.DB
|
||||
|
||||
// Get returns the instantiated DB instance.
|
||||
func Get() *gorm.DB {
|
||||
return dbInstance
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Create a default *sql.DB exposed by the superkit/db package
|
||||
// based on the given configuration.
|
||||
config := db.Config{
|
||||
Driver: os.Getenv("DB_DRIVER"),
|
||||
Name: os.Getenv("DB_NAME"),
|
||||
|
@ -27,12 +31,30 @@ func init() {
|
|||
User: os.Getenv("DB_USER"),
|
||||
Host: os.Getenv("DB_HOST"),
|
||||
}
|
||||
db, err := db.New(config)
|
||||
dbinst, err := db.NewSQL(config)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
Query = bun.NewDB(db, sqlitedialect.New())
|
||||
if kit.IsDevelopment() {
|
||||
Query.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
|
||||
// Based on the driver create the corresponding DB instance.
|
||||
// By default, the SuperKit boilerplate comes with a pre-configured
|
||||
// ORM called Gorm. https://gorm.io.
|
||||
//
|
||||
// You can change this to any other DB interaction tool
|
||||
// of your liking. EG:
|
||||
// - uptrace bun -> https://bun.uptrace.dev
|
||||
// - SQLC -> https://github.com/sqlc-dev/sqlc
|
||||
// - gojet -> https://github.com/go-jet/jet
|
||||
switch config.Driver {
|
||||
case db.DriverSqlite3:
|
||||
dbInstance, err = gorm.Open(sqlite.New(sqlite.Config{
|
||||
Conn: dbinst,
|
||||
}))
|
||||
case db.DriverMysql:
|
||||
// ...
|
||||
default:
|
||||
log.Fatal("invalid driver:", config.Driver)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,9 +5,10 @@ create table if not exists users(
|
|||
password_hash text not null,
|
||||
first_name text not null,
|
||||
last_name text not null,
|
||||
email_verified_at timestamp with time zone,
|
||||
created_at timestamp with time zone not null,
|
||||
updated_at timestamp with time zone not null
|
||||
email_verified_at datetime,
|
||||
created_at datetime not null,
|
||||
updated_at datetime not null,
|
||||
deleted_at datetime
|
||||
);
|
||||
|
||||
-- +goose Down
|
||||
|
|
|
@ -5,9 +5,10 @@ create table if not exists sessions(
|
|||
user_id integer not null references users,
|
||||
ip_address text,
|
||||
user_agent text,
|
||||
expires_at timestamp with time zone not null,
|
||||
last_login_at timestamp with time zone,
|
||||
created_at timestamp with time zone not null
|
||||
expires_at datetime not null,
|
||||
created_at datetime not null,
|
||||
updated_at datetime not null,
|
||||
deleted_at datetime
|
||||
);
|
||||
|
||||
-- +goose Down
|
||||
|
|
|
@ -2,7 +2,7 @@ package types
|
|||
|
||||
// AuthUser represents an user that might be authenticated.
|
||||
type AuthUser struct {
|
||||
ID int
|
||||
ID uint
|
||||
Email string
|
||||
LoggedIn bool
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ module AABBCCDD
|
|||
go 1.22.4
|
||||
|
||||
// uncomment for local development on the superkit core.
|
||||
// replace github.com/anthdm/superkit => ../
|
||||
replace github.com/anthdm/superkit => ../
|
||||
|
||||
require (
|
||||
github.com/a-h/templ v0.2.707
|
||||
|
@ -24,10 +24,13 @@ require (
|
|||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/gorilla/sessions v1.3.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
golang.org/x/sys v0.21.0 // indirect
|
||||
gorm.io/driver/sqlite v1.5.6 // indirect
|
||||
gorm.io/gorm v1.25.10 // indirect
|
||||
)
|
||||
|
|
|
@ -18,6 +18,8 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
|
|||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
|
@ -45,3 +47,7 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE=
|
||||
gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
|
||||
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
|
||||
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -40,12 +41,9 @@ func HandleLoginCreate(kit *kit.Kit) error {
|
|||
}
|
||||
|
||||
var user User
|
||||
err := db.Query.NewSelect().
|
||||
Model(&user).
|
||||
Where("user.email = ?", values.Email).
|
||||
Scan(kit.Request.Context())
|
||||
err := db.Get().Find(&user, "email = ?", values.Email).Error
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
errors.Add("credentials", "invalid credentials")
|
||||
return kit.Render(LoginForm(values, errors))
|
||||
}
|
||||
|
@ -59,7 +57,7 @@ func HandleLoginCreate(kit *kit.Kit) error {
|
|||
|
||||
skipVerify := kit.Getenv("SUPERKIT_AUTH_SKIP_VERIFY", "false")
|
||||
if skipVerify != "true" {
|
||||
if user.EmailVerifiedAt.Equal(time.Time{}) {
|
||||
if !user.EmailVerifiedAt.Valid {
|
||||
errors.Add("verified", "please verify your email")
|
||||
return kit.Render(LoginForm(values, errors))
|
||||
}
|
||||
|
@ -71,24 +69,17 @@ func HandleLoginCreate(kit *kit.Kit) error {
|
|||
sessionExpiry = 48
|
||||
}
|
||||
session := Session{
|
||||
UserID: user.ID,
|
||||
Token: uuid.New().String(),
|
||||
CreatedAt: time.Now(),
|
||||
LastLoginAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour * time.Duration(sessionExpiry)),
|
||||
UserID: user.ID,
|
||||
Token: uuid.New().String(),
|
||||
ExpiresAt: time.Now().Add(time.Hour * time.Duration(sessionExpiry)),
|
||||
}
|
||||
_, err = db.Query.NewInsert().
|
||||
Model(&session).
|
||||
Exec(kit.Request.Context())
|
||||
if err != nil {
|
||||
if err = db.Get().Create(&session).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO change this with kit.Getenv
|
||||
sess := kit.GetSession(userSessionName)
|
||||
sess.Values["sessionToken"] = session.Token
|
||||
sess.Save(kit.Request, kit.Response)
|
||||
|
||||
redirectURL := kit.Getenv("SUPERKIT_AUTH_REDIRECT_AFTER_LOGIN", "/profile")
|
||||
|
||||
return kit.Redirect(http.StatusSeeOther, redirectURL)
|
||||
|
@ -100,10 +91,7 @@ func HandleLoginDelete(kit *kit.Kit) error {
|
|||
sess.Values = map[any]any{}
|
||||
sess.Save(kit.Request, kit.Response)
|
||||
}()
|
||||
_, err := db.Query.NewDelete().
|
||||
Model((*Session)(nil)).
|
||||
Where("token = ?", sess.Values["sessionToken"]).
|
||||
Exec(kit.Request.Context())
|
||||
err := db.Get().Delete(&Session{}, "token = ?", sess.Values["sessionToken"]).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -121,7 +109,7 @@ func HandleEmailVerify(kit *kit.Kit) error {
|
|||
return []byte(os.Getenv("SUPERKIT_SECRET")), nil
|
||||
}, jwt.WithLeeway(5*time.Second))
|
||||
if err != nil {
|
||||
return err
|
||||
return kit.Render(EmailVerificationError("invalid verification token"))
|
||||
}
|
||||
if !token.Valid {
|
||||
return kit.Render(EmailVerificationError("invalid verification token"))
|
||||
|
@ -141,23 +129,18 @@ func HandleEmailVerify(kit *kit.Kit) error {
|
|||
}
|
||||
|
||||
var user User
|
||||
err = db.Query.NewSelect().
|
||||
Model(&user).
|
||||
Where("id = ?", userID).
|
||||
Scan(kit.Request.Context())
|
||||
err = db.Get().First(&user, userID).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.EmailVerifiedAt.After(time.Time{}) {
|
||||
if user.EmailVerifiedAt.Time.After(time.Time{}) {
|
||||
return kit.Render(EmailVerificationError("Email already verified"))
|
||||
}
|
||||
|
||||
user.EmailVerifiedAt = time.Now()
|
||||
_, err = db.Query.NewUpdate().
|
||||
Model(&user).
|
||||
WherePK().
|
||||
Exec(kit.Request.Context())
|
||||
now := sql.NullTime{Time: time.Now(), Valid: true}
|
||||
user.EmailVerifiedAt = now
|
||||
err = db.Get().Save(&user).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -174,16 +157,14 @@ func AuthenticateUser(kit *kit.Kit) (kit.Auth, error) {
|
|||
}
|
||||
|
||||
var session Session
|
||||
err := db.Query.NewSelect().
|
||||
Model(&session).
|
||||
Relation("User").
|
||||
Where("session.token = ? AND session.expires_at > ?", token, time.Now()).
|
||||
Scan(kit.Request.Context())
|
||||
if err != nil {
|
||||
err := db.Get().
|
||||
Preload("User").
|
||||
Find(&session, "token = ? AND expires_at > ?", token, time.Now()).Error
|
||||
if err != nil || session.ID == 0 {
|
||||
return auth, nil
|
||||
}
|
||||
return Auth{
|
||||
|
||||
return Auth{
|
||||
LoggedIn: true,
|
||||
UserID: session.User.ID,
|
||||
Email: session.User.Email,
|
||||
|
|
|
@ -14,7 +14,7 @@ var profileSchema = v.Schema{
|
|||
}
|
||||
|
||||
type ProfileFormValues struct {
|
||||
ID int `form:"id"`
|
||||
ID uint `form:"id"`
|
||||
FirstName string `form:"firstName"`
|
||||
LastName string `form:"lastName"`
|
||||
Email string
|
||||
|
@ -25,11 +25,7 @@ func HandleProfileShow(kit *kit.Kit) error {
|
|||
auth := kit.Auth().(Auth)
|
||||
|
||||
var user User
|
||||
err := db.Query.NewSelect().
|
||||
Model(&user).
|
||||
Where("id = ?", auth.UserID).
|
||||
Scan(kit.Request.Context())
|
||||
if err != nil {
|
||||
if err := db.Get().First(&user, auth.UserID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -54,12 +50,12 @@ func HandleProfileUpdate(kit *kit.Kit) error {
|
|||
if auth.UserID != values.ID {
|
||||
return fmt.Errorf("unauthorized request for profile %d", values.ID)
|
||||
}
|
||||
_, err := db.Query.NewUpdate().
|
||||
Model((*User)(nil)).
|
||||
Set("first_name = ?", values.FirstName).
|
||||
Set("last_name = ?", values.LastName).
|
||||
err := db.Get().Model(&User{}).
|
||||
Where("id = ?", auth.UserID).
|
||||
Exec(kit.Request.Context())
|
||||
Updates(&User{
|
||||
FirstName: values.FirstName,
|
||||
LastName: values.LastName,
|
||||
}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -63,19 +63,15 @@ func HandleResendVerificationCode(kit *kit.Kit) error {
|
|||
}
|
||||
|
||||
var user User
|
||||
err = db.Query.NewSelect().
|
||||
Model(&user).
|
||||
Where("id = ?", id).
|
||||
Scan(kit.Request.Context())
|
||||
if err != nil {
|
||||
if err = db.Get().First(&user, id).Error; err != nil {
|
||||
return kit.Text(http.StatusOK, "An unexpected error occured")
|
||||
}
|
||||
|
||||
if user.EmailVerifiedAt.After(time.Time{}) {
|
||||
if user.EmailVerifiedAt.Time.After(time.Time{}) {
|
||||
return kit.Text(http.StatusOK, "Email already verified!")
|
||||
}
|
||||
|
||||
token, err := createVerificationToken(id)
|
||||
token, err := createVerificationToken(uint(id))
|
||||
if err != nil {
|
||||
return kit.Text(http.StatusOK, "An unexpected error occured")
|
||||
}
|
||||
|
@ -90,7 +86,7 @@ func HandleResendVerificationCode(kit *kit.Kit) error {
|
|||
return kit.Text(http.StatusOK, msg)
|
||||
}
|
||||
|
||||
func createVerificationToken(userID int) (string, error) {
|
||||
func createVerificationToken(userID uint) (string, error) {
|
||||
expiryStr := kit.Getenv("SUPERKIT_AUTH_EMAIL_VERIFICATION_EXPIRY_IN_HOURS", "1")
|
||||
expiry, err := strconv.Atoi(expiryStr)
|
||||
if err != nil {
|
||||
|
|
|
@ -2,10 +2,11 @@ package auth
|
|||
|
||||
import (
|
||||
"AABBCCDD/app/db"
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Event name constants
|
||||
|
@ -22,7 +23,7 @@ type UserWithVerificationToken struct {
|
|||
}
|
||||
|
||||
type Auth struct {
|
||||
UserID int
|
||||
UserID uint
|
||||
Email string
|
||||
LoggedIn bool
|
||||
}
|
||||
|
@ -32,12 +33,13 @@ func (auth Auth) Check() bool {
|
|||
}
|
||||
|
||||
type User struct {
|
||||
ID int `bun:"id,pk,autoincrement"`
|
||||
gorm.Model
|
||||
|
||||
Email string
|
||||
FirstName string
|
||||
LastName string
|
||||
PasswordHash string
|
||||
EmailVerifiedAt time.Time
|
||||
EmailVerifiedAt sql.NullTime
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
@ -52,22 +54,19 @@ func createUserFromFormValues(values SignupFormValues) (User, error) {
|
|||
FirstName: values.FirstName,
|
||||
LastName: values.LastName,
|
||||
PasswordHash: string(hash),
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
_, err = db.Query.NewInsert().Model(&user).Exec(context.Background())
|
||||
return user, err
|
||||
result := db.Get().Create(&user)
|
||||
return user, result.Error
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID int `bun:"id,pk,autoincrement"`
|
||||
UserID int
|
||||
Token string
|
||||
IPAddress string
|
||||
UserAgent string
|
||||
ExpiresAt time.Time
|
||||
LastLoginAt time.Time
|
||||
CreatedAt time.Time
|
||||
gorm.Model
|
||||
|
||||
User User `bun:"rel:belongs-to,join:user_id=id"`
|
||||
UserID uint
|
||||
Token string
|
||||
IPAddress string
|
||||
UserAgent string
|
||||
ExpiresAt time.Time
|
||||
CreatedAt time.Time
|
||||
User User
|
||||
}
|
||||
|
|
3
db/db.go
3
db/db.go
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
const (
|
||||
DriverSqlite3 = "sqlite3"
|
||||
DriverMysql = "mysql"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
|
@ -17,7 +18,7 @@ type Config struct {
|
|||
Password string
|
||||
}
|
||||
|
||||
func New(cfg Config) (*sql.DB, error) {
|
||||
func NewSQL(cfg Config) (*sql.DB, error) {
|
||||
switch cfg.Driver {
|
||||
case DriverSqlite3:
|
||||
name := cfg.Name
|
||||
|
|
|
@ -153,12 +153,18 @@ func parseRequest(r *http.Request, v any) error {
|
|||
}
|
||||
case reflect.String:
|
||||
fieldVal.SetString(formValue)
|
||||
case reflect.Int:
|
||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||
intVal, err := strconv.Atoi(formValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse int: %v", err)
|
||||
}
|
||||
fieldVal.SetInt(int64(intVal))
|
||||
case reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||
intVal, err := strconv.Atoi(formValue)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse int: %v", err)
|
||||
}
|
||||
fieldVal.SetUint(uint64(intVal))
|
||||
case reflect.Float64:
|
||||
floatVal, err := strconv.ParseFloat(formValue, 64)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in a new issue