From 6055fc642660bae2edfe37d1508f8e46640625a7 Mon Sep 17 00:00:00 2001 From: anthdm Date: Sun, 23 Jun 2024 10:53:56 +0200 Subject: [PATCH] changed bun to gorm in auth plugin --- bootstrap/app/db/db.go | 46 +++++++++++---- .../20240610161057_create_users_table.sql | 7 ++- .../20240610163918_add_sessions_table.sql | 7 ++- bootstrap/app/types/auth.go | 2 +- bootstrap/go.mod | 5 +- bootstrap/go.sum | 6 ++ bootstrap/plugins/auth/auth_handler.go | 59 +++++++------------ bootstrap/plugins/auth/profile_handler.go | 18 +++--- bootstrap/plugins/auth/signup_handler.go | 12 ++-- bootstrap/plugins/auth/types.go | 33 +++++------ db/db.go | 3 +- validate/validate.go | 8 ++- 12 files changed, 109 insertions(+), 97 deletions(-) diff --git a/bootstrap/app/db/db.go b/bootstrap/app/db/db.go index df39e95..dae7888 100644 --- a/bootstrap/app/db/db.go +++ b/bootstrap/app/db/db.go @@ -5,21 +5,25 @@ import ( "os" "github.com/anthdm/superkit/db" - "github.com/anthdm/superkit/kit" _ "github.com/mattn/go-sqlite3" - "github.com/uptrace/bun" - "github.com/uptrace/bun/dialect/sqlitedialect" - "github.com/uptrace/bun/extra/bundebug" + "gorm.io/driver/sqlite" + "gorm.io/gorm" ) -// I could not came up with a better naming for this. -// Ideally, app should export a global variable called "DB" -// but this will cause imports cycles for plugins. -var Query *bun.DB +// By default this is a pre-configured Gorm DB instance. +// Change this type based on the database package of your likings. +var dbInstance *gorm.DB + +// Get returns the instantiated DB instance. +func Get() *gorm.DB { + return dbInstance +} func init() { + // Create a default *sql.DB exposed by the superkit/db package + // based on the given configuration. config := db.Config{ Driver: os.Getenv("DB_DRIVER"), Name: os.Getenv("DB_NAME"), @@ -27,12 +31,30 @@ func init() { User: os.Getenv("DB_USER"), Host: os.Getenv("DB_HOST"), } - db, err := db.New(config) + dbinst, err := db.NewSQL(config) if err != nil { log.Fatal(err) } - Query = bun.NewDB(db, sqlitedialect.New()) - if kit.IsDevelopment() { - Query.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true))) + // Based on the driver create the corresponding DB instance. + // By default, the SuperKit boilerplate comes with a pre-configured + // ORM called Gorm. https://gorm.io. + // + // You can change this to any other DB interaction tool + // of your liking. EG: + // - uptrace bun -> https://bun.uptrace.dev + // - SQLC -> https://github.com/sqlc-dev/sqlc + // - gojet -> https://github.com/go-jet/jet + switch config.Driver { + case db.DriverSqlite3: + dbInstance, err = gorm.Open(sqlite.New(sqlite.Config{ + Conn: dbinst, + })) + case db.DriverMysql: + // ... + default: + log.Fatal("invalid driver:", config.Driver) + } + if err != nil { + log.Fatal(err) } } diff --git a/bootstrap/app/db/migrations/20240610161057_create_users_table.sql b/bootstrap/app/db/migrations/20240610161057_create_users_table.sql index cd84d68..7b0193e 100644 --- a/bootstrap/app/db/migrations/20240610161057_create_users_table.sql +++ b/bootstrap/app/db/migrations/20240610161057_create_users_table.sql @@ -5,9 +5,10 @@ create table if not exists users( password_hash text not null, first_name text not null, last_name text not null, - email_verified_at timestamp with time zone, - created_at timestamp with time zone not null, - updated_at timestamp with time zone not null + email_verified_at datetime, + created_at datetime not null, + updated_at datetime not null, + deleted_at datetime ); -- +goose Down diff --git a/bootstrap/app/db/migrations/20240610163918_add_sessions_table.sql b/bootstrap/app/db/migrations/20240610163918_add_sessions_table.sql index ae082c8..7a36e68 100644 --- a/bootstrap/app/db/migrations/20240610163918_add_sessions_table.sql +++ b/bootstrap/app/db/migrations/20240610163918_add_sessions_table.sql @@ -5,9 +5,10 @@ create table if not exists sessions( user_id integer not null references users, ip_address text, user_agent text, - expires_at timestamp with time zone not null, - last_login_at timestamp with time zone, - created_at timestamp with time zone not null + expires_at datetime not null, + created_at datetime not null, + updated_at datetime not null, + deleted_at datetime ); -- +goose Down diff --git a/bootstrap/app/types/auth.go b/bootstrap/app/types/auth.go index ee4eb51..c46d9b7 100644 --- a/bootstrap/app/types/auth.go +++ b/bootstrap/app/types/auth.go @@ -2,7 +2,7 @@ package types // AuthUser represents an user that might be authenticated. type AuthUser struct { - ID int + ID uint Email string LoggedIn bool } diff --git a/bootstrap/go.mod b/bootstrap/go.mod index 8e62dfb..16cbd41 100644 --- a/bootstrap/go.mod +++ b/bootstrap/go.mod @@ -3,7 +3,7 @@ module AABBCCDD go 1.22.4 // uncomment for local development on the superkit core. -// replace github.com/anthdm/superkit => ../ +replace github.com/anthdm/superkit => ../ require ( github.com/a-h/templ v0.2.707 @@ -24,10 +24,13 @@ require ( github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/sessions v1.3.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect golang.org/x/sys v0.21.0 // indirect + gorm.io/driver/sqlite v1.5.6 // indirect + gorm.io/gorm v1.25.10 // indirect ) diff --git a/bootstrap/go.sum b/bootstrap/go.sum index 7b7ec23..933a4c0 100644 --- a/bootstrap/go.sum +++ b/bootstrap/go.sum @@ -18,6 +18,8 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -45,3 +47,7 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= +gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= +gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s= +gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/bootstrap/plugins/auth/auth_handler.go b/bootstrap/plugins/auth/auth_handler.go index 3c5d386..9df0f9d 100644 --- a/bootstrap/plugins/auth/auth_handler.go +++ b/bootstrap/plugins/auth/auth_handler.go @@ -13,6 +13,7 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" ) const ( @@ -40,12 +41,9 @@ func HandleLoginCreate(kit *kit.Kit) error { } var user User - err := db.Query.NewSelect(). - Model(&user). - Where("user.email = ?", values.Email). - Scan(kit.Request.Context()) + err := db.Get().Find(&user, "email = ?", values.Email).Error if err != nil { - if err == sql.ErrNoRows { + if err == gorm.ErrRecordNotFound { errors.Add("credentials", "invalid credentials") return kit.Render(LoginForm(values, errors)) } @@ -59,7 +57,7 @@ func HandleLoginCreate(kit *kit.Kit) error { skipVerify := kit.Getenv("SUPERKIT_AUTH_SKIP_VERIFY", "false") if skipVerify != "true" { - if user.EmailVerifiedAt.Equal(time.Time{}) { + if !user.EmailVerifiedAt.Valid { errors.Add("verified", "please verify your email") return kit.Render(LoginForm(values, errors)) } @@ -71,24 +69,17 @@ func HandleLoginCreate(kit *kit.Kit) error { sessionExpiry = 48 } session := Session{ - UserID: user.ID, - Token: uuid.New().String(), - CreatedAt: time.Now(), - LastLoginAt: time.Now(), - ExpiresAt: time.Now().Add(time.Hour * time.Duration(sessionExpiry)), + UserID: user.ID, + Token: uuid.New().String(), + ExpiresAt: time.Now().Add(time.Hour * time.Duration(sessionExpiry)), } - _, err = db.Query.NewInsert(). - Model(&session). - Exec(kit.Request.Context()) - if err != nil { + if err = db.Get().Create(&session).Error; err != nil { return err } - // TODO change this with kit.Getenv sess := kit.GetSession(userSessionName) sess.Values["sessionToken"] = session.Token sess.Save(kit.Request, kit.Response) - redirectURL := kit.Getenv("SUPERKIT_AUTH_REDIRECT_AFTER_LOGIN", "/profile") return kit.Redirect(http.StatusSeeOther, redirectURL) @@ -100,10 +91,7 @@ func HandleLoginDelete(kit *kit.Kit) error { sess.Values = map[any]any{} sess.Save(kit.Request, kit.Response) }() - _, err := db.Query.NewDelete(). - Model((*Session)(nil)). - Where("token = ?", sess.Values["sessionToken"]). - Exec(kit.Request.Context()) + err := db.Get().Delete(&Session{}, "token = ?", sess.Values["sessionToken"]).Error if err != nil { return err } @@ -121,7 +109,7 @@ func HandleEmailVerify(kit *kit.Kit) error { return []byte(os.Getenv("SUPERKIT_SECRET")), nil }, jwt.WithLeeway(5*time.Second)) if err != nil { - return err + return kit.Render(EmailVerificationError("invalid verification token")) } if !token.Valid { return kit.Render(EmailVerificationError("invalid verification token")) @@ -141,23 +129,18 @@ func HandleEmailVerify(kit *kit.Kit) error { } var user User - err = db.Query.NewSelect(). - Model(&user). - Where("id = ?", userID). - Scan(kit.Request.Context()) + err = db.Get().First(&user, userID).Error if err != nil { return err } - if user.EmailVerifiedAt.After(time.Time{}) { + if user.EmailVerifiedAt.Time.After(time.Time{}) { return kit.Render(EmailVerificationError("Email already verified")) } - user.EmailVerifiedAt = time.Now() - _, err = db.Query.NewUpdate(). - Model(&user). - WherePK(). - Exec(kit.Request.Context()) + now := sql.NullTime{Time: time.Now(), Valid: true} + user.EmailVerifiedAt = now + err = db.Get().Save(&user).Error if err != nil { return err } @@ -174,16 +157,14 @@ func AuthenticateUser(kit *kit.Kit) (kit.Auth, error) { } var session Session - err := db.Query.NewSelect(). - Model(&session). - Relation("User"). - Where("session.token = ? AND session.expires_at > ?", token, time.Now()). - Scan(kit.Request.Context()) - if err != nil { + err := db.Get(). + Preload("User"). + Find(&session, "token = ? AND expires_at > ?", token, time.Now()).Error + if err != nil || session.ID == 0 { return auth, nil } - return Auth{ + return Auth{ LoggedIn: true, UserID: session.User.ID, Email: session.User.Email, diff --git a/bootstrap/plugins/auth/profile_handler.go b/bootstrap/plugins/auth/profile_handler.go index c90dae7..88e586d 100644 --- a/bootstrap/plugins/auth/profile_handler.go +++ b/bootstrap/plugins/auth/profile_handler.go @@ -14,7 +14,7 @@ var profileSchema = v.Schema{ } type ProfileFormValues struct { - ID int `form:"id"` + ID uint `form:"id"` FirstName string `form:"firstName"` LastName string `form:"lastName"` Email string @@ -25,11 +25,7 @@ func HandleProfileShow(kit *kit.Kit) error { auth := kit.Auth().(Auth) var user User - err := db.Query.NewSelect(). - Model(&user). - Where("id = ?", auth.UserID). - Scan(kit.Request.Context()) - if err != nil { + if err := db.Get().First(&user, auth.UserID).Error; err != nil { return err } @@ -54,12 +50,12 @@ func HandleProfileUpdate(kit *kit.Kit) error { if auth.UserID != values.ID { return fmt.Errorf("unauthorized request for profile %d", values.ID) } - _, err := db.Query.NewUpdate(). - Model((*User)(nil)). - Set("first_name = ?", values.FirstName). - Set("last_name = ?", values.LastName). + err := db.Get().Model(&User{}). Where("id = ?", auth.UserID). - Exec(kit.Request.Context()) + Updates(&User{ + FirstName: values.FirstName, + LastName: values.LastName, + }).Error if err != nil { return err } diff --git a/bootstrap/plugins/auth/signup_handler.go b/bootstrap/plugins/auth/signup_handler.go index b625267..cb74cdf 100644 --- a/bootstrap/plugins/auth/signup_handler.go +++ b/bootstrap/plugins/auth/signup_handler.go @@ -63,19 +63,15 @@ func HandleResendVerificationCode(kit *kit.Kit) error { } var user User - err = db.Query.NewSelect(). - Model(&user). - Where("id = ?", id). - Scan(kit.Request.Context()) - if err != nil { + if err = db.Get().First(&user, id).Error; err != nil { return kit.Text(http.StatusOK, "An unexpected error occured") } - if user.EmailVerifiedAt.After(time.Time{}) { + if user.EmailVerifiedAt.Time.After(time.Time{}) { return kit.Text(http.StatusOK, "Email already verified!") } - token, err := createVerificationToken(id) + token, err := createVerificationToken(uint(id)) if err != nil { return kit.Text(http.StatusOK, "An unexpected error occured") } @@ -90,7 +86,7 @@ func HandleResendVerificationCode(kit *kit.Kit) error { return kit.Text(http.StatusOK, msg) } -func createVerificationToken(userID int) (string, error) { +func createVerificationToken(userID uint) (string, error) { expiryStr := kit.Getenv("SUPERKIT_AUTH_EMAIL_VERIFICATION_EXPIRY_IN_HOURS", "1") expiry, err := strconv.Atoi(expiryStr) if err != nil { diff --git a/bootstrap/plugins/auth/types.go b/bootstrap/plugins/auth/types.go index 763e665..150cabd 100644 --- a/bootstrap/plugins/auth/types.go +++ b/bootstrap/plugins/auth/types.go @@ -2,10 +2,11 @@ package auth import ( "AABBCCDD/app/db" - "context" + "database/sql" "time" "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" ) // Event name constants @@ -22,7 +23,7 @@ type UserWithVerificationToken struct { } type Auth struct { - UserID int + UserID uint Email string LoggedIn bool } @@ -32,12 +33,13 @@ func (auth Auth) Check() bool { } type User struct { - ID int `bun:"id,pk,autoincrement"` + gorm.Model + Email string FirstName string LastName string PasswordHash string - EmailVerifiedAt time.Time + EmailVerifiedAt sql.NullTime CreatedAt time.Time UpdatedAt time.Time } @@ -52,22 +54,19 @@ func createUserFromFormValues(values SignupFormValues) (User, error) { FirstName: values.FirstName, LastName: values.LastName, PasswordHash: string(hash), - CreatedAt: time.Now(), - UpdatedAt: time.Now(), } - _, err = db.Query.NewInsert().Model(&user).Exec(context.Background()) - return user, err + result := db.Get().Create(&user) + return user, result.Error } type Session struct { - ID int `bun:"id,pk,autoincrement"` - UserID int - Token string - IPAddress string - UserAgent string - ExpiresAt time.Time - LastLoginAt time.Time - CreatedAt time.Time + gorm.Model - User User `bun:"rel:belongs-to,join:user_id=id"` + UserID uint + Token string + IPAddress string + UserAgent string + ExpiresAt time.Time + CreatedAt time.Time + User User } diff --git a/db/db.go b/db/db.go index 6d611ad..02a0df3 100644 --- a/db/db.go +++ b/db/db.go @@ -7,6 +7,7 @@ import ( const ( DriverSqlite3 = "sqlite3" + DriverMysql = "mysql" ) type Config struct { @@ -17,7 +18,7 @@ type Config struct { Password string } -func New(cfg Config) (*sql.DB, error) { +func NewSQL(cfg Config) (*sql.DB, error) { switch cfg.Driver { case DriverSqlite3: name := cfg.Name diff --git a/validate/validate.go b/validate/validate.go index 01ac30d..71c3131 100644 --- a/validate/validate.go +++ b/validate/validate.go @@ -153,12 +153,18 @@ func parseRequest(r *http.Request, v any) error { } case reflect.String: fieldVal.SetString(formValue) - case reflect.Int: + case reflect.Int, reflect.Int32, reflect.Int64: intVal, err := strconv.Atoi(formValue) if err != nil { return fmt.Errorf("failed to parse int: %v", err) } fieldVal.SetInt(int64(intVal)) + case reflect.Uint, reflect.Uint32, reflect.Uint64: + intVal, err := strconv.Atoi(formValue) + if err != nil { + return fmt.Errorf("failed to parse int: %v", err) + } + fieldVal.SetUint(uint64(intVal)) case reflect.Float64: floatVal, err := strconv.ParseFloat(formValue, 64) if err != nil {