Encapsulate orm behaviour in the Database struct

This commit is contained in:
Andrea Fazzi 2020-01-15 11:27:00 +01:00
parent 2ad7e3d180
commit 3bb8dde670
14 changed files with 308 additions and 304 deletions

View file

@ -35,12 +35,13 @@ type PathPattern struct {
type Handlers struct { type Handlers struct {
Config *config.ConfigT Config *config.ConfigT
Models []interface{} Database *orm.Database
Models []interface{}
Login func(store *sessions.CookieStore, signingKey []byte) http.Handler Login func(db *orm.Database, store *sessions.CookieStore, signingKey []byte) http.Handler
Logout func(store *sessions.CookieStore) http.Handler Logout func(store *sessions.CookieStore) http.Handler
Home func() http.Handler Home func() http.Handler
GetToken func(signingKey []byte) http.Handler GetToken func(db *orm.Database, signingKey []byte) http.Handler
Static func() http.Handler Static func() http.Handler
Recover func(next http.Handler) http.Handler Recover func(next http.Handler) http.Handler
@ -164,10 +165,11 @@ func (h *Handlers) generateModelHandlers(r *mux.Router, model interface{}) {
} }
func NewHandlers(config *config.ConfigT, models []interface{}) *Handlers { func NewHandlers(config *config.ConfigT, db *orm.Database, models []interface{}) *Handlers {
handlers := new(Handlers) handlers := new(Handlers)
handlers.Config = config handlers.Config = config
handlers.Database = db
handlers.CookieStore = sessions.NewCookieStore([]byte(config.Keys.CookieStoreKey)) handlers.CookieStore = sessions.NewCookieStore([]byte(config.Keys.CookieStoreKey))
@ -197,12 +199,12 @@ func NewHandlers(config *config.ConfigT, models []interface{}) *Handlers {
// Authentication // Authentication
r.Handle("/login", handlers.Login(handlers.CookieStore, []byte(config.Keys.JWTSigningKey))) r.Handle("/login", handlers.Login(handlers.Database, handlers.CookieStore, []byte(config.Keys.JWTSigningKey)))
r.Handle("/logout", handlers.Logout(handlers.CookieStore)) r.Handle("/logout", handlers.Logout(handlers.CookieStore))
// School subscription // School subscription
r.Handle("/subscribe", handlers.Login(handlers.CookieStore, []byte(config.Keys.JWTSigningKey))) r.Handle("/subscribe", handlers.Login(handlers.Database, handlers.CookieStore, []byte(config.Keys.JWTSigningKey)))
// Home // Home
@ -216,7 +218,7 @@ func NewHandlers(config *config.ConfigT, models []interface{}) *Handlers {
// Token handling // Token handling
r.Handle("/get_token", handlers.GetToken([]byte(config.Keys.JWTSigningKey))) r.Handle("/get_token", handlers.GetToken(handlers.Database, []byte(config.Keys.JWTSigningKey)))
// Static file server // Static file server
@ -293,7 +295,7 @@ func hasPermission(role, path string) bool {
func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) { func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) {
format := r.URL.Query().Get("format") format := r.URL.Query().Get("format")
getFn, err := orm.GetFunc(pattern.Path(model)) getFn, err := h.Database.GetFunc(pattern.Path(model))
if err != nil { if err != nil {
log.Println("Error:", err) log.Println("Error:", err)
respondWithError(w, r, err) respondWithError(w, r, err)
@ -306,7 +308,7 @@ func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pat
renderer.Render[format](w, r, fmt.Errorf("%s", "Errore di autorizzazione")) renderer.Render[format](w, r, fmt.Errorf("%s", "Errore di autorizzazione"))
} else { } else {
data, err := getFn(mux.Vars(r), w, r) data, err := getFn(h.Database, mux.Vars(r), w, r)
if err != nil { if err != nil {
renderer.Render[format](w, r, err) renderer.Render[format](w, r, err)
} else { } else {
@ -317,14 +319,14 @@ func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pat
} }
func post(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) { func (h *Handlers) post(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) {
var ( var (
data interface{} data interface{}
err error err error
) )
respFormat := renderer.GetContentFormat(r) respFormat := renderer.GetContentFormat(r)
postFn, err := orm.GetFunc(pattern.Path(model)) postFn, err := h.Database.GetFunc(pattern.Path(model))
if err != nil { if err != nil {
respondWithError(w, r, err) respondWithError(w, r, err)
@ -335,7 +337,7 @@ func post(w http.ResponseWriter, r *http.Request, model string, pattern PathPatt
if !hasPermission(role, pattern.Path(model)) { if !hasPermission(role, pattern.Path(model)) {
renderer.Render[respFormat](w, r, fmt.Errorf("%s", "Errore di autorizzazione")) renderer.Render[respFormat](w, r, fmt.Errorf("%s", "Errore di autorizzazione"))
} else { } else {
data, err = postFn(mux.Vars(r), w, r) data, err = postFn(h.Database, mux.Vars(r), w, r)
if err != nil { if err != nil {
respondWithError(w, r, err) respondWithError(w, r, err)
} else if pattern.RedirectPattern != "" { } else if pattern.RedirectPattern != "" {
@ -353,7 +355,7 @@ func post(w http.ResponseWriter, r *http.Request, model string, pattern PathPatt
} }
func delete(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) { func (h *Handlers) delete(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) {
var data interface{} var data interface{}
respFormat := renderer.GetContentFormat(r) respFormat := renderer.GetContentFormat(r)
@ -364,11 +366,11 @@ func delete(w http.ResponseWriter, r *http.Request, model string, pattern PathPa
if !hasPermission(role, pattern.Path(model)) { if !hasPermission(role, pattern.Path(model)) {
renderer.Render[respFormat](w, r, fmt.Errorf("%s", "Errore di autorizzazione")) renderer.Render[respFormat](w, r, fmt.Errorf("%s", "Errore di autorizzazione"))
} else { } else {
postFn, err := orm.GetFunc(pattern.Path(model)) postFn, err := h.Database.GetFunc(pattern.Path(model))
if err != nil { if err != nil {
renderer.Render[r.URL.Query().Get("format")](w, r, err) renderer.Render[r.URL.Query().Get("format")](w, r, err)
} }
data, err = postFn(mux.Vars(r), w, r) data, err = postFn(h.Database, mux.Vars(r), w, r)
if err != nil { if err != nil {
renderer.Render["html"](w, r, err) renderer.Render["html"](w, r, err)
} else if pattern.RedirectPattern != "" { } else if pattern.RedirectPattern != "" {
@ -402,10 +404,10 @@ func (h *Handlers) modelHandler(model string, pattern PathPattern) http.Handler
h.get(w, r, model, pattern) h.get(w, r, model, pattern)
case "POST": case "POST":
post(w, r, model, pattern) h.post(w, r, model, pattern)
case "DELETE": case "DELETE":
delete(w, r, model, pattern) h.delete(w, r, model, pattern)
} }
} }

View file

@ -59,14 +59,14 @@ func DefaultLogoutHandler(store *sessions.CookieStore) http.Handler {
return http.HandlerFunc(fn) return http.HandlerFunc(fn)
} }
func DefaultLoginHandler(store *sessions.CookieStore, signingKey []byte) http.Handler { func DefaultLoginHandler(db *orm.Database, store *sessions.CookieStore, signingKey []byte) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) { fn := func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" { if r.Method == "GET" {
renderer.Render["html"](w, r, nil, r.URL.Query()) renderer.Render["html"](w, r, nil, r.URL.Query())
} }
if r.Method == "POST" { if r.Method == "POST" {
r.ParseForm() r.ParseForm()
token, err := getToken(r.FormValue("username"), r.FormValue("password"), signingKey) token, err := getToken(db, r.FormValue("username"), r.FormValue("password"), signingKey)
if err != nil { if err != nil {
http.Redirect(w, r, "/login?tpl_layout=login&tpl_content=login&failed=true", http.StatusSeeOther) http.Redirect(w, r, "/login?tpl_layout=login&tpl_content=login&failed=true", http.StatusSeeOther)
} else { } else {
@ -87,9 +87,7 @@ func DefaultLoginHandler(store *sessions.CookieStore, signingKey []byte) http.Ha
// FIXME: This is an hack for fast prototyping: users should have // FIXME: This is an hack for fast prototyping: users should have
// their own table on DB. // their own table on DB.
func checkCredential(username string, password string) (*UserToken, error) { func checkCredential(db *orm.Database, username string, password string) (*UserToken, error) {
var user orm.User
// Check if user is the administrator // Check if user is the administrator
if username == config.Config.Admin.Username && password == config.Config.Admin.Password { if username == config.Config.Admin.Username && password == config.Config.Admin.Password {
@ -104,31 +102,31 @@ func checkCredential(username string, password string) (*UserToken, error) {
var token *UserToken var token *UserToken
if err := orm.DB().Where("username = ? AND password = ?", username, password).First(&user).Error; err != nil { user, err := db.GetUser(username, password)
if err != nil {
return nil, errors.New("Authentication failed!") return nil, errors.New("Authentication failed!")
} else { }
switch user.Role { switch user.Role {
case "participant": case "participant":
var participant orm.Participant var participant orm.Participant
if err := orm.DB().First(&participant, &orm.Participant{UserID: user.ID}).Error; err != nil { if err := db.DB().First(&participant, &orm.Participant{UserID: user.ID}).Error; err != nil {
return nil, errors.New("Authentication failed!") return nil, errors.New("Authentication failed!")
}
token = &UserToken{username, false, user.Role, strconv.Itoa(int(participant.ID))}
case "school":
var school orm.School
if err := orm.DB().First(&school, &orm.School{UserID: user.ID}).Error; err != nil {
return nil, errors.New("Authentication failed!")
}
token = &UserToken{username, false, user.Role, strconv.Itoa(int(school.ID))}
} }
token = &UserToken{username, false, user.Role, strconv.Itoa(int(participant.ID))}
case "school":
var school orm.School
if err := db.DB().First(&school, &orm.School{UserID: user.ID}).Error; err != nil {
return nil, errors.New("Authentication failed!")
}
token = &UserToken{username, false, user.Role, strconv.Itoa(int(school.ID))}
} }
return token, nil return token, nil
} }
// FIXME: Refactor the functions above please!!! // FIXME: Refactor the functions above please!!!
func getToken(username string, password string, signingKey []byte) ([]byte, error) { func getToken(db *orm.Database, username string, password string, signingKey []byte) ([]byte, error) {
user, err := checkCredential(username, password) user, err := checkCredential(db, username, password)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -154,11 +152,11 @@ func getToken(username string, password string, signingKey []byte) ([]byte, erro
} }
// FIXME: Refactor the functions above please!!! // FIXME: Refactor the functions above please!!!
func DefaultGetTokenHandler(signingKey []byte) http.Handler { func DefaultGetTokenHandler(db *orm.Database, signingKey []byte) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) { fn := func(w http.ResponseWriter, r *http.Request) {
username, password, _ := r.BasicAuth() username, password, _ := r.BasicAuth()
token, err := getToken(username, password, signingKey) token, err := getToken(db, username, password, signingKey)
if err != nil { if err != nil {
panic(err) panic(err)
} }

22
main.go
View file

@ -2,14 +2,12 @@ package main
import ( import (
"flag" "flag"
"fmt"
"log" "log"
"net/http" "net/http"
"os" "os"
"time" "time"
"github.com/gorilla/handlers" "github.com/gorilla/handlers"
"github.com/jinzhu/gorm"
"git.andreafazzi.eu/andrea/oef/config" "git.andreafazzi.eu/andrea/oef/config"
oef_handlers "git.andreafazzi.eu/andrea/oef/handlers" oef_handlers "git.andreafazzi.eu/andrea/oef/handlers"
@ -22,7 +20,7 @@ const (
) )
var ( var (
db *gorm.DB db *orm.Database
err error err error
models = []interface{}{ models = []interface{}{
@ -54,7 +52,7 @@ func main() {
wait := true wait := true
for wait && count > 0 { for wait && count > 0 {
db, err = orm.New(fmt.Sprintf("%s?%s", config.Config.Orm.Connection, config.Config.Orm.Options)) db, err = orm.NewDatabase(config.Config, models)
if err != nil { if err != nil {
count-- count--
log.Println(err) log.Println(err)
@ -65,26 +63,22 @@ func main() {
wait = false wait = false
} }
orm.Use(db) // REMOVE
// orm.Use(db)
if config.Config.Orm.AutoMigrate { if config.Config.Orm.AutoMigrate {
log.Print("Automigrating...") log.Print("Automigrating...")
orm.AutoMigrate(models...) db.AutoMigrate()
} }
log.Println("Eventually write categories on DB...") log.Println("Eventually write categories on DB...")
orm.CreateCategories() orm.CreateCategories(db)
log.Println("Eventually write regions on DB...") log.Println("Eventually write regions on DB...")
orm.CreateRegions() orm.CreateRegions(db)
log.Println("Map models <-> handlers")
if err := orm.MapHandlers(models); err != nil {
panic(err)
}
log.Println("OEF is listening to port 3000...") log.Println("OEF is listening to port 3000...")
if err := http.ListenAndServe(":3000", handlers.LoggingHandler(os.Stdout, oef_handlers.NewHandlers(config.Config, models).Router)); err != nil { if err := http.ListenAndServe(":3000", handlers.LoggingHandler(os.Stdout, oef_handlers.NewHandlers(config.Config, db, models).Router)); err != nil {
panic(err) panic(err)
} }

View file

@ -26,10 +26,10 @@ func (a *Answer) String() string {
return a.Text return a.Text
} }
func (a *Answer) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (a *Answer) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
answer := new(Answer) answer := new(Answer)
if err := DB().Find(&answer.AllQuestions).Error; err != nil { if err := db._db.Find(&answer.AllQuestions).Error; err != nil {
return nil, err return nil, err
} }
@ -40,7 +40,7 @@ func (a *Answer) Create(args map[string]string, w http.ResponseWriter, r *http.R
if err != nil { if err != nil {
return nil, err return nil, err
} }
answer, err = CreateAnswer(answer) answer, err = CreateAnswer(db, answer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -48,36 +48,36 @@ func (a *Answer) Create(args map[string]string, w http.ResponseWriter, r *http.R
} }
} }
func (a *Answer) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (a *Answer) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var answer Answer var answer Answer
id := args["id"] id := args["id"]
if err := DB().Preload("Question").First(&answer, id).Error; err != nil { if err := db._db.Preload("Question").First(&answer, id).Error; err != nil {
return nil, err return nil, err
} }
return &answer, nil return &answer, nil
} }
func (a *Answer) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (a *Answer) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var answers []*Answer var answers []*Answer
if err := DB().Preload("Question").Order("created_at").Find(&answers).Error; err != nil { if err := db._db.Preload("Question").Order("created_at").Find(&answers).Error; err != nil {
return nil, err return nil, err
} }
return answers, nil return answers, nil
} }
func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (a *Answer) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
result, err := a.Read(args, w, r) result, err := a.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
answer := result.(*Answer) answer := result.(*Answer)
if err := DB().Find(&answer.AllQuestions).Error; err != nil { if err := db._db.Find(&answer.AllQuestions).Error; err != nil {
return nil, err return nil, err
} }
@ -86,7 +86,7 @@ func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.R
return answer, nil return answer, nil
} else { } else {
answer, err := a.Read(args, w, r) answer, err := a.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -98,11 +98,11 @@ func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.R
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = SaveAnswer(answer) _, err = SaveAnswer(db, answer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
answer, err = a.Read(args, w, r) answer, err = a.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -110,26 +110,26 @@ func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.R
} }
} }
func (model *Answer) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Answer) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
answer, err := model.Read(args, w, r) answer, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := DB().Unscoped().Delete(answer.(*Answer)).Error; err != nil { if err := db._db.Unscoped().Delete(answer.(*Answer)).Error; err != nil {
return nil, err return nil, err
} }
return answer.(*Answer), nil return answer.(*Answer), nil
} }
func CreateAnswer(answer *Answer) (*Answer, error) { func CreateAnswer(db *Database, answer *Answer) (*Answer, error) {
if err := DB().Create(answer).Error; err != nil { if err := db._db.Create(answer).Error; err != nil {
return nil, err return nil, err
} }
return answer, nil return answer, nil
} }
func SaveAnswer(answer interface{}) (interface{}, error) { func SaveAnswer(db *Database, answer interface{}) (interface{}, error) {
if err := DB().Omit("Answers").Save(answer).Error; err != nil { if err := db._db.Omit("Answers").Save(answer).Error; err != nil {
return nil, err return nil, err
} }
return answer, nil return answer, nil

View file

@ -23,10 +23,10 @@ func (model *Category) String() string {
return model.Name return model.Name
} }
func (model *Category) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Category) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
category := new(Category) category := new(Category)
// if err := DB().Find(&category.AllContests).Error; err != nil { // if err := db._db.Find(&category.AllContests).Error; err != nil {
// return nil, err // return nil, err
// } // }
return category, nil return category, nil
@ -36,7 +36,7 @@ func (model *Category) Create(args map[string]string, w http.ResponseWriter, r *
if err != nil { if err != nil {
return nil, err return nil, err
} }
category, err = CreateCategory(category) category, err = CreateCategory(db, category)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -44,36 +44,36 @@ func (model *Category) Create(args map[string]string, w http.ResponseWriter, r *
} }
} }
func (model *Category) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Category) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var category Category var category Category
id := args["id"] id := args["id"]
if err := DB(). /*.Preload("Something")*/ First(&category, id).Error; err != nil { if err := db._db. /*.Preload("Something")*/ First(&category, id).Error; err != nil {
return nil, err return nil, err
} }
return &category, nil return &category, nil
} }
func (model *Category) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Category) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var categories []*Category var categories []*Category
if err := DB(). /*.Preload("Something")*/ Order("created_at").Find(&categories).Error; err != nil { if err := db._db. /*.Preload("Something")*/ Order("created_at").Find(&categories).Error; err != nil {
return nil, err return nil, err
} }
return categories, nil return categories, nil
} }
func (model *Category) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Category) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
result, err := model.Read(args, w, r) result, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
category := result.(*Category) category := result.(*Category)
// if err := DB().Find(&category.AllElements).Error; err != nil { // if err := db._db.Find(&category.AllElements).Error; err != nil {
// return nil, err // return nil, err
// } // }
@ -82,7 +82,7 @@ func (model *Category) Update(args map[string]string, w http.ResponseWriter, r *
return category, nil return category, nil
} else { } else {
category, err := model.Read(args, w, r) category, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -90,11 +90,11 @@ func (model *Category) Update(args map[string]string, w http.ResponseWriter, r *
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = SaveCategory(category) _, err = SaveCategory(db, category)
if err != nil { if err != nil {
return nil, err return nil, err
} }
category, err = model.Read(args, w, r) category, err = model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -102,26 +102,26 @@ func (model *Category) Update(args map[string]string, w http.ResponseWriter, r *
} }
} }
func (model *Category) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Category) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
category, err := model.Read(args, w, r) category, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := DB().Unscoped().Delete(category.(*Category)).Error; err != nil { if err := db._db.Unscoped().Delete(category.(*Category)).Error; err != nil {
return nil, err return nil, err
} }
return category.(*Category), nil return category.(*Category), nil
} }
func CreateCategory(category *Category) (*Category, error) { func CreateCategory(db *Database, category *Category) (*Category, error) {
if err := DB().Create(category).Error; err != nil { if err := db._db.Create(category).Error; err != nil {
return nil, err return nil, err
} }
return category, nil return category, nil
} }
func SaveCategory(category interface{}) (interface{}, error) { func SaveCategory(db *Database, category interface{}) (interface{}, error) {
if err := DB(). /*.Omit("Something")*/ Save(category).Error; err != nil { if err := db._db. /*.Omit("Something")*/ Save(category).Error; err != nil {
return nil, err return nil, err
} }
return category, nil return category, nil

View file

@ -39,7 +39,7 @@ func (c *Contest) String() string {
return c.Name return c.Name
} }
func (c *Contest) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (c *Contest) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
return nil, nil return nil, nil
} else { } else {
@ -68,7 +68,7 @@ func (c *Contest) Create(args map[string]string, w http.ResponseWriter, r *http.
if err != nil { if err != nil {
return nil, err return nil, err
} }
contest, err = CreateContest(contest) contest, err = CreateContest(db, contest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -76,25 +76,25 @@ func (c *Contest) Create(args map[string]string, w http.ResponseWriter, r *http.
} }
} }
func (c *Contest) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (c *Contest) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var contest Contest var contest Contest
id := args["id"] id := args["id"]
if err := DB().Preload("Participants").Preload("Questions").First(&contest, id).Error; err != nil { if err := db._db.Preload("Participants").Preload("Questions").First(&contest, id).Error; err != nil {
return nil, err return nil, err
} }
return &contest, nil return &contest, nil
} }
func (c *Contest) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (c *Contest) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var contests []*Contest var contests []*Contest
claims := r.Context().Value("user").(*jwt.Token).Claims.(jwt.MapClaims) claims := r.Context().Value("user").(*jwt.Token).Claims.(jwt.MapClaims)
if claims["admin"].(bool) { if claims["admin"].(bool) {
if err := DB().Order("created_at").Find(&contests).Error; err != nil { if err := db._db.Order("created_at").Find(&contests).Error; err != nil {
return nil, err return nil, err
} else { } else {
return contests, nil return contests, nil
@ -103,16 +103,16 @@ func (c *Contest) ReadAll(args map[string]string, w http.ResponseWriter, r *http
participant := &Participant{} participant := &Participant{}
if err := DB().Preload("Contests").Where("username = ?", claims["name"].(string)).First(&participant).Error; err != nil { if err := db._db.Preload("Contests").Where("username = ?", claims["name"].(string)).First(&participant).Error; err != nil {
return nil, err return nil, err
} else { } else {
return participant.Contests, nil return participant.Contests, nil
} }
} }
func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (c *Contest) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
result, err := c.Read(args, w, r) result, err := c.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -120,7 +120,7 @@ func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http.
return contest, nil return contest, nil
} else { } else {
contest, err := c.Read(args, w, r) contest, err := c.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -148,11 +148,11 @@ func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http.
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = SaveContest(contest) _, err = SaveContest(db, contest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
contest, err = c.Read(args, w, r) contest, err = c.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -160,26 +160,26 @@ func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http.
} }
} }
func (model *Contest) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Contest) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
contest, err := model.Read(args, w, r) contest, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := DB().Unscoped().Delete(contest.(*Contest)).Error; err != nil { if err := db._db.Unscoped().Delete(contest.(*Contest)).Error; err != nil {
return nil, err return nil, err
} }
return contest.(*Contest), nil return contest.(*Contest), nil
} }
func CreateContest(contest *Contest) (*Contest, error) { func CreateContest(db *Database, contest *Contest) (*Contest, error) {
if err := DB().Create(contest).Error; err != nil { if err := db._db.Create(contest).Error; err != nil {
return nil, err return nil, err
} }
return contest, nil return contest, nil
} }
func SaveContest(contest interface{}) (interface{}, error) { func SaveContest(db *Database, contest interface{}) (interface{}, error) {
if err := DB().Omit("Contests").Save(contest).Error; err != nil { if err := db._db.Omit("Contests").Save(contest).Error; err != nil {
return nil, err return nil, err
} }
return contest, nil return contest, nil
@ -189,13 +189,13 @@ func (c *Contest) isAlwaysActive() bool {
return c.StartTime.IsZero() || c.EndTime.IsZero() || c.Duration == 0 return c.StartTime.IsZero() || c.EndTime.IsZero() || c.Duration == 0
} }
func (c *Contest) generateQuestionsOrder() (string, error) { func (c *Contest) generateQuestionsOrder(db *gorm.DB) (string, error) {
var ( var (
order []string order []string
questions []*Question questions []*Question
) )
if err := DB().Find(&questions, Question{ContestID: c.ID}).Error; err != nil { if err := db.Find(&questions, Question{ContestID: c.ID}).Error; err != nil {
return "", err return "", err
} }

View file

@ -1,6 +1,7 @@
package orm package orm
import ( import (
"errors"
"fmt" "fmt"
"net/http" "net/http"
"path" "path"
@ -18,51 +19,60 @@ type IDer interface {
GetID() uint GetID() uint
} }
type GetFn func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error) type GetFn func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
type Database struct { type Database struct {
Config *config.ConfigT Config *config.ConfigT
_db *gorm.DB models []interface{}
fns map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error) _db *gorm.DB
fns map[string]func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
} }
func NewDatabase(config *config.ConfigT) (*Database, error) { func NewDatabase(config *config.ConfigT, models []interface{}) (*Database, error) {
var err error
db := new(Database) db := new(Database)
db.fns = make(db*gorm.DB, map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0) db.fns = make(map[string]func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0)
db._db, err := gorm.Open("mysql", fmt.Sprintf("%s?%s", config.Config.Orm.Connection, config.Config.Orm.Options)) db.mapHandlers(models)
db._db, err = gorm.Open("mysql", fmt.Sprintf("%s?%s", config.Orm.Connection, config.Orm.Options))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return db return db, err
} }
func (db *Database) AutoMigrate(models ...interface{}) { func (db *Database) AutoMigrate() {
if err := db._db.AutoMigrate(models...).Error; err != nil { if err := db._db.AutoMigrate(db.models...).Error; err != nil {
panic(err) panic(err)
} }
} }
func CreateCategories() { func (db *Database) GetUser(username, password string) (*User, error) {
var user *User
if err := db._db.Where("username = ? AND password = ?", username, password).First(&user).Error; err != nil {
return nil, errors.New("Authentication failed!")
}
return user, nil
}
func (db *Database) DB() *gorm.DB {
return db._db
}
func CreateCategories(db *Database) {
for _, name := range categories { for _, name := range categories {
var category Category var category Category
if err := currDB.FirstOrCreate(&category, Category{Name: name}).Error; err != nil { if err := db._db.FirstOrCreate(&category, Category{Name: name}).Error; err != nil {
panic(err) panic(err)
} }
} }
} }
func Use(db *gorm.DB) { func (db *Database) mapHandlers(models []interface{}) error {
currDB = db
}
func DB() *gorm.DB {
return currDB
}
func MapHandlers(models []interface{}) error {
for _, model := range models { for _, model := range models {
name := inflection.Plural(strings.ToLower(ModelName(model))) name := inflection.Plural(strings.ToLower(ModelName(model)))
for p, action := range map[string]string{ for p, action := range map[string]string{
@ -80,7 +90,7 @@ func MapHandlers(models []interface{}) error {
if strings.HasSuffix(p, "/") { if strings.HasSuffix(p, "/") {
joinedPath += "/" joinedPath += "/"
} }
fns[joinedPath] = method.Interface().(func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)) db.fns[joinedPath] = method.Interface().(func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error))
} }
} }
@ -88,19 +98,19 @@ func MapHandlers(models []interface{}) error {
return nil return nil
} }
func GetFunc(path string) (GetFn, error) { func (db *Database) GetFunc(path string) (GetFn, error) {
fn, ok := fns[path] fn, ok := db.fns[path]
if !ok { if !ok {
return nil, fmt.Errorf("Can't map path %s to any model methods.", path) return nil, fmt.Errorf("Can't map path %s to any model methods.", path)
} }
return fn, nil return fn, nil
} }
func GetNothing(args map[string]string) (interface{}, error) { func GetNothing(db *Database, args map[string]string) (interface{}, error) {
return nil, nil return nil, nil
} }
func PostNothing(args map[string]string, w http.ResponseWriter, r *http.Request) (IDer, error) { func PostNothing(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (IDer, error) {
return nil, nil return nil, nil
} }

View file

@ -92,9 +92,9 @@ func (model *Participant) String() string {
// return nil // return nil
// } // }
func (model *Participant) exists() (*User, error) { func (model *Participant) exists(db *Database) (*User, error) {
var user User var user User
if err := DB().First(&user, &User{Username: model.username()}).Error; err != nil && err != gorm.ErrRecordNotFound { if err := db._db.First(&user, &User{Username: model.username()}).Error; err != nil && err != gorm.ErrRecordNotFound {
return nil, err return nil, err
} else if err == gorm.ErrRecordNotFound { } else if err == gorm.ErrRecordNotFound {
return nil, nil return nil, nil
@ -129,7 +129,7 @@ func (model *Participant) AfterSave(tx *gorm.DB) error {
return err return err
} }
order, err := contest.generateQuestionsOrder() order, err := contest.generateQuestionsOrder(tx)
if err != nil { if err != nil {
return err return err
} }
@ -148,21 +148,21 @@ func (model *Participant) AfterDelete(tx *gorm.DB) error {
return nil return nil
} }
func (model *Participant) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Participant) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
participant := new(Participant) participant := new(Participant)
if isSchool(r) { if isSchool(r) {
if err := DB().Find(&participant.AllCategories).Error; err != nil { if err := db._db.Find(&participant.AllCategories).Error; err != nil {
return nil, err return nil, err
} }
} else { } else {
if err := DB().Find(&participant.AllCategories).Error; err != nil { if err := db._db.Find(&participant.AllCategories).Error; err != nil {
return nil, err return nil, err
} }
if err := DB().Find(&participant.AllContests).Error; err != nil { if err := db._db.Find(&participant.AllContests).Error; err != nil {
return nil, err return nil, err
} }
if err := DB().Find(&participant.AllSchools).Error; err != nil { if err := db._db.Find(&participant.AllSchools).Error; err != nil {
return nil, err return nil, err
} }
} }
@ -175,8 +175,8 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter,
} }
// Check if participant exists // Check if participant exists
if user, err := participant.exists(); err == nil && user != nil { if user, err := participant.exists(db); err == nil && user != nil {
if err := DB().Where("user_id = ?", user.ID).Find(&participant).Error; err != nil { if err := db._db.Where("user_id = ?", user.ID).Find(&participant).Error; err != nil {
return nil, err return nil, err
} }
// err := setFlashMessage(w, r, "participantExists") // err := setFlashMessage(w, r, "participantExists")
@ -199,10 +199,10 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter,
// Check if a participant of the same category exists // Check if a participant of the same category exists
var school School var school School
if err := DB().First(&school, participant.SchoolID).Error; err != nil { if err := db._db.First(&school, participant.SchoolID).Error; err != nil {
return nil, err return nil, err
} }
hasCategory, err := school.HasCategory(participant) hasCategory, err := school.HasCategory(db, participant)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -212,20 +212,20 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter,
participant.UserModifierCreate = NewUserModifierCreate(r) participant.UserModifierCreate = NewUserModifierCreate(r)
participant, err = CreateParticipant(participant) participant, err = CreateParticipant(db, participant)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var response Response var response Response
if err := DB().First(&response, &Response{ParticipantID: participant.ID}).Error; err != nil { if err := db._db.First(&response, &Response{ParticipantID: participant.ID}).Error; err != nil {
return nil, err return nil, err
} }
response.UserModifierCreate = NewUserModifierCreate(r) response.UserModifierCreate = NewUserModifierCreate(r)
if err := DB().Save(&response).Error; err != nil { if err := db._db.Save(&response).Error; err != nil {
return nil, err return nil, err
} }
@ -233,14 +233,14 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter,
} }
} }
func (model *Participant) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Participant) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var participant Participant var participant Participant
id := args["id"] id := args["id"]
// School user can access to its participants only! // School user can access to its participants only!
if isSchool(r) { if isSchool(r) {
if err := DB().Preload("School").First(&participant, id).Error; err != nil { if err := db._db.Preload("School").First(&participant, id).Error; err != nil {
return nil, err return nil, err
} }
@ -249,12 +249,12 @@ func (model *Participant) Read(args map[string]string, w http.ResponseWriter, r
return nil, errors.NotAuthorized return nil, errors.NotAuthorized
} }
if err := DB().Preload("User").Preload("School").Preload("Category").First(&participant, id).Error; err != nil { if err := db._db.Preload("User").Preload("School").Preload("Category").First(&participant, id).Error; err != nil {
return nil, err return nil, err
} }
} else { } else {
if err := DB().Preload("User").Preload("School").Preload("Responses").Preload("Contests").Preload("Category").First(&participant, id).Error; err != nil { if err := db._db.Preload("User").Preload("School").Preload("Responses").Preload("Contests").Preload("Category").First(&participant, id).Error; err != nil {
return nil, err return nil, err
} }
} }
@ -262,7 +262,7 @@ func (model *Participant) Read(args map[string]string, w http.ResponseWriter, r
return &participant, nil return &participant, nil
} }
func (model *Participant) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Participant) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var participants []*Participant var participants []*Participant
// School user can access to its participants only! // School user can access to its participants only!
@ -271,20 +271,20 @@ func (model *Participant) ReadAll(args map[string]string, w http.ResponseWriter,
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := DB().Preload("Category").Preload("School").Preload("Contests").Order("lastname").Find(&participants, &Participant{SchoolID: uint(schoolId)}).Error; err != nil { if err := db._db.Preload("Category").Preload("School").Preload("Contests").Order("lastname").Find(&participants, &Participant{SchoolID: uint(schoolId)}).Error; err != nil {
return nil, err return nil, err
} }
} else { } else {
if err := DB().Preload("School").Preload("Contests").Preload("Responses").Order("created_at").Find(&participants).Error; err != nil { if err := db._db.Preload("School").Preload("Contests").Preload("Responses").Order("created_at").Find(&participants).Error; err != nil {
return nil, err return nil, err
} }
} }
return participants, nil return participants, nil
} }
func (model *Participant) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Participant) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
result, err := model.Read(args, w, r) result, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -292,21 +292,21 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
participant := result.(*Participant) participant := result.(*Participant)
if isSchool(r) { if isSchool(r) {
if err := DB().Find(&participant.AllCategories).Error; err != nil { if err := db._db.Find(&participant.AllCategories).Error; err != nil {
return nil, err return nil, err
} }
participant.SelectedCategory = make(map[uint]string) participant.SelectedCategory = make(map[uint]string)
participant.SelectedCategory[participant.CategoryID] = "selected" participant.SelectedCategory[participant.CategoryID] = "selected"
} else { } else {
if err := DB().Find(&participant.AllCategories).Error; err != nil { if err := db._db.Find(&participant.AllCategories).Error; err != nil {
return nil, err return nil, err
} }
participant.SelectedCategory = make(map[uint]string) participant.SelectedCategory = make(map[uint]string)
participant.SelectedCategory[participant.CategoryID] = "selected" participant.SelectedCategory[participant.CategoryID] = "selected"
if err := DB().Find(&participant.AllContests).Error; err != nil { if err := db._db.Find(&participant.AllContests).Error; err != nil {
return nil, err return nil, err
} }
@ -315,7 +315,7 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
participant.SelectedContest[c.ID] = "selected" participant.SelectedContest[c.ID] = "selected"
} }
if err := DB().Find(&participant.AllSchools).Error; err != nil { if err := db._db.Find(&participant.AllSchools).Error; err != nil {
return nil, err return nil, err
} }
@ -324,7 +324,7 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
} }
return participant, nil return participant, nil
} else { } else {
participant, err := model.Read(args, w, r) participant, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -333,7 +333,7 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
return nil, err return nil, err
} }
if user, err := participant.(*Participant).exists(); err == nil && user != nil { if user, err := participant.(*Participant).exists(db); err == nil && user != nil {
if user.ID != participant.(*Participant).UserID { if user.ID != participant.(*Participant).UserID {
// err := setFlashMessage(w, r, "participantExists") // err := setFlashMessage(w, r, "participantExists")
// if err != nil { // if err != nil {
@ -347,10 +347,10 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
// Check if a participant of the same category exists // Check if a participant of the same category exists
var school School var school School
if err := DB().First(&school, participant.(*Participant).SchoolID).Error; err != nil { if err := db._db.First(&school, participant.(*Participant).SchoolID).Error; err != nil {
return nil, err return nil, err
} }
hasCategory, err := school.HasCategory(participant.(*Participant)) hasCategory, err := school.HasCategory(db, participant.(*Participant))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -358,35 +358,35 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
return nil, errors.CategoryExists return nil, errors.CategoryExists
} }
if err := DB().Where(participant.(*Participant).ContestIDs).Find(&participant.(*Participant).Contests).Error; err != nil { if err := db._db.Where(participant.(*Participant).ContestIDs).Find(&participant.(*Participant).Contests).Error; err != nil {
return nil, err return nil, err
} }
participant.(*Participant).UserModifierUpdate = NewUserModifierUpdate(r) participant.(*Participant).UserModifierUpdate = NewUserModifierUpdate(r)
_, err = SaveParticipant(participant) _, err = SaveParticipant(db, participant)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := DB().Model(participant).Association("Contests").Replace(participant.(*Participant).Contests).Error; err != nil { if err := db._db.Model(participant).Association("Contests").Replace(participant.(*Participant).Contests).Error; err != nil {
return nil, err return nil, err
} }
participant, err = model.Read(args, w, r) participant, err = model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var response Response var response Response
if err := DB().First(&response, &Response{ParticipantID: participant.(*Participant).ID}).Error; err != nil { if err := db._db.First(&response, &Response{ParticipantID: participant.(*Participant).ID}).Error; err != nil {
return nil, err return nil, err
} }
response.UserModifierUpdate = NewUserModifierUpdate(r) response.UserModifierUpdate = NewUserModifierUpdate(r)
if err := DB().Save(&response).Error; err != nil { if err := db._db.Save(&response).Error; err != nil {
return nil, err return nil, err
} }
@ -394,31 +394,31 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
} }
} }
func (model *Participant) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Participant) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
participant, err := model.Read(args, w, r) participant, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := DB().Unscoped().Delete(participant.(*Participant)).Error; err != nil { if err := db._db.Unscoped().Delete(participant.(*Participant)).Error; err != nil {
return nil, err return nil, err
} }
return participant.(*Participant), nil return participant.(*Participant), nil
} }
func CreateParticipant(participant *Participant) (*Participant, error) { func CreateParticipant(db *Database, participant *Participant) (*Participant, error) {
if err := DB().Where([]uint(participant.ContestIDs)).Find(&participant.Contests).Error; err != nil { if err := db._db.Where([]uint(participant.ContestIDs)).Find(&participant.Contests).Error; err != nil {
return nil, err return nil, err
} }
if err := DB().Create(participant).Error; err != nil { if err := db._db.Create(participant).Error; err != nil {
return nil, err return nil, err
} }
return participant, nil return participant, nil
} }
func SaveParticipant(participant interface{}) (interface{}, error) { func SaveParticipant(db *Database, participant interface{}) (interface{}, error) {
participant.(*Participant).FiscalCode = strings.ToUpper(participant.(*Participant).FiscalCode) participant.(*Participant).FiscalCode = strings.ToUpper(participant.(*Participant).FiscalCode)
if err := DB().Omit("Category", "School").Save(participant).Error; err != nil { if err := db._db.Omit("Category", "School").Save(participant).Error; err != nil {
return nil, err return nil, err
} }
return participant, nil return participant, nil

View file

@ -33,10 +33,10 @@ func (q *Question) BeforeCreate(tx *gorm.DB) error {
return nil return nil
} }
func (q *Question) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (q *Question) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
question := new(Question) question := new(Question)
if err := DB().Find(&question.AllContests).Error; err != nil { if err := db._db.Find(&question.AllContests).Error; err != nil {
return nil, err return nil, err
} }
return question, nil return question, nil
@ -46,7 +46,7 @@ func (q *Question) Create(args map[string]string, w http.ResponseWriter, r *http
if err != nil { if err != nil {
return nil, err return nil, err
} }
question, err = CreateQuestion(question) question, err = CreateQuestion(db, question)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -54,12 +54,12 @@ func (q *Question) Create(args map[string]string, w http.ResponseWriter, r *http
} }
} }
func (q *Question) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (q *Question) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var question Question var question Question
id := args["id"] id := args["id"]
if err := DB().Preload("Answers", func(db *gorm.DB) *gorm.DB { if err := db._db.Preload("Answers", func(db *gorm.DB) *gorm.DB {
return db.Order("answers.correct DESC") return db.Order("answers.correct DESC")
}).Preload("Contest").First(&question, id).Error; err != nil { }).Preload("Contest").First(&question, id).Error; err != nil {
return nil, err return nil, err
@ -68,24 +68,24 @@ func (q *Question) Read(args map[string]string, w http.ResponseWriter, r *http.R
return &question, nil return &question, nil
} }
func (q *Question) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (q *Question) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var questions []*Question var questions []*Question
if err := DB().Preload("Contest").Order("created_at").Find(&questions).Error; err != nil { if err := db._db.Preload("Contest").Order("created_at").Find(&questions).Error; err != nil {
return nil, err return nil, err
} }
return questions, nil return questions, nil
} }
func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (q *Question) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
result, err := q.Read(args, w, r) result, err := q.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
question := result.(*Question) question := result.(*Question)
if err := DB().Find(&question.AllContests).Error; err != nil { if err := db._db.Find(&question.AllContests).Error; err != nil {
return nil, err return nil, err
} }
@ -94,7 +94,7 @@ func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http
return question, nil return question, nil
} else { } else {
question, err := q.Read(args, w, r) question, err := q.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -102,11 +102,11 @@ func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = SaveQuestion(question) _, err = SaveQuestion(db, question)
if err != nil { if err != nil {
return nil, err return nil, err
} }
question, err = q.Read(args, w, r) question, err = q.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -114,26 +114,26 @@ func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http
} }
} }
func (model *Question) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Question) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
question, err := model.Read(args, w, r) question, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := DB().Unscoped().Delete(question.(*Question)).Error; err != nil { if err := db._db.Unscoped().Delete(question.(*Question)).Error; err != nil {
return nil, err return nil, err
} }
return question.(*Question), nil return question.(*Question), nil
} }
func CreateQuestion(question *Question) (*Question, error) { func CreateQuestion(db *Database, question *Question) (*Question, error) {
if err := DB().Create(question).Error; err != nil { if err := db._db.Create(question).Error; err != nil {
return nil, err return nil, err
} }
return question, nil return question, nil
} }
func SaveQuestion(question interface{}) (interface{}, error) { func SaveQuestion(db *Database, question interface{}) (interface{}, error) {
if err := DB().Omit("Answers", "Contest").Save(question).Error; err != nil { if err := db._db.Omit("Answers", "Contest").Save(question).Error; err != nil {
return nil, err return nil, err
} }
return question, nil return question, nil

View file

@ -60,10 +60,10 @@ type Region struct {
Name string Name string
} }
func CreateRegions() { func CreateRegions(db *Database) {
for _, name := range regions { for _, name := range regions {
var region Region var region Region
if err := currDB.FirstOrCreate(&region, Region{Name: name}).Error; err != nil { if err := db._db.FirstOrCreate(&region, Region{Name: name}).Error; err != nil {
panic(err) panic(err)
} }
} }
@ -75,10 +75,10 @@ func (model *Region) String() string {
return model.Name return model.Name
} }
func (model *Region) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Region) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
region := new(Region) region := new(Region)
// if err := DB().Find(&region.AllContests).Error; err != nil { // if err := db._db.Find(&region.AllContests).Error; err != nil {
// return nil, err // return nil, err
// } // }
return region, nil return region, nil
@ -88,7 +88,7 @@ func (model *Region) Create(args map[string]string, w http.ResponseWriter, r *ht
if err != nil { if err != nil {
return nil, err return nil, err
} }
region, err = CreateRegion(region) region, err = CreateRegion(db, region)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -96,36 +96,36 @@ func (model *Region) Create(args map[string]string, w http.ResponseWriter, r *ht
} }
} }
func (model *Region) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Region) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var region Region var region Region
id := args["id"] id := args["id"]
if err := DB(). /*.Preload("Something")*/ First(&region, id).Error; err != nil { if err := db._db. /*.Preload("Something")*/ First(&region, id).Error; err != nil {
return nil, err return nil, err
} }
return &region, nil return &region, nil
} }
func (model *Region) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Region) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var regions []*Region var regions []*Region
if err := DB(). /*.Preload("Something")*/ Order("created_at").Find(&regions).Error; err != nil { if err := db._db. /*.Preload("Something")*/ Order("created_at").Find(&regions).Error; err != nil {
return nil, err return nil, err
} }
return regions, nil return regions, nil
} }
func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Region) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
result, err := model.Read(args, w, r) result, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
region := result.(*Region) region := result.(*Region)
// if err := DB().Find(&region.AllElements).Error; err != nil { // if err := db._db.Find(&region.AllElements).Error; err != nil {
// return nil, err // return nil, err
// } // }
@ -134,7 +134,7 @@ func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *ht
return region, nil return region, nil
} else { } else {
region, err := model.Read(args, w, r) region, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -142,11 +142,11 @@ func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *ht
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = SaveRegion(region) _, err = SaveRegion(db, region)
if err != nil { if err != nil {
return nil, err return nil, err
} }
region, err = model.Read(args, w, r) region, err = model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -154,26 +154,26 @@ func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *ht
} }
} }
func (model *Region) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Region) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
region, err := model.Read(args, w, r) region, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := DB().Unscoped().Delete(region.(*Region)).Error; err != nil { if err := db._db.Unscoped().Delete(region.(*Region)).Error; err != nil {
return nil, err return nil, err
} }
return region.(*Region), nil return region.(*Region), nil
} }
func CreateRegion(region *Region) (*Region, error) { func CreateRegion(db *Database, region *Region) (*Region, error) {
if err := DB().Create(region).Error; err != nil { if err := db._db.Create(region).Error; err != nil {
return nil, err return nil, err
} }
return region, nil return region, nil
} }
func SaveRegion(region interface{}) (interface{}, error) { func SaveRegion(db *Database, region interface{}) (interface{}, error) {
if err := DB(). /*.Omit("Something")*/ Save(region).Error; err != nil { if err := db._db. /*.Omit("Something")*/ Save(region).Error; err != nil {
return nil, err return nil, err
} }
return region, nil return region, nil

View file

@ -77,13 +77,13 @@ func (model *Response) BeforeSave(tx *gorm.DB) error {
return nil return nil
} }
func (model *Response) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Response) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
response := new(Response) response := new(Response)
contestID := r.URL.Query().Get("contest_id") contestID := r.URL.Query().Get("contest_id")
if err := DB().Preload("Answers").Where("contest_id = ?", contestID).Find(&response.Questions).Error; err != nil { if err := db._db.Preload("Answers").Where("contest_id = ?", contestID).Find(&response.Questions).Error; err != nil {
return nil, err return nil, err
} }
return response, nil return response, nil
@ -96,7 +96,7 @@ func (model *Response) Create(args map[string]string, w http.ResponseWriter, r *
response.UserModifierCreate = NewUserModifierCreate(r) response.UserModifierCreate = NewUserModifierCreate(r)
response, err = CreateResponse(response) response, err = CreateResponse(db, response)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -104,12 +104,12 @@ func (model *Response) Create(args map[string]string, w http.ResponseWriter, r *
} }
} }
func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Response) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var response Response var response Response
id := args["id"] id := args["id"]
if err := DB().Preload("Contest").Preload("Participant").First(&response, id).Error; err != nil { if err := db._db.Preload("Contest").Preload("Participant").First(&response, id).Error; err != nil {
return nil, err return nil, err
} }
@ -126,7 +126,7 @@ func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *ht
for _, sr := range response.SingleResponses { for _, sr := range response.SingleResponses {
var answer Answer var answer Answer
if err := DB().First(&answer, sr.AnswerID).Error; err != nil { if err := db._db.First(&answer, sr.AnswerID).Error; err != nil {
return nil, err return nil, err
} }
@ -151,7 +151,7 @@ func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *ht
// Fetch questions in the given order // Fetch questions in the given order
field := fmt.Sprintf("FIELD(id,%s)", strings.Replace(response.QuestionsOrder, " ", ",", -1)) field := fmt.Sprintf("FIELD(id,%s)", strings.Replace(response.QuestionsOrder, " ", ",", -1))
if err := DB().Order(field).Where("contest_id = ?", response.Contest.ID).Preload("Answers", func(db *gorm.DB) *gorm.DB { if err := db._db.Order(field).Where("contest_id = ?", response.Contest.ID).Preload("Answers", func(db *gorm.DB) *gorm.DB {
return db.Order("RAND()") return db.Order("RAND()")
}).Find(&response.Questions).Error; err != nil { }).Find(&response.Questions).Error; err != nil {
return nil, err return nil, err
@ -160,17 +160,17 @@ func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *ht
return &response, nil return &response, nil
} }
func (model *Response) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Response) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var responses []*Response var responses []*Response
if err := DB().Preload("Contest").Preload("Participant").Order("created_at").Find(&responses).Error; err != nil { if err := db._db.Preload("Contest").Preload("Participant").Order("created_at").Find(&responses).Error; err != nil {
return nil, err return nil, err
} }
return responses, nil return responses, nil
} }
func (model *Response) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Response) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
result, err := model.Read(args, w, r) result, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -184,10 +184,10 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r *
return nil, errors.OutOfTime return nil, errors.OutOfTime
} }
if response.StartTime.IsZero() { if response.StartTime.IsZero() {
if err := DB().Model(&response).Update("start_time", time.Now()).Error; err != nil { if err := db._db.Model(&response).Update("start_time", time.Now()).Error; err != nil {
return nil, err return nil, err
} }
if err := DB().Model(&response).Update("end_time", time.Now().Add(time.Duration(response.Contest.Duration)*time.Minute)).Error; err != nil { if err := db._db.Model(&response).Update("end_time", time.Now().Add(time.Duration(response.Contest.Duration)*time.Minute)).Error; err != nil {
return nil, err return nil, err
} }
log.Println("StartTime/EndTime", response.StartTime, response.EndTime) log.Println("StartTime/EndTime", response.StartTime, response.EndTime)
@ -197,7 +197,7 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r *
return response, nil return response, nil
} else { } else {
response, err := model.Read(args, w, r) response, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -214,11 +214,11 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r *
response.(*Response).UserModifierUpdate = NewUserModifierUpdate(r) response.(*Response).UserModifierUpdate = NewUserModifierUpdate(r)
_, err = SaveResponse(response) _, err = SaveResponse(db, response)
if err != nil { if err != nil {
return nil, err return nil, err
} }
response, err = model.Read(args, w, r) response, err = model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -226,26 +226,26 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r *
} }
} }
func (model *Response) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *Response) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
response, err := model.Read(args, w, r) response, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := DB().Unscoped().Delete(response.(*Response)).Error; err != nil { if err := db._db.Unscoped().Delete(response.(*Response)).Error; err != nil {
return nil, err return nil, err
} }
return response.(*Response), nil return response.(*Response), nil
} }
func CreateResponse(response *Response) (*Response, error) { func CreateResponse(db *Database, response *Response) (*Response, error) {
if err := DB().Create(response).Error; err != nil { if err := db._db.Create(response).Error; err != nil {
return nil, err return nil, err
} }
return response, nil return response, nil
} }
func SaveResponse(response interface{}) (interface{}, error) { func SaveResponse(db *Database, response interface{}) (interface{}, error) {
if err := DB(). /*.Omit("Something")*/ Save(response).Error; err != nil { if err := db._db. /*.Omit("Something")*/ Save(response).Error; err != nil {
return nil, err return nil, err
} }
return response, nil return response, nil

View file

@ -70,9 +70,9 @@ func (model *School) To() string {
return model.Email return model.Email
} }
func (model *School) exists() (*User, error) { func (model *School) exists(db *Database) (*User, error) {
var user User var user User
if err := DB().First(&user, &User{Username: model.Username()}).Error; err != nil && err != gorm.ErrRecordNotFound { if err := db._db.First(&user, &User{Username: model.Username()}).Error; err != nil && err != gorm.ErrRecordNotFound {
return nil, err return nil, err
} else if err == gorm.ErrRecordNotFound { } else if err == gorm.ErrRecordNotFound {
return nil, nil return nil, nil
@ -118,11 +118,11 @@ func (model *School) AfterCreate(tx *gorm.DB) error {
return nil return nil
} }
func (model *School) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *School) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
school := new(School) school := new(School)
if err := DB().Find(&school.AllRegions).Error; err != nil { if err := db._db.Find(&school.AllRegions).Error; err != nil {
return nil, err return nil, err
} }
@ -135,8 +135,8 @@ func (model *School) Create(args map[string]string, w http.ResponseWriter, r *ht
} }
// Check if this user already exists in the users table. // Check if this user already exists in the users table.
if user, err := school.exists(); err == nil && user != nil { if user, err := school.exists(db); err == nil && user != nil {
if err := DB().Where("user_id = ?", user.ID).Find(&school).Error; err != nil { if err := db._db.Where("user_id = ?", user.ID).Find(&school).Error; err != nil {
return nil, err return nil, err
} }
// err := setFlashMessage(w, r, "schoolExists") // err := setFlashMessage(w, r, "schoolExists")
@ -150,7 +150,7 @@ func (model *School) Create(args map[string]string, w http.ResponseWriter, r *ht
school.UserModifierCreate = NewUserModifierCreate(r) school.UserModifierCreate = NewUserModifierCreate(r)
school, err = CreateSchool(school) school, err = CreateSchool(db, school)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -158,7 +158,7 @@ func (model *School) Create(args map[string]string, w http.ResponseWriter, r *ht
} }
} }
func (model *School) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *School) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var school School var school School
id := args["id"] id := args["id"]
@ -168,31 +168,31 @@ func (model *School) Read(args map[string]string, w http.ResponseWriter, r *http
return nil, errors.NotAuthorized return nil, errors.NotAuthorized
} }
if err := DB().Preload("User").Preload("Region").Preload("Participants.Category").Preload("Participants").First(&school, id).Error; err != nil { if err := db._db.Preload("User").Preload("Region").Preload("Participants.Category").Preload("Participants").First(&school, id).Error; err != nil {
return nil, err return nil, err
} }
return &school, nil return &school, nil
} }
func (model *School) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *School) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var schools []*School var schools []*School
if err := DB().Preload("Region").Preload("Participants.Category").Preload("Participants").Order("code").Find(&schools).Error; err != nil { if err := db._db.Preload("Region").Preload("Participants.Category").Preload("Participants").Order("code").Find(&schools).Error; err != nil {
return nil, err return nil, err
} }
return schools, nil return schools, nil
} }
func (model *School) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *School) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
result, err := model.Read(args, w, r) result, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
school := result.(*School) school := result.(*School)
if err := DB().Find(&school.AllRegions).Error; err != nil { if err := db._db.Find(&school.AllRegions).Error; err != nil {
return nil, err return nil, err
} }
@ -201,7 +201,7 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht
return school, nil return school, nil
} else { } else {
school, err := model.Read(args, w, r) school, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -211,7 +211,7 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht
} }
// Check if the modified school code belong to an existing school. // Check if the modified school code belong to an existing school.
if user, err := school.(*School).exists(); err == nil && user != nil { if user, err := school.(*School).exists(db); err == nil && user != nil {
if user.ID != school.(*School).UserID { if user.ID != school.(*School).UserID {
// err := setFlashMessage(w, r, "schoolExists") // err := setFlashMessage(w, r, "schoolExists")
// if err != nil { // if err != nil {
@ -225,11 +225,11 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht
school.(*School).UserModifierUpdate = NewUserModifierUpdate(r) school.(*School).UserModifierUpdate = NewUserModifierUpdate(r)
_, err = SaveSchool(school) _, err = SaveSchool(db, school)
if err != nil { if err != nil {
return nil, err return nil, err
} }
school, err = model.Read(args, w, r) school, err = model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -237,35 +237,35 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht
} }
} }
func (model *School) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *School) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
school, err := model.Read(args, w, r) school, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := DB().Unscoped().Delete(school.(*School)).Error; err != nil { if err := db._db.Unscoped().Delete(school.(*School)).Error; err != nil {
return nil, err return nil, err
} }
return school.(*School), nil return school.(*School), nil
} }
func CreateSchool(school *School) (*School, error) { func CreateSchool(db *Database, school *School) (*School, error) {
if err := DB().Create(school).Error; err != nil { if err := db._db.Create(school).Error; err != nil {
return nil, err return nil, err
} }
return school, nil return school, nil
} }
func SaveSchool(school interface{}) (interface{}, error) { func SaveSchool(db *Database, school interface{}) (interface{}, error) {
if err := DB().Omit("Region").Save(school).Error; err != nil { if err := db._db.Omit("Region").Save(school).Error; err != nil {
return nil, err return nil, err
} }
return school, nil return school, nil
} }
func (model *School) HasCategory(participant *Participant) (bool, error) { func (model *School) HasCategory(db *Database, participant *Participant) (bool, error) {
var participants []*Participant var participants []*Participant
if err := DB().Where("category_id = ? AND school_id = ? AND id <> ?", participant.CategoryID, model.ID, participant.ID).Find(&participants).Error; err != nil { if err := db._db.Where("category_id = ? AND school_id = ? AND id <> ?", participant.CategoryID, model.ID, participant.ID).Find(&participants).Error; err != nil {
return false, err return false, err
} }
return len(participants) > 0, nil return len(participants) > 0, nil

View file

@ -47,7 +47,7 @@ func (model *User) BeforeSave(tx *gorm.DB) error {
return nil return nil
} }
func (model *User) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *User) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
user := new(User) user := new(User)
return user, nil return user, nil
@ -57,7 +57,7 @@ func (model *User) Create(args map[string]string, w http.ResponseWriter, r *http
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err = CreateUser(user) user, err = CreateUser(db, user)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -65,36 +65,36 @@ func (model *User) Create(args map[string]string, w http.ResponseWriter, r *http
} }
} }
func (model *User) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *User) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var user User var user User
id := args["id"] id := args["id"]
if err := DB(). /*.Preload("Something")*/ First(&user, id).Error; err != nil { if err := db._db. /*.Preload("Something")*/ First(&user, id).Error; err != nil {
return nil, err return nil, err
} }
return &user, nil return &user, nil
} }
func (model *User) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *User) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var users []*User var users []*User
if err := DB(). /*.Preload("Something")*/ Order("created_at").Find(&users).Error; err != nil { if err := db._db. /*.Preload("Something")*/ Order("created_at").Find(&users).Error; err != nil {
return nil, err return nil, err
} }
return users, nil return users, nil
} }
func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *User) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" { if r.Method == "GET" {
result, err := model.Read(args, w, r) result, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user := result.(*User) user := result.(*User)
// if err := DB().Find(&user.AllElements).Error; err != nil { // if err := db._db.Find(&user.AllElements).Error; err != nil {
// return nil, err // return nil, err
// } // }
@ -103,7 +103,7 @@ func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http
return user, nil return user, nil
} else { } else {
user, err := model.Read(args, w, r) user, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -111,11 +111,11 @@ func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = SaveUser(user) _, err = SaveUser(db, user)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err = model.Read(args, w, r) user, err = model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -123,26 +123,26 @@ func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http
} }
} }
func (model *User) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { func (model *User) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
user, err := model.Read(args, w, r) user, err := model.Read(db, args, w, r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := DB().Unscoped().Delete(user.(*User)).Error; err != nil { if err := db._db.Unscoped().Delete(user.(*User)).Error; err != nil {
return nil, err return nil, err
} }
return user.(*User), nil return user.(*User), nil
} }
func CreateUser(user *User) (*User, error) { func CreateUser(db *Database, user *User) (*User, error) {
if err := DB().Create(user).Error; err != nil { if err := db._db.Create(user).Error; err != nil {
return nil, err return nil, err
} }
return user, nil return user, nil
} }
func SaveUser(user interface{}) (interface{}, error) { func SaveUser(db *Database, user interface{}) (interface{}, error) {
if err := DB(). /*.Omit("Something")*/ Save(user).Error; err != nil { if err := db._db. /*.Omit("Something")*/ Save(user).Error; err != nil {
return nil, err return nil, err
} }
return user, nil return user, nil

View file

@ -39,20 +39,20 @@ func NewUserModifierCreate(r *http.Request) *UserModifierCreate {
} }
} }
func (um *UserModifierCreate) CreatedBy() (*UserAction, error) { func (um *UserModifierCreate) CreatedBy(db *Database) (*UserAction, error) {
action := new(UserAction) action := new(UserAction)
switch um.CreatorRole { switch um.CreatorRole {
case "participant": case "participant":
var participant Participant var participant Participant
if err := DB().Preload("User").First(&participant, um.CreatorID).Error; err != nil { if err := db._db.Preload("User").First(&participant, um.CreatorID).Error; err != nil {
return nil, err return nil, err
} }
action.User = *participant.User action.User = *participant.User
case "school": case "school":
var school School var school School
if err := DB().Preload("User").First(&school, um.CreatorID).Error; err != nil { if err := db._db.Preload("User").First(&school, um.CreatorID).Error; err != nil {
return nil, err return nil, err
} }
action.User = *school.User action.User = *school.User
@ -80,19 +80,19 @@ func NewUserModifierUpdate(r *http.Request) *UserModifierUpdate {
} }
} }
func (um *UserModifierUpdate) UpdatedBy() (*UserAction, error) { func (um *UserModifierUpdate) UpdatedBy(db *Database) (*UserAction, error) {
action := new(UserAction) action := new(UserAction)
switch um.UpdaterRole { switch um.UpdaterRole {
case "participant": case "participant":
var participant Participant var participant Participant
if err := DB().Preload("User").First(&participant, um.UpdaterID).Error; err != nil { if err := db._db.Preload("User").First(&participant, um.UpdaterID).Error; err != nil {
return nil, err return nil, err
} }
action.User = *participant.User action.User = *participant.User
case "school": case "school":
var school School var school School
if err := DB().Preload("User").First(&school, um.UpdaterID).Error; err != nil { if err := db._db.Preload("User").First(&school, um.UpdaterID).Error; err != nil {
return nil, err return nil, err
} }
action.User = *school.User action.User = *school.User