changed bun to gorm in auth plugin

This commit is contained in:
anthdm 2024-06-23 10:53:56 +02:00
parent 6028c5d60d
commit 6055fc6426
12 changed files with 109 additions and 97 deletions

View file

@ -5,21 +5,25 @@ import (
"os" "os"
"github.com/anthdm/superkit/db" "github.com/anthdm/superkit/db"
"github.com/anthdm/superkit/kit"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/uptrace/bun" "gorm.io/driver/sqlite"
"github.com/uptrace/bun/dialect/sqlitedialect" "gorm.io/gorm"
"github.com/uptrace/bun/extra/bundebug"
) )
// I could not came up with a better naming for this. // By default this is a pre-configured Gorm DB instance.
// Ideally, app should export a global variable called "DB" // Change this type based on the database package of your likings.
// but this will cause imports cycles for plugins. var dbInstance *gorm.DB
var Query *bun.DB
// Get returns the instantiated DB instance.
func Get() *gorm.DB {
return dbInstance
}
func init() { func init() {
// Create a default *sql.DB exposed by the superkit/db package
// based on the given configuration.
config := db.Config{ config := db.Config{
Driver: os.Getenv("DB_DRIVER"), Driver: os.Getenv("DB_DRIVER"),
Name: os.Getenv("DB_NAME"), Name: os.Getenv("DB_NAME"),
@ -27,12 +31,30 @@ func init() {
User: os.Getenv("DB_USER"), User: os.Getenv("DB_USER"),
Host: os.Getenv("DB_HOST"), Host: os.Getenv("DB_HOST"),
} }
db, err := db.New(config) dbinst, err := db.NewSQL(config)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
Query = bun.NewDB(db, sqlitedialect.New()) // Based on the driver create the corresponding DB instance.
if kit.IsDevelopment() { // By default, the SuperKit boilerplate comes with a pre-configured
Query.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true))) // ORM called Gorm. https://gorm.io.
//
// You can change this to any other DB interaction tool
// of your liking. EG:
// - uptrace bun -> https://bun.uptrace.dev
// - SQLC -> https://github.com/sqlc-dev/sqlc
// - gojet -> https://github.com/go-jet/jet
switch config.Driver {
case db.DriverSqlite3:
dbInstance, err = gorm.Open(sqlite.New(sqlite.Config{
Conn: dbinst,
}))
case db.DriverMysql:
// ...
default:
log.Fatal("invalid driver:", config.Driver)
}
if err != nil {
log.Fatal(err)
} }
} }

View file

@ -5,9 +5,10 @@ create table if not exists users(
password_hash text not null, password_hash text not null,
first_name text not null, first_name text not null,
last_name text not null, last_name text not null,
email_verified_at timestamp with time zone, email_verified_at datetime,
created_at timestamp with time zone not null, created_at datetime not null,
updated_at timestamp with time zone not null updated_at datetime not null,
deleted_at datetime
); );
-- +goose Down -- +goose Down

View file

@ -5,9 +5,10 @@ create table if not exists sessions(
user_id integer not null references users, user_id integer not null references users,
ip_address text, ip_address text,
user_agent text, user_agent text,
expires_at timestamp with time zone not null, expires_at datetime not null,
last_login_at timestamp with time zone, created_at datetime not null,
created_at timestamp with time zone not null updated_at datetime not null,
deleted_at datetime
); );
-- +goose Down -- +goose Down

View file

@ -2,7 +2,7 @@ package types
// AuthUser represents an user that might be authenticated. // AuthUser represents an user that might be authenticated.
type AuthUser struct { type AuthUser struct {
ID int ID uint
Email string Email string
LoggedIn bool LoggedIn bool
} }

View file

@ -3,7 +3,7 @@ module AABBCCDD
go 1.22.4 go 1.22.4
// uncomment for local development on the superkit core. // uncomment for local development on the superkit core.
// replace github.com/anthdm/superkit => ../ replace github.com/anthdm/superkit => ../
require ( require (
github.com/a-h/templ v0.2.707 github.com/a-h/templ v0.2.707
@ -24,10 +24,13 @@ require (
github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect
github.com/gorilla/sessions v1.3.0 // indirect github.com/gorilla/sessions v1.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
golang.org/x/sys v0.21.0 // indirect golang.org/x/sys v0.21.0 // indirect
gorm.io/driver/sqlite v1.5.6 // indirect
gorm.io/gorm v1.25.10 // indirect
) )

View file

@ -18,6 +18,8 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
@ -45,3 +47,7 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE=
gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=

View file

@ -13,6 +13,7 @@ import (
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
const ( const (
@ -40,12 +41,9 @@ func HandleLoginCreate(kit *kit.Kit) error {
} }
var user User var user User
err := db.Query.NewSelect(). err := db.Get().Find(&user, "email = ?", values.Email).Error
Model(&user).
Where("user.email = ?", values.Email).
Scan(kit.Request.Context())
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == gorm.ErrRecordNotFound {
errors.Add("credentials", "invalid credentials") errors.Add("credentials", "invalid credentials")
return kit.Render(LoginForm(values, errors)) return kit.Render(LoginForm(values, errors))
} }
@ -59,7 +57,7 @@ func HandleLoginCreate(kit *kit.Kit) error {
skipVerify := kit.Getenv("SUPERKIT_AUTH_SKIP_VERIFY", "false") skipVerify := kit.Getenv("SUPERKIT_AUTH_SKIP_VERIFY", "false")
if skipVerify != "true" { if skipVerify != "true" {
if user.EmailVerifiedAt.Equal(time.Time{}) { if !user.EmailVerifiedAt.Valid {
errors.Add("verified", "please verify your email") errors.Add("verified", "please verify your email")
return kit.Render(LoginForm(values, errors)) return kit.Render(LoginForm(values, errors))
} }
@ -71,24 +69,17 @@ func HandleLoginCreate(kit *kit.Kit) error {
sessionExpiry = 48 sessionExpiry = 48
} }
session := Session{ session := Session{
UserID: user.ID, UserID: user.ID,
Token: uuid.New().String(), Token: uuid.New().String(),
CreatedAt: time.Now(), ExpiresAt: time.Now().Add(time.Hour * time.Duration(sessionExpiry)),
LastLoginAt: time.Now(),
ExpiresAt: time.Now().Add(time.Hour * time.Duration(sessionExpiry)),
} }
_, err = db.Query.NewInsert(). if err = db.Get().Create(&session).Error; err != nil {
Model(&session).
Exec(kit.Request.Context())
if err != nil {
return err return err
} }
// TODO change this with kit.Getenv
sess := kit.GetSession(userSessionName) sess := kit.GetSession(userSessionName)
sess.Values["sessionToken"] = session.Token sess.Values["sessionToken"] = session.Token
sess.Save(kit.Request, kit.Response) sess.Save(kit.Request, kit.Response)
redirectURL := kit.Getenv("SUPERKIT_AUTH_REDIRECT_AFTER_LOGIN", "/profile") redirectURL := kit.Getenv("SUPERKIT_AUTH_REDIRECT_AFTER_LOGIN", "/profile")
return kit.Redirect(http.StatusSeeOther, redirectURL) return kit.Redirect(http.StatusSeeOther, redirectURL)
@ -100,10 +91,7 @@ func HandleLoginDelete(kit *kit.Kit) error {
sess.Values = map[any]any{} sess.Values = map[any]any{}
sess.Save(kit.Request, kit.Response) sess.Save(kit.Request, kit.Response)
}() }()
_, err := db.Query.NewDelete(). err := db.Get().Delete(&Session{}, "token = ?", sess.Values["sessionToken"]).Error
Model((*Session)(nil)).
Where("token = ?", sess.Values["sessionToken"]).
Exec(kit.Request.Context())
if err != nil { if err != nil {
return err return err
} }
@ -121,7 +109,7 @@ func HandleEmailVerify(kit *kit.Kit) error {
return []byte(os.Getenv("SUPERKIT_SECRET")), nil return []byte(os.Getenv("SUPERKIT_SECRET")), nil
}, jwt.WithLeeway(5*time.Second)) }, jwt.WithLeeway(5*time.Second))
if err != nil { if err != nil {
return err return kit.Render(EmailVerificationError("invalid verification token"))
} }
if !token.Valid { if !token.Valid {
return kit.Render(EmailVerificationError("invalid verification token")) return kit.Render(EmailVerificationError("invalid verification token"))
@ -141,23 +129,18 @@ func HandleEmailVerify(kit *kit.Kit) error {
} }
var user User var user User
err = db.Query.NewSelect(). err = db.Get().First(&user, userID).Error
Model(&user).
Where("id = ?", userID).
Scan(kit.Request.Context())
if err != nil { if err != nil {
return err return err
} }
if user.EmailVerifiedAt.After(time.Time{}) { if user.EmailVerifiedAt.Time.After(time.Time{}) {
return kit.Render(EmailVerificationError("Email already verified")) return kit.Render(EmailVerificationError("Email already verified"))
} }
user.EmailVerifiedAt = time.Now() now := sql.NullTime{Time: time.Now(), Valid: true}
_, err = db.Query.NewUpdate(). user.EmailVerifiedAt = now
Model(&user). err = db.Get().Save(&user).Error
WherePK().
Exec(kit.Request.Context())
if err != nil { if err != nil {
return err return err
} }
@ -174,16 +157,14 @@ func AuthenticateUser(kit *kit.Kit) (kit.Auth, error) {
} }
var session Session var session Session
err := db.Query.NewSelect(). err := db.Get().
Model(&session). Preload("User").
Relation("User"). Find(&session, "token = ? AND expires_at > ?", token, time.Now()).Error
Where("session.token = ? AND session.expires_at > ?", token, time.Now()). if err != nil || session.ID == 0 {
Scan(kit.Request.Context())
if err != nil {
return auth, nil return auth, nil
} }
return Auth{
return Auth{
LoggedIn: true, LoggedIn: true,
UserID: session.User.ID, UserID: session.User.ID,
Email: session.User.Email, Email: session.User.Email,

View file

@ -14,7 +14,7 @@ var profileSchema = v.Schema{
} }
type ProfileFormValues struct { type ProfileFormValues struct {
ID int `form:"id"` ID uint `form:"id"`
FirstName string `form:"firstName"` FirstName string `form:"firstName"`
LastName string `form:"lastName"` LastName string `form:"lastName"`
Email string Email string
@ -25,11 +25,7 @@ func HandleProfileShow(kit *kit.Kit) error {
auth := kit.Auth().(Auth) auth := kit.Auth().(Auth)
var user User var user User
err := db.Query.NewSelect(). if err := db.Get().First(&user, auth.UserID).Error; err != nil {
Model(&user).
Where("id = ?", auth.UserID).
Scan(kit.Request.Context())
if err != nil {
return err return err
} }
@ -54,12 +50,12 @@ func HandleProfileUpdate(kit *kit.Kit) error {
if auth.UserID != values.ID { if auth.UserID != values.ID {
return fmt.Errorf("unauthorized request for profile %d", values.ID) return fmt.Errorf("unauthorized request for profile %d", values.ID)
} }
_, err := db.Query.NewUpdate(). err := db.Get().Model(&User{}).
Model((*User)(nil)).
Set("first_name = ?", values.FirstName).
Set("last_name = ?", values.LastName).
Where("id = ?", auth.UserID). Where("id = ?", auth.UserID).
Exec(kit.Request.Context()) Updates(&User{
FirstName: values.FirstName,
LastName: values.LastName,
}).Error
if err != nil { if err != nil {
return err return err
} }

View file

@ -63,19 +63,15 @@ func HandleResendVerificationCode(kit *kit.Kit) error {
} }
var user User var user User
err = db.Query.NewSelect(). if err = db.Get().First(&user, id).Error; err != nil {
Model(&user).
Where("id = ?", id).
Scan(kit.Request.Context())
if err != nil {
return kit.Text(http.StatusOK, "An unexpected error occured") return kit.Text(http.StatusOK, "An unexpected error occured")
} }
if user.EmailVerifiedAt.After(time.Time{}) { if user.EmailVerifiedAt.Time.After(time.Time{}) {
return kit.Text(http.StatusOK, "Email already verified!") return kit.Text(http.StatusOK, "Email already verified!")
} }
token, err := createVerificationToken(id) token, err := createVerificationToken(uint(id))
if err != nil { if err != nil {
return kit.Text(http.StatusOK, "An unexpected error occured") return kit.Text(http.StatusOK, "An unexpected error occured")
} }
@ -90,7 +86,7 @@ func HandleResendVerificationCode(kit *kit.Kit) error {
return kit.Text(http.StatusOK, msg) return kit.Text(http.StatusOK, msg)
} }
func createVerificationToken(userID int) (string, error) { func createVerificationToken(userID uint) (string, error) {
expiryStr := kit.Getenv("SUPERKIT_AUTH_EMAIL_VERIFICATION_EXPIRY_IN_HOURS", "1") expiryStr := kit.Getenv("SUPERKIT_AUTH_EMAIL_VERIFICATION_EXPIRY_IN_HOURS", "1")
expiry, err := strconv.Atoi(expiryStr) expiry, err := strconv.Atoi(expiryStr)
if err != nil { if err != nil {

View file

@ -2,10 +2,11 @@ package auth
import ( import (
"AABBCCDD/app/db" "AABBCCDD/app/db"
"context" "database/sql"
"time" "time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
// Event name constants // Event name constants
@ -22,7 +23,7 @@ type UserWithVerificationToken struct {
} }
type Auth struct { type Auth struct {
UserID int UserID uint
Email string Email string
LoggedIn bool LoggedIn bool
} }
@ -32,12 +33,13 @@ func (auth Auth) Check() bool {
} }
type User struct { type User struct {
ID int `bun:"id,pk,autoincrement"` gorm.Model
Email string Email string
FirstName string FirstName string
LastName string LastName string
PasswordHash string PasswordHash string
EmailVerifiedAt time.Time EmailVerifiedAt sql.NullTime
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
} }
@ -52,22 +54,19 @@ func createUserFromFormValues(values SignupFormValues) (User, error) {
FirstName: values.FirstName, FirstName: values.FirstName,
LastName: values.LastName, LastName: values.LastName,
PasswordHash: string(hash), PasswordHash: string(hash),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
} }
_, err = db.Query.NewInsert().Model(&user).Exec(context.Background()) result := db.Get().Create(&user)
return user, err return user, result.Error
} }
type Session struct { type Session struct {
ID int `bun:"id,pk,autoincrement"` gorm.Model
UserID int
Token string
IPAddress string
UserAgent string
ExpiresAt time.Time
LastLoginAt time.Time
CreatedAt time.Time
User User `bun:"rel:belongs-to,join:user_id=id"` UserID uint
Token string
IPAddress string
UserAgent string
ExpiresAt time.Time
CreatedAt time.Time
User User
} }

View file

@ -7,6 +7,7 @@ import (
const ( const (
DriverSqlite3 = "sqlite3" DriverSqlite3 = "sqlite3"
DriverMysql = "mysql"
) )
type Config struct { type Config struct {
@ -17,7 +18,7 @@ type Config struct {
Password string Password string
} }
func New(cfg Config) (*sql.DB, error) { func NewSQL(cfg Config) (*sql.DB, error) {
switch cfg.Driver { switch cfg.Driver {
case DriverSqlite3: case DriverSqlite3:
name := cfg.Name name := cfg.Name

View file

@ -153,12 +153,18 @@ func parseRequest(r *http.Request, v any) error {
} }
case reflect.String: case reflect.String:
fieldVal.SetString(formValue) fieldVal.SetString(formValue)
case reflect.Int: case reflect.Int, reflect.Int32, reflect.Int64:
intVal, err := strconv.Atoi(formValue) intVal, err := strconv.Atoi(formValue)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse int: %v", err) return fmt.Errorf("failed to parse int: %v", err)
} }
fieldVal.SetInt(int64(intVal)) fieldVal.SetInt(int64(intVal))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
intVal, err := strconv.Atoi(formValue)
if err != nil {
return fmt.Errorf("failed to parse int: %v", err)
}
fieldVal.SetUint(uint64(intVal))
case reflect.Float64: case reflect.Float64:
floatVal, err := strconv.ParseFloat(formValue, 64) floatVal, err := strconv.ParseFloat(formValue, 64)
if err != nil { if err != nil {