Encapsulate orm behaviour in the Database struct

This commit is contained in:
Andrea Fazzi 2020-01-15 11:27:00 +01:00
parent 2ad7e3d180
commit 3bb8dde670
14 changed files with 308 additions and 304 deletions

View file

@ -35,12 +35,13 @@ type PathPattern struct {
type Handlers struct {
Config *config.ConfigT
Database *orm.Database
Models []interface{}
Login func(store *sessions.CookieStore, signingKey []byte) http.Handler
Login func(db *orm.Database, store *sessions.CookieStore, signingKey []byte) http.Handler
Logout func(store *sessions.CookieStore) http.Handler
Home func() http.Handler
GetToken func(signingKey []byte) http.Handler
GetToken func(db *orm.Database, signingKey []byte) http.Handler
Static func() http.Handler
Recover func(next http.Handler) http.Handler
@ -164,10 +165,11 @@ func (h *Handlers) generateModelHandlers(r *mux.Router, model interface{}) {
}
func NewHandlers(config *config.ConfigT, models []interface{}) *Handlers {
func NewHandlers(config *config.ConfigT, db *orm.Database, models []interface{}) *Handlers {
handlers := new(Handlers)
handlers.Config = config
handlers.Database = db
handlers.CookieStore = sessions.NewCookieStore([]byte(config.Keys.CookieStoreKey))
@ -197,12 +199,12 @@ func NewHandlers(config *config.ConfigT, models []interface{}) *Handlers {
// Authentication
r.Handle("/login", handlers.Login(handlers.CookieStore, []byte(config.Keys.JWTSigningKey)))
r.Handle("/login", handlers.Login(handlers.Database, handlers.CookieStore, []byte(config.Keys.JWTSigningKey)))
r.Handle("/logout", handlers.Logout(handlers.CookieStore))
// School subscription
r.Handle("/subscribe", handlers.Login(handlers.CookieStore, []byte(config.Keys.JWTSigningKey)))
r.Handle("/subscribe", handlers.Login(handlers.Database, handlers.CookieStore, []byte(config.Keys.JWTSigningKey)))
// Home
@ -216,7 +218,7 @@ func NewHandlers(config *config.ConfigT, models []interface{}) *Handlers {
// Token handling
r.Handle("/get_token", handlers.GetToken([]byte(config.Keys.JWTSigningKey)))
r.Handle("/get_token", handlers.GetToken(handlers.Database, []byte(config.Keys.JWTSigningKey)))
// Static file server
@ -293,7 +295,7 @@ func hasPermission(role, path string) bool {
func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) {
format := r.URL.Query().Get("format")
getFn, err := orm.GetFunc(pattern.Path(model))
getFn, err := h.Database.GetFunc(pattern.Path(model))
if err != nil {
log.Println("Error:", err)
respondWithError(w, r, err)
@ -306,7 +308,7 @@ func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pat
renderer.Render[format](w, r, fmt.Errorf("%s", "Errore di autorizzazione"))
} else {
data, err := getFn(mux.Vars(r), w, r)
data, err := getFn(h.Database, mux.Vars(r), w, r)
if err != nil {
renderer.Render[format](w, r, err)
} else {
@ -317,14 +319,14 @@ func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pat
}
func post(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) {
func (h *Handlers) post(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) {
var (
data interface{}
err error
)
respFormat := renderer.GetContentFormat(r)
postFn, err := orm.GetFunc(pattern.Path(model))
postFn, err := h.Database.GetFunc(pattern.Path(model))
if err != nil {
respondWithError(w, r, err)
@ -335,7 +337,7 @@ func post(w http.ResponseWriter, r *http.Request, model string, pattern PathPatt
if !hasPermission(role, pattern.Path(model)) {
renderer.Render[respFormat](w, r, fmt.Errorf("%s", "Errore di autorizzazione"))
} else {
data, err = postFn(mux.Vars(r), w, r)
data, err = postFn(h.Database, mux.Vars(r), w, r)
if err != nil {
respondWithError(w, r, err)
} else if pattern.RedirectPattern != "" {
@ -353,7 +355,7 @@ func post(w http.ResponseWriter, r *http.Request, model string, pattern PathPatt
}
func delete(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) {
func (h *Handlers) delete(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) {
var data interface{}
respFormat := renderer.GetContentFormat(r)
@ -364,11 +366,11 @@ func delete(w http.ResponseWriter, r *http.Request, model string, pattern PathPa
if !hasPermission(role, pattern.Path(model)) {
renderer.Render[respFormat](w, r, fmt.Errorf("%s", "Errore di autorizzazione"))
} else {
postFn, err := orm.GetFunc(pattern.Path(model))
postFn, err := h.Database.GetFunc(pattern.Path(model))
if err != nil {
renderer.Render[r.URL.Query().Get("format")](w, r, err)
}
data, err = postFn(mux.Vars(r), w, r)
data, err = postFn(h.Database, mux.Vars(r), w, r)
if err != nil {
renderer.Render["html"](w, r, err)
} else if pattern.RedirectPattern != "" {
@ -402,10 +404,10 @@ func (h *Handlers) modelHandler(model string, pattern PathPattern) http.Handler
h.get(w, r, model, pattern)
case "POST":
post(w, r, model, pattern)
h.post(w, r, model, pattern)
case "DELETE":
delete(w, r, model, pattern)
h.delete(w, r, model, pattern)
}
}

View file

@ -59,14 +59,14 @@ func DefaultLogoutHandler(store *sessions.CookieStore) http.Handler {
return http.HandlerFunc(fn)
}
func DefaultLoginHandler(store *sessions.CookieStore, signingKey []byte) http.Handler {
func DefaultLoginHandler(db *orm.Database, store *sessions.CookieStore, signingKey []byte) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
renderer.Render["html"](w, r, nil, r.URL.Query())
}
if r.Method == "POST" {
r.ParseForm()
token, err := getToken(r.FormValue("username"), r.FormValue("password"), signingKey)
token, err := getToken(db, r.FormValue("username"), r.FormValue("password"), signingKey)
if err != nil {
http.Redirect(w, r, "/login?tpl_layout=login&tpl_content=login&failed=true", http.StatusSeeOther)
} else {
@ -87,9 +87,7 @@ func DefaultLoginHandler(store *sessions.CookieStore, signingKey []byte) http.Ha
// FIXME: This is an hack for fast prototyping: users should have
// their own table on DB.
func checkCredential(username string, password string) (*UserToken, error) {
var user orm.User
func checkCredential(db *orm.Database, username string, password string) (*UserToken, error) {
// Check if user is the administrator
if username == config.Config.Admin.Username && password == config.Config.Admin.Password {
@ -104,31 +102,31 @@ func checkCredential(username string, password string) (*UserToken, error) {
var token *UserToken
if err := orm.DB().Where("username = ? AND password = ?", username, password).First(&user).Error; err != nil {
user, err := db.GetUser(username, password)
if err != nil {
return nil, errors.New("Authentication failed!")
} else {
}
switch user.Role {
case "participant":
var participant orm.Participant
if err := orm.DB().First(&participant, &orm.Participant{UserID: user.ID}).Error; err != nil {
if err := db.DB().First(&participant, &orm.Participant{UserID: user.ID}).Error; err != nil {
return nil, errors.New("Authentication failed!")
}
token = &UserToken{username, false, user.Role, strconv.Itoa(int(participant.ID))}
case "school":
var school orm.School
if err := orm.DB().First(&school, &orm.School{UserID: user.ID}).Error; err != nil {
if err := db.DB().First(&school, &orm.School{UserID: user.ID}).Error; err != nil {
return nil, errors.New("Authentication failed!")
}
token = &UserToken{username, false, user.Role, strconv.Itoa(int(school.ID))}
}
}
return token, nil
}
// FIXME: Refactor the functions above please!!!
func getToken(username string, password string, signingKey []byte) ([]byte, error) {
user, err := checkCredential(username, password)
func getToken(db *orm.Database, username string, password string, signingKey []byte) ([]byte, error) {
user, err := checkCredential(db, username, password)
if err != nil {
return nil, err
}
@ -154,11 +152,11 @@ func getToken(username string, password string, signingKey []byte) ([]byte, erro
}
// FIXME: Refactor the functions above please!!!
func DefaultGetTokenHandler(signingKey []byte) http.Handler {
func DefaultGetTokenHandler(db *orm.Database, signingKey []byte) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
username, password, _ := r.BasicAuth()
token, err := getToken(username, password, signingKey)
token, err := getToken(db, username, password, signingKey)
if err != nil {
panic(err)
}

22
main.go
View file

@ -2,14 +2,12 @@ package main
import (
"flag"
"fmt"
"log"
"net/http"
"os"
"time"
"github.com/gorilla/handlers"
"github.com/jinzhu/gorm"
"git.andreafazzi.eu/andrea/oef/config"
oef_handlers "git.andreafazzi.eu/andrea/oef/handlers"
@ -22,7 +20,7 @@ const (
)
var (
db *gorm.DB
db *orm.Database
err error
models = []interface{}{
@ -54,7 +52,7 @@ func main() {
wait := true
for wait && count > 0 {
db, err = orm.New(fmt.Sprintf("%s?%s", config.Config.Orm.Connection, config.Config.Orm.Options))
db, err = orm.NewDatabase(config.Config, models)
if err != nil {
count--
log.Println(err)
@ -65,26 +63,22 @@ func main() {
wait = false
}
orm.Use(db)
// REMOVE
// orm.Use(db)
if config.Config.Orm.AutoMigrate {
log.Print("Automigrating...")
orm.AutoMigrate(models...)
db.AutoMigrate()
}
log.Println("Eventually write categories on DB...")
orm.CreateCategories()
orm.CreateCategories(db)
log.Println("Eventually write regions on DB...")
orm.CreateRegions()
log.Println("Map models <-> handlers")
if err := orm.MapHandlers(models); err != nil {
panic(err)
}
orm.CreateRegions(db)
log.Println("OEF is listening to port 3000...")
if err := http.ListenAndServe(":3000", handlers.LoggingHandler(os.Stdout, oef_handlers.NewHandlers(config.Config, models).Router)); err != nil {
if err := http.ListenAndServe(":3000", handlers.LoggingHandler(os.Stdout, oef_handlers.NewHandlers(config.Config, db, models).Router)); err != nil {
panic(err)
}

View file

@ -26,10 +26,10 @@ func (a *Answer) String() string {
return a.Text
}
func (a *Answer) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (a *Answer) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
answer := new(Answer)
if err := DB().Find(&answer.AllQuestions).Error; err != nil {
if err := db._db.Find(&answer.AllQuestions).Error; err != nil {
return nil, err
}
@ -40,7 +40,7 @@ func (a *Answer) Create(args map[string]string, w http.ResponseWriter, r *http.R
if err != nil {
return nil, err
}
answer, err = CreateAnswer(answer)
answer, err = CreateAnswer(db, answer)
if err != nil {
return nil, err
}
@ -48,36 +48,36 @@ func (a *Answer) Create(args map[string]string, w http.ResponseWriter, r *http.R
}
}
func (a *Answer) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (a *Answer) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var answer Answer
id := args["id"]
if err := DB().Preload("Question").First(&answer, id).Error; err != nil {
if err := db._db.Preload("Question").First(&answer, id).Error; err != nil {
return nil, err
}
return &answer, nil
}
func (a *Answer) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (a *Answer) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var answers []*Answer
if err := DB().Preload("Question").Order("created_at").Find(&answers).Error; err != nil {
if err := db._db.Preload("Question").Order("created_at").Find(&answers).Error; err != nil {
return nil, err
}
return answers, nil
}
func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (a *Answer) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
result, err := a.Read(args, w, r)
result, err := a.Read(db, args, w, r)
if err != nil {
return nil, err
}
answer := result.(*Answer)
if err := DB().Find(&answer.AllQuestions).Error; err != nil {
if err := db._db.Find(&answer.AllQuestions).Error; err != nil {
return nil, err
}
@ -86,7 +86,7 @@ func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.R
return answer, nil
} else {
answer, err := a.Read(args, w, r)
answer, err := a.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -98,11 +98,11 @@ func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.R
if err != nil {
return nil, err
}
_, err = SaveAnswer(answer)
_, err = SaveAnswer(db, answer)
if err != nil {
return nil, err
}
answer, err = a.Read(args, w, r)
answer, err = a.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -110,26 +110,26 @@ func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.R
}
}
func (model *Answer) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
answer, err := model.Read(args, w, r)
func (model *Answer) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
answer, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
if err := DB().Unscoped().Delete(answer.(*Answer)).Error; err != nil {
if err := db._db.Unscoped().Delete(answer.(*Answer)).Error; err != nil {
return nil, err
}
return answer.(*Answer), nil
}
func CreateAnswer(answer *Answer) (*Answer, error) {
if err := DB().Create(answer).Error; err != nil {
func CreateAnswer(db *Database, answer *Answer) (*Answer, error) {
if err := db._db.Create(answer).Error; err != nil {
return nil, err
}
return answer, nil
}
func SaveAnswer(answer interface{}) (interface{}, error) {
if err := DB().Omit("Answers").Save(answer).Error; err != nil {
func SaveAnswer(db *Database, answer interface{}) (interface{}, error) {
if err := db._db.Omit("Answers").Save(answer).Error; err != nil {
return nil, err
}
return answer, nil

View file

@ -23,10 +23,10 @@ func (model *Category) String() string {
return model.Name
}
func (model *Category) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Category) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
category := new(Category)
// if err := DB().Find(&category.AllContests).Error; err != nil {
// if err := db._db.Find(&category.AllContests).Error; err != nil {
// return nil, err
// }
return category, nil
@ -36,7 +36,7 @@ func (model *Category) Create(args map[string]string, w http.ResponseWriter, r *
if err != nil {
return nil, err
}
category, err = CreateCategory(category)
category, err = CreateCategory(db, category)
if err != nil {
return nil, err
}
@ -44,36 +44,36 @@ func (model *Category) Create(args map[string]string, w http.ResponseWriter, r *
}
}
func (model *Category) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Category) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var category Category
id := args["id"]
if err := DB(). /*.Preload("Something")*/ First(&category, id).Error; err != nil {
if err := db._db. /*.Preload("Something")*/ First(&category, id).Error; err != nil {
return nil, err
}
return &category, nil
}
func (model *Category) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Category) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var categories []*Category
if err := DB(). /*.Preload("Something")*/ Order("created_at").Find(&categories).Error; err != nil {
if err := db._db. /*.Preload("Something")*/ Order("created_at").Find(&categories).Error; err != nil {
return nil, err
}
return categories, nil
}
func (model *Category) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Category) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
result, err := model.Read(args, w, r)
result, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
category := result.(*Category)
// if err := DB().Find(&category.AllElements).Error; err != nil {
// if err := db._db.Find(&category.AllElements).Error; err != nil {
// return nil, err
// }
@ -82,7 +82,7 @@ func (model *Category) Update(args map[string]string, w http.ResponseWriter, r *
return category, nil
} else {
category, err := model.Read(args, w, r)
category, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -90,11 +90,11 @@ func (model *Category) Update(args map[string]string, w http.ResponseWriter, r *
if err != nil {
return nil, err
}
_, err = SaveCategory(category)
_, err = SaveCategory(db, category)
if err != nil {
return nil, err
}
category, err = model.Read(args, w, r)
category, err = model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -102,26 +102,26 @@ func (model *Category) Update(args map[string]string, w http.ResponseWriter, r *
}
}
func (model *Category) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
category, err := model.Read(args, w, r)
func (model *Category) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
category, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
if err := DB().Unscoped().Delete(category.(*Category)).Error; err != nil {
if err := db._db.Unscoped().Delete(category.(*Category)).Error; err != nil {
return nil, err
}
return category.(*Category), nil
}
func CreateCategory(category *Category) (*Category, error) {
if err := DB().Create(category).Error; err != nil {
func CreateCategory(db *Database, category *Category) (*Category, error) {
if err := db._db.Create(category).Error; err != nil {
return nil, err
}
return category, nil
}
func SaveCategory(category interface{}) (interface{}, error) {
if err := DB(). /*.Omit("Something")*/ Save(category).Error; err != nil {
func SaveCategory(db *Database, category interface{}) (interface{}, error) {
if err := db._db. /*.Omit("Something")*/ Save(category).Error; err != nil {
return nil, err
}
return category, nil

View file

@ -39,7 +39,7 @@ func (c *Contest) String() string {
return c.Name
}
func (c *Contest) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (c *Contest) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
return nil, nil
} else {
@ -68,7 +68,7 @@ func (c *Contest) Create(args map[string]string, w http.ResponseWriter, r *http.
if err != nil {
return nil, err
}
contest, err = CreateContest(contest)
contest, err = CreateContest(db, contest)
if err != nil {
return nil, err
}
@ -76,25 +76,25 @@ func (c *Contest) Create(args map[string]string, w http.ResponseWriter, r *http.
}
}
func (c *Contest) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (c *Contest) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var contest Contest
id := args["id"]
if err := DB().Preload("Participants").Preload("Questions").First(&contest, id).Error; err != nil {
if err := db._db.Preload("Participants").Preload("Questions").First(&contest, id).Error; err != nil {
return nil, err
}
return &contest, nil
}
func (c *Contest) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (c *Contest) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var contests []*Contest
claims := r.Context().Value("user").(*jwt.Token).Claims.(jwt.MapClaims)
if claims["admin"].(bool) {
if err := DB().Order("created_at").Find(&contests).Error; err != nil {
if err := db._db.Order("created_at").Find(&contests).Error; err != nil {
return nil, err
} else {
return contests, nil
@ -103,16 +103,16 @@ func (c *Contest) ReadAll(args map[string]string, w http.ResponseWriter, r *http
participant := &Participant{}
if err := DB().Preload("Contests").Where("username = ?", claims["name"].(string)).First(&participant).Error; err != nil {
if err := db._db.Preload("Contests").Where("username = ?", claims["name"].(string)).First(&participant).Error; err != nil {
return nil, err
} else {
return participant.Contests, nil
}
}
func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (c *Contest) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
result, err := c.Read(args, w, r)
result, err := c.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -120,7 +120,7 @@ func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http.
return contest, nil
} else {
contest, err := c.Read(args, w, r)
contest, err := c.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -148,11 +148,11 @@ func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http.
if err != nil {
return nil, err
}
_, err = SaveContest(contest)
_, err = SaveContest(db, contest)
if err != nil {
return nil, err
}
contest, err = c.Read(args, w, r)
contest, err = c.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -160,26 +160,26 @@ func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http.
}
}
func (model *Contest) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
contest, err := model.Read(args, w, r)
func (model *Contest) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
contest, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
if err := DB().Unscoped().Delete(contest.(*Contest)).Error; err != nil {
if err := db._db.Unscoped().Delete(contest.(*Contest)).Error; err != nil {
return nil, err
}
return contest.(*Contest), nil
}
func CreateContest(contest *Contest) (*Contest, error) {
if err := DB().Create(contest).Error; err != nil {
func CreateContest(db *Database, contest *Contest) (*Contest, error) {
if err := db._db.Create(contest).Error; err != nil {
return nil, err
}
return contest, nil
}
func SaveContest(contest interface{}) (interface{}, error) {
if err := DB().Omit("Contests").Save(contest).Error; err != nil {
func SaveContest(db *Database, contest interface{}) (interface{}, error) {
if err := db._db.Omit("Contests").Save(contest).Error; err != nil {
return nil, err
}
return contest, nil
@ -189,13 +189,13 @@ func (c *Contest) isAlwaysActive() bool {
return c.StartTime.IsZero() || c.EndTime.IsZero() || c.Duration == 0
}
func (c *Contest) generateQuestionsOrder() (string, error) {
func (c *Contest) generateQuestionsOrder(db *gorm.DB) (string, error) {
var (
order []string
questions []*Question
)
if err := DB().Find(&questions, Question{ContestID: c.ID}).Error; err != nil {
if err := db.Find(&questions, Question{ContestID: c.ID}).Error; err != nil {
return "", err
}

View file

@ -1,6 +1,7 @@
package orm
import (
"errors"
"fmt"
"net/http"
"path"
@ -18,51 +19,60 @@ type IDer interface {
GetID() uint
}
type GetFn func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
type GetFn func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
type Database struct {
Config *config.ConfigT
models []interface{}
_db *gorm.DB
fns map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
fns map[string]func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
}
func NewDatabase(config *config.ConfigT) (*Database, error) {
func NewDatabase(config *config.ConfigT, models []interface{}) (*Database, error) {
var err error
db := new(Database)
db.fns = make(db*gorm.DB, map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0)
db._db, err := gorm.Open("mysql", fmt.Sprintf("%s?%s", config.Config.Orm.Connection, config.Config.Orm.Options))
db.fns = make(map[string]func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0)
db.mapHandlers(models)
db._db, err = gorm.Open("mysql", fmt.Sprintf("%s?%s", config.Orm.Connection, config.Orm.Options))
if err != nil {
return nil, err
}
return db
return db, err
}
func (db *Database) AutoMigrate(models ...interface{}) {
if err := db._db.AutoMigrate(models...).Error; err != nil {
func (db *Database) AutoMigrate() {
if err := db._db.AutoMigrate(db.models...).Error; err != nil {
panic(err)
}
}
func CreateCategories() {
func (db *Database) GetUser(username, password string) (*User, error) {
var user *User
if err := db._db.Where("username = ? AND password = ?", username, password).First(&user).Error; err != nil {
return nil, errors.New("Authentication failed!")
}
return user, nil
}
func (db *Database) DB() *gorm.DB {
return db._db
}
func CreateCategories(db *Database) {
for _, name := range categories {
var category Category
if err := currDB.FirstOrCreate(&category, Category{Name: name}).Error; err != nil {
if err := db._db.FirstOrCreate(&category, Category{Name: name}).Error; err != nil {
panic(err)
}
}
}
func Use(db *gorm.DB) {
currDB = db
}
func DB() *gorm.DB {
return currDB
}
func MapHandlers(models []interface{}) error {
func (db *Database) mapHandlers(models []interface{}) error {
for _, model := range models {
name := inflection.Plural(strings.ToLower(ModelName(model)))
for p, action := range map[string]string{
@ -80,7 +90,7 @@ func MapHandlers(models []interface{}) error {
if strings.HasSuffix(p, "/") {
joinedPath += "/"
}
fns[joinedPath] = method.Interface().(func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error))
db.fns[joinedPath] = method.Interface().(func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error))
}
}
@ -88,19 +98,19 @@ func MapHandlers(models []interface{}) error {
return nil
}
func GetFunc(path string) (GetFn, error) {
fn, ok := fns[path]
func (db *Database) GetFunc(path string) (GetFn, error) {
fn, ok := db.fns[path]
if !ok {
return nil, fmt.Errorf("Can't map path %s to any model methods.", path)
}
return fn, nil
}
func GetNothing(args map[string]string) (interface{}, error) {
func GetNothing(db *Database, args map[string]string) (interface{}, error) {
return nil, nil
}
func PostNothing(args map[string]string, w http.ResponseWriter, r *http.Request) (IDer, error) {
func PostNothing(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (IDer, error) {
return nil, nil
}

View file

@ -92,9 +92,9 @@ func (model *Participant) String() string {
// return nil
// }
func (model *Participant) exists() (*User, error) {
func (model *Participant) exists(db *Database) (*User, error) {
var user User
if err := DB().First(&user, &User{Username: model.username()}).Error; err != nil && err != gorm.ErrRecordNotFound {
if err := db._db.First(&user, &User{Username: model.username()}).Error; err != nil && err != gorm.ErrRecordNotFound {
return nil, err
} else if err == gorm.ErrRecordNotFound {
return nil, nil
@ -129,7 +129,7 @@ func (model *Participant) AfterSave(tx *gorm.DB) error {
return err
}
order, err := contest.generateQuestionsOrder()
order, err := contest.generateQuestionsOrder(tx)
if err != nil {
return err
}
@ -148,21 +148,21 @@ func (model *Participant) AfterDelete(tx *gorm.DB) error {
return nil
}
func (model *Participant) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Participant) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
participant := new(Participant)
if isSchool(r) {
if err := DB().Find(&participant.AllCategories).Error; err != nil {
if err := db._db.Find(&participant.AllCategories).Error; err != nil {
return nil, err
}
} else {
if err := DB().Find(&participant.AllCategories).Error; err != nil {
if err := db._db.Find(&participant.AllCategories).Error; err != nil {
return nil, err
}
if err := DB().Find(&participant.AllContests).Error; err != nil {
if err := db._db.Find(&participant.AllContests).Error; err != nil {
return nil, err
}
if err := DB().Find(&participant.AllSchools).Error; err != nil {
if err := db._db.Find(&participant.AllSchools).Error; err != nil {
return nil, err
}
}
@ -175,8 +175,8 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter,
}
// Check if participant exists
if user, err := participant.exists(); err == nil && user != nil {
if err := DB().Where("user_id = ?", user.ID).Find(&participant).Error; err != nil {
if user, err := participant.exists(db); err == nil && user != nil {
if err := db._db.Where("user_id = ?", user.ID).Find(&participant).Error; err != nil {
return nil, err
}
// err := setFlashMessage(w, r, "participantExists")
@ -199,10 +199,10 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter,
// Check if a participant of the same category exists
var school School
if err := DB().First(&school, participant.SchoolID).Error; err != nil {
if err := db._db.First(&school, participant.SchoolID).Error; err != nil {
return nil, err
}
hasCategory, err := school.HasCategory(participant)
hasCategory, err := school.HasCategory(db, participant)
if err != nil {
return nil, err
}
@ -212,20 +212,20 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter,
participant.UserModifierCreate = NewUserModifierCreate(r)
participant, err = CreateParticipant(participant)
participant, err = CreateParticipant(db, participant)
if err != nil {
return nil, err
}
var response Response
if err := DB().First(&response, &Response{ParticipantID: participant.ID}).Error; err != nil {
if err := db._db.First(&response, &Response{ParticipantID: participant.ID}).Error; err != nil {
return nil, err
}
response.UserModifierCreate = NewUserModifierCreate(r)
if err := DB().Save(&response).Error; err != nil {
if err := db._db.Save(&response).Error; err != nil {
return nil, err
}
@ -233,14 +233,14 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter,
}
}
func (model *Participant) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Participant) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var participant Participant
id := args["id"]
// School user can access to its participants only!
if isSchool(r) {
if err := DB().Preload("School").First(&participant, id).Error; err != nil {
if err := db._db.Preload("School").First(&participant, id).Error; err != nil {
return nil, err
}
@ -249,12 +249,12 @@ func (model *Participant) Read(args map[string]string, w http.ResponseWriter, r
return nil, errors.NotAuthorized
}
if err := DB().Preload("User").Preload("School").Preload("Category").First(&participant, id).Error; err != nil {
if err := db._db.Preload("User").Preload("School").Preload("Category").First(&participant, id).Error; err != nil {
return nil, err
}
} else {
if err := DB().Preload("User").Preload("School").Preload("Responses").Preload("Contests").Preload("Category").First(&participant, id).Error; err != nil {
if err := db._db.Preload("User").Preload("School").Preload("Responses").Preload("Contests").Preload("Category").First(&participant, id).Error; err != nil {
return nil, err
}
}
@ -262,7 +262,7 @@ func (model *Participant) Read(args map[string]string, w http.ResponseWriter, r
return &participant, nil
}
func (model *Participant) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Participant) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var participants []*Participant
// School user can access to its participants only!
@ -271,20 +271,20 @@ func (model *Participant) ReadAll(args map[string]string, w http.ResponseWriter,
if err != nil {
return nil, err
}
if err := DB().Preload("Category").Preload("School").Preload("Contests").Order("lastname").Find(&participants, &Participant{SchoolID: uint(schoolId)}).Error; err != nil {
if err := db._db.Preload("Category").Preload("School").Preload("Contests").Order("lastname").Find(&participants, &Participant{SchoolID: uint(schoolId)}).Error; err != nil {
return nil, err
}
} else {
if err := DB().Preload("School").Preload("Contests").Preload("Responses").Order("created_at").Find(&participants).Error; err != nil {
if err := db._db.Preload("School").Preload("Contests").Preload("Responses").Order("created_at").Find(&participants).Error; err != nil {
return nil, err
}
}
return participants, nil
}
func (model *Participant) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Participant) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
result, err := model.Read(args, w, r)
result, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -292,21 +292,21 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
participant := result.(*Participant)
if isSchool(r) {
if err := DB().Find(&participant.AllCategories).Error; err != nil {
if err := db._db.Find(&participant.AllCategories).Error; err != nil {
return nil, err
}
participant.SelectedCategory = make(map[uint]string)
participant.SelectedCategory[participant.CategoryID] = "selected"
} else {
if err := DB().Find(&participant.AllCategories).Error; err != nil {
if err := db._db.Find(&participant.AllCategories).Error; err != nil {
return nil, err
}
participant.SelectedCategory = make(map[uint]string)
participant.SelectedCategory[participant.CategoryID] = "selected"
if err := DB().Find(&participant.AllContests).Error; err != nil {
if err := db._db.Find(&participant.AllContests).Error; err != nil {
return nil, err
}
@ -315,7 +315,7 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
participant.SelectedContest[c.ID] = "selected"
}
if err := DB().Find(&participant.AllSchools).Error; err != nil {
if err := db._db.Find(&participant.AllSchools).Error; err != nil {
return nil, err
}
@ -324,7 +324,7 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
}
return participant, nil
} else {
participant, err := model.Read(args, w, r)
participant, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -333,7 +333,7 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
return nil, err
}
if user, err := participant.(*Participant).exists(); err == nil && user != nil {
if user, err := participant.(*Participant).exists(db); err == nil && user != nil {
if user.ID != participant.(*Participant).UserID {
// err := setFlashMessage(w, r, "participantExists")
// if err != nil {
@ -347,10 +347,10 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
// Check if a participant of the same category exists
var school School
if err := DB().First(&school, participant.(*Participant).SchoolID).Error; err != nil {
if err := db._db.First(&school, participant.(*Participant).SchoolID).Error; err != nil {
return nil, err
}
hasCategory, err := school.HasCategory(participant.(*Participant))
hasCategory, err := school.HasCategory(db, participant.(*Participant))
if err != nil {
return nil, err
}
@ -358,35 +358,35 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
return nil, errors.CategoryExists
}
if err := DB().Where(participant.(*Participant).ContestIDs).Find(&participant.(*Participant).Contests).Error; err != nil {
if err := db._db.Where(participant.(*Participant).ContestIDs).Find(&participant.(*Participant).Contests).Error; err != nil {
return nil, err
}
participant.(*Participant).UserModifierUpdate = NewUserModifierUpdate(r)
_, err = SaveParticipant(participant)
_, err = SaveParticipant(db, participant)
if err != nil {
return nil, err
}
if err := DB().Model(participant).Association("Contests").Replace(participant.(*Participant).Contests).Error; err != nil {
if err := db._db.Model(participant).Association("Contests").Replace(participant.(*Participant).Contests).Error; err != nil {
return nil, err
}
participant, err = model.Read(args, w, r)
participant, err = model.Read(db, args, w, r)
if err != nil {
return nil, err
}
var response Response
if err := DB().First(&response, &Response{ParticipantID: participant.(*Participant).ID}).Error; err != nil {
if err := db._db.First(&response, &Response{ParticipantID: participant.(*Participant).ID}).Error; err != nil {
return nil, err
}
response.UserModifierUpdate = NewUserModifierUpdate(r)
if err := DB().Save(&response).Error; err != nil {
if err := db._db.Save(&response).Error; err != nil {
return nil, err
}
@ -394,31 +394,31 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
}
}
func (model *Participant) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
participant, err := model.Read(args, w, r)
func (model *Participant) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
participant, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
if err := DB().Unscoped().Delete(participant.(*Participant)).Error; err != nil {
if err := db._db.Unscoped().Delete(participant.(*Participant)).Error; err != nil {
return nil, err
}
return participant.(*Participant), nil
}
func CreateParticipant(participant *Participant) (*Participant, error) {
if err := DB().Where([]uint(participant.ContestIDs)).Find(&participant.Contests).Error; err != nil {
func CreateParticipant(db *Database, participant *Participant) (*Participant, error) {
if err := db._db.Where([]uint(participant.ContestIDs)).Find(&participant.Contests).Error; err != nil {
return nil, err
}
if err := DB().Create(participant).Error; err != nil {
if err := db._db.Create(participant).Error; err != nil {
return nil, err
}
return participant, nil
}
func SaveParticipant(participant interface{}) (interface{}, error) {
func SaveParticipant(db *Database, participant interface{}) (interface{}, error) {
participant.(*Participant).FiscalCode = strings.ToUpper(participant.(*Participant).FiscalCode)
if err := DB().Omit("Category", "School").Save(participant).Error; err != nil {
if err := db._db.Omit("Category", "School").Save(participant).Error; err != nil {
return nil, err
}
return participant, nil

View file

@ -33,10 +33,10 @@ func (q *Question) BeforeCreate(tx *gorm.DB) error {
return nil
}
func (q *Question) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (q *Question) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
question := new(Question)
if err := DB().Find(&question.AllContests).Error; err != nil {
if err := db._db.Find(&question.AllContests).Error; err != nil {
return nil, err
}
return question, nil
@ -46,7 +46,7 @@ func (q *Question) Create(args map[string]string, w http.ResponseWriter, r *http
if err != nil {
return nil, err
}
question, err = CreateQuestion(question)
question, err = CreateQuestion(db, question)
if err != nil {
return nil, err
}
@ -54,12 +54,12 @@ func (q *Question) Create(args map[string]string, w http.ResponseWriter, r *http
}
}
func (q *Question) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (q *Question) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var question Question
id := args["id"]
if err := DB().Preload("Answers", func(db *gorm.DB) *gorm.DB {
if err := db._db.Preload("Answers", func(db *gorm.DB) *gorm.DB {
return db.Order("answers.correct DESC")
}).Preload("Contest").First(&question, id).Error; err != nil {
return nil, err
@ -68,24 +68,24 @@ func (q *Question) Read(args map[string]string, w http.ResponseWriter, r *http.R
return &question, nil
}
func (q *Question) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (q *Question) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var questions []*Question
if err := DB().Preload("Contest").Order("created_at").Find(&questions).Error; err != nil {
if err := db._db.Preload("Contest").Order("created_at").Find(&questions).Error; err != nil {
return nil, err
}
return questions, nil
}
func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (q *Question) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
result, err := q.Read(args, w, r)
result, err := q.Read(db, args, w, r)
if err != nil {
return nil, err
}
question := result.(*Question)
if err := DB().Find(&question.AllContests).Error; err != nil {
if err := db._db.Find(&question.AllContests).Error; err != nil {
return nil, err
}
@ -94,7 +94,7 @@ func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http
return question, nil
} else {
question, err := q.Read(args, w, r)
question, err := q.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -102,11 +102,11 @@ func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http
if err != nil {
return nil, err
}
_, err = SaveQuestion(question)
_, err = SaveQuestion(db, question)
if err != nil {
return nil, err
}
question, err = q.Read(args, w, r)
question, err = q.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -114,26 +114,26 @@ func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http
}
}
func (model *Question) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
question, err := model.Read(args, w, r)
func (model *Question) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
question, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
if err := DB().Unscoped().Delete(question.(*Question)).Error; err != nil {
if err := db._db.Unscoped().Delete(question.(*Question)).Error; err != nil {
return nil, err
}
return question.(*Question), nil
}
func CreateQuestion(question *Question) (*Question, error) {
if err := DB().Create(question).Error; err != nil {
func CreateQuestion(db *Database, question *Question) (*Question, error) {
if err := db._db.Create(question).Error; err != nil {
return nil, err
}
return question, nil
}
func SaveQuestion(question interface{}) (interface{}, error) {
if err := DB().Omit("Answers", "Contest").Save(question).Error; err != nil {
func SaveQuestion(db *Database, question interface{}) (interface{}, error) {
if err := db._db.Omit("Answers", "Contest").Save(question).Error; err != nil {
return nil, err
}
return question, nil

View file

@ -60,10 +60,10 @@ type Region struct {
Name string
}
func CreateRegions() {
func CreateRegions(db *Database) {
for _, name := range regions {
var region Region
if err := currDB.FirstOrCreate(&region, Region{Name: name}).Error; err != nil {
if err := db._db.FirstOrCreate(&region, Region{Name: name}).Error; err != nil {
panic(err)
}
}
@ -75,10 +75,10 @@ func (model *Region) String() string {
return model.Name
}
func (model *Region) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Region) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
region := new(Region)
// if err := DB().Find(&region.AllContests).Error; err != nil {
// if err := db._db.Find(&region.AllContests).Error; err != nil {
// return nil, err
// }
return region, nil
@ -88,7 +88,7 @@ func (model *Region) Create(args map[string]string, w http.ResponseWriter, r *ht
if err != nil {
return nil, err
}
region, err = CreateRegion(region)
region, err = CreateRegion(db, region)
if err != nil {
return nil, err
}
@ -96,36 +96,36 @@ func (model *Region) Create(args map[string]string, w http.ResponseWriter, r *ht
}
}
func (model *Region) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Region) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var region Region
id := args["id"]
if err := DB(). /*.Preload("Something")*/ First(&region, id).Error; err != nil {
if err := db._db. /*.Preload("Something")*/ First(&region, id).Error; err != nil {
return nil, err
}
return &region, nil
}
func (model *Region) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Region) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var regions []*Region
if err := DB(). /*.Preload("Something")*/ Order("created_at").Find(&regions).Error; err != nil {
if err := db._db. /*.Preload("Something")*/ Order("created_at").Find(&regions).Error; err != nil {
return nil, err
}
return regions, nil
}
func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Region) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
result, err := model.Read(args, w, r)
result, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
region := result.(*Region)
// if err := DB().Find(&region.AllElements).Error; err != nil {
// if err := db._db.Find(&region.AllElements).Error; err != nil {
// return nil, err
// }
@ -134,7 +134,7 @@ func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *ht
return region, nil
} else {
region, err := model.Read(args, w, r)
region, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -142,11 +142,11 @@ func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *ht
if err != nil {
return nil, err
}
_, err = SaveRegion(region)
_, err = SaveRegion(db, region)
if err != nil {
return nil, err
}
region, err = model.Read(args, w, r)
region, err = model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -154,26 +154,26 @@ func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *ht
}
}
func (model *Region) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
region, err := model.Read(args, w, r)
func (model *Region) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
region, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
if err := DB().Unscoped().Delete(region.(*Region)).Error; err != nil {
if err := db._db.Unscoped().Delete(region.(*Region)).Error; err != nil {
return nil, err
}
return region.(*Region), nil
}
func CreateRegion(region *Region) (*Region, error) {
if err := DB().Create(region).Error; err != nil {
func CreateRegion(db *Database, region *Region) (*Region, error) {
if err := db._db.Create(region).Error; err != nil {
return nil, err
}
return region, nil
}
func SaveRegion(region interface{}) (interface{}, error) {
if err := DB(). /*.Omit("Something")*/ Save(region).Error; err != nil {
func SaveRegion(db *Database, region interface{}) (interface{}, error) {
if err := db._db. /*.Omit("Something")*/ Save(region).Error; err != nil {
return nil, err
}
return region, nil

View file

@ -77,13 +77,13 @@ func (model *Response) BeforeSave(tx *gorm.DB) error {
return nil
}
func (model *Response) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Response) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
response := new(Response)
contestID := r.URL.Query().Get("contest_id")
if err := DB().Preload("Answers").Where("contest_id = ?", contestID).Find(&response.Questions).Error; err != nil {
if err := db._db.Preload("Answers").Where("contest_id = ?", contestID).Find(&response.Questions).Error; err != nil {
return nil, err
}
return response, nil
@ -96,7 +96,7 @@ func (model *Response) Create(args map[string]string, w http.ResponseWriter, r *
response.UserModifierCreate = NewUserModifierCreate(r)
response, err = CreateResponse(response)
response, err = CreateResponse(db, response)
if err != nil {
return nil, err
}
@ -104,12 +104,12 @@ func (model *Response) Create(args map[string]string, w http.ResponseWriter, r *
}
}
func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Response) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var response Response
id := args["id"]
if err := DB().Preload("Contest").Preload("Participant").First(&response, id).Error; err != nil {
if err := db._db.Preload("Contest").Preload("Participant").First(&response, id).Error; err != nil {
return nil, err
}
@ -126,7 +126,7 @@ func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *ht
for _, sr := range response.SingleResponses {
var answer Answer
if err := DB().First(&answer, sr.AnswerID).Error; err != nil {
if err := db._db.First(&answer, sr.AnswerID).Error; err != nil {
return nil, err
}
@ -151,7 +151,7 @@ func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *ht
// Fetch questions in the given order
field := fmt.Sprintf("FIELD(id,%s)", strings.Replace(response.QuestionsOrder, " ", ",", -1))
if err := DB().Order(field).Where("contest_id = ?", response.Contest.ID).Preload("Answers", func(db *gorm.DB) *gorm.DB {
if err := db._db.Order(field).Where("contest_id = ?", response.Contest.ID).Preload("Answers", func(db *gorm.DB) *gorm.DB {
return db.Order("RAND()")
}).Find(&response.Questions).Error; err != nil {
return nil, err
@ -160,17 +160,17 @@ func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *ht
return &response, nil
}
func (model *Response) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Response) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var responses []*Response
if err := DB().Preload("Contest").Preload("Participant").Order("created_at").Find(&responses).Error; err != nil {
if err := db._db.Preload("Contest").Preload("Participant").Order("created_at").Find(&responses).Error; err != nil {
return nil, err
}
return responses, nil
}
func (model *Response) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *Response) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
result, err := model.Read(args, w, r)
result, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -184,10 +184,10 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r *
return nil, errors.OutOfTime
}
if response.StartTime.IsZero() {
if err := DB().Model(&response).Update("start_time", time.Now()).Error; err != nil {
if err := db._db.Model(&response).Update("start_time", time.Now()).Error; err != nil {
return nil, err
}
if err := DB().Model(&response).Update("end_time", time.Now().Add(time.Duration(response.Contest.Duration)*time.Minute)).Error; err != nil {
if err := db._db.Model(&response).Update("end_time", time.Now().Add(time.Duration(response.Contest.Duration)*time.Minute)).Error; err != nil {
return nil, err
}
log.Println("StartTime/EndTime", response.StartTime, response.EndTime)
@ -197,7 +197,7 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r *
return response, nil
} else {
response, err := model.Read(args, w, r)
response, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -214,11 +214,11 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r *
response.(*Response).UserModifierUpdate = NewUserModifierUpdate(r)
_, err = SaveResponse(response)
_, err = SaveResponse(db, response)
if err != nil {
return nil, err
}
response, err = model.Read(args, w, r)
response, err = model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -226,26 +226,26 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r *
}
}
func (model *Response) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
response, err := model.Read(args, w, r)
func (model *Response) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
response, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
if err := DB().Unscoped().Delete(response.(*Response)).Error; err != nil {
if err := db._db.Unscoped().Delete(response.(*Response)).Error; err != nil {
return nil, err
}
return response.(*Response), nil
}
func CreateResponse(response *Response) (*Response, error) {
if err := DB().Create(response).Error; err != nil {
func CreateResponse(db *Database, response *Response) (*Response, error) {
if err := db._db.Create(response).Error; err != nil {
return nil, err
}
return response, nil
}
func SaveResponse(response interface{}) (interface{}, error) {
if err := DB(). /*.Omit("Something")*/ Save(response).Error; err != nil {
func SaveResponse(db *Database, response interface{}) (interface{}, error) {
if err := db._db. /*.Omit("Something")*/ Save(response).Error; err != nil {
return nil, err
}
return response, nil

View file

@ -70,9 +70,9 @@ func (model *School) To() string {
return model.Email
}
func (model *School) exists() (*User, error) {
func (model *School) exists(db *Database) (*User, error) {
var user User
if err := DB().First(&user, &User{Username: model.Username()}).Error; err != nil && err != gorm.ErrRecordNotFound {
if err := db._db.First(&user, &User{Username: model.Username()}).Error; err != nil && err != gorm.ErrRecordNotFound {
return nil, err
} else if err == gorm.ErrRecordNotFound {
return nil, nil
@ -118,11 +118,11 @@ func (model *School) AfterCreate(tx *gorm.DB) error {
return nil
}
func (model *School) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *School) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
school := new(School)
if err := DB().Find(&school.AllRegions).Error; err != nil {
if err := db._db.Find(&school.AllRegions).Error; err != nil {
return nil, err
}
@ -135,8 +135,8 @@ func (model *School) Create(args map[string]string, w http.ResponseWriter, r *ht
}
// Check if this user already exists in the users table.
if user, err := school.exists(); err == nil && user != nil {
if err := DB().Where("user_id = ?", user.ID).Find(&school).Error; err != nil {
if user, err := school.exists(db); err == nil && user != nil {
if err := db._db.Where("user_id = ?", user.ID).Find(&school).Error; err != nil {
return nil, err
}
// err := setFlashMessage(w, r, "schoolExists")
@ -150,7 +150,7 @@ func (model *School) Create(args map[string]string, w http.ResponseWriter, r *ht
school.UserModifierCreate = NewUserModifierCreate(r)
school, err = CreateSchool(school)
school, err = CreateSchool(db, school)
if err != nil {
return nil, err
}
@ -158,7 +158,7 @@ func (model *School) Create(args map[string]string, w http.ResponseWriter, r *ht
}
}
func (model *School) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *School) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var school School
id := args["id"]
@ -168,31 +168,31 @@ func (model *School) Read(args map[string]string, w http.ResponseWriter, r *http
return nil, errors.NotAuthorized
}
if err := DB().Preload("User").Preload("Region").Preload("Participants.Category").Preload("Participants").First(&school, id).Error; err != nil {
if err := db._db.Preload("User").Preload("Region").Preload("Participants.Category").Preload("Participants").First(&school, id).Error; err != nil {
return nil, err
}
return &school, nil
}
func (model *School) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *School) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var schools []*School
if err := DB().Preload("Region").Preload("Participants.Category").Preload("Participants").Order("code").Find(&schools).Error; err != nil {
if err := db._db.Preload("Region").Preload("Participants.Category").Preload("Participants").Order("code").Find(&schools).Error; err != nil {
return nil, err
}
return schools, nil
}
func (model *School) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *School) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
result, err := model.Read(args, w, r)
result, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
school := result.(*School)
if err := DB().Find(&school.AllRegions).Error; err != nil {
if err := db._db.Find(&school.AllRegions).Error; err != nil {
return nil, err
}
@ -201,7 +201,7 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht
return school, nil
} else {
school, err := model.Read(args, w, r)
school, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -211,7 +211,7 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht
}
// Check if the modified school code belong to an existing school.
if user, err := school.(*School).exists(); err == nil && user != nil {
if user, err := school.(*School).exists(db); err == nil && user != nil {
if user.ID != school.(*School).UserID {
// err := setFlashMessage(w, r, "schoolExists")
// if err != nil {
@ -225,11 +225,11 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht
school.(*School).UserModifierUpdate = NewUserModifierUpdate(r)
_, err = SaveSchool(school)
_, err = SaveSchool(db, school)
if err != nil {
return nil, err
}
school, err = model.Read(args, w, r)
school, err = model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -237,35 +237,35 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht
}
}
func (model *School) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
school, err := model.Read(args, w, r)
func (model *School) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
school, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
if err := DB().Unscoped().Delete(school.(*School)).Error; err != nil {
if err := db._db.Unscoped().Delete(school.(*School)).Error; err != nil {
return nil, err
}
return school.(*School), nil
}
func CreateSchool(school *School) (*School, error) {
if err := DB().Create(school).Error; err != nil {
func CreateSchool(db *Database, school *School) (*School, error) {
if err := db._db.Create(school).Error; err != nil {
return nil, err
}
return school, nil
}
func SaveSchool(school interface{}) (interface{}, error) {
if err := DB().Omit("Region").Save(school).Error; err != nil {
func SaveSchool(db *Database, school interface{}) (interface{}, error) {
if err := db._db.Omit("Region").Save(school).Error; err != nil {
return nil, err
}
return school, nil
}
func (model *School) HasCategory(participant *Participant) (bool, error) {
func (model *School) HasCategory(db *Database, participant *Participant) (bool, error) {
var participants []*Participant
if err := DB().Where("category_id = ? AND school_id = ? AND id <> ?", participant.CategoryID, model.ID, participant.ID).Find(&participants).Error; err != nil {
if err := db._db.Where("category_id = ? AND school_id = ? AND id <> ?", participant.CategoryID, model.ID, participant.ID).Find(&participants).Error; err != nil {
return false, err
}
return len(participants) > 0, nil

View file

@ -47,7 +47,7 @@ func (model *User) BeforeSave(tx *gorm.DB) error {
return nil
}
func (model *User) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *User) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
user := new(User)
return user, nil
@ -57,7 +57,7 @@ func (model *User) Create(args map[string]string, w http.ResponseWriter, r *http
if err != nil {
return nil, err
}
user, err = CreateUser(user)
user, err = CreateUser(db, user)
if err != nil {
return nil, err
}
@ -65,36 +65,36 @@ func (model *User) Create(args map[string]string, w http.ResponseWriter, r *http
}
}
func (model *User) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *User) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var user User
id := args["id"]
if err := DB(). /*.Preload("Something")*/ First(&user, id).Error; err != nil {
if err := db._db. /*.Preload("Something")*/ First(&user, id).Error; err != nil {
return nil, err
}
return &user, nil
}
func (model *User) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *User) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
var users []*User
if err := DB(). /*.Preload("Something")*/ Order("created_at").Find(&users).Error; err != nil {
if err := db._db. /*.Preload("Something")*/ Order("created_at").Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}
func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
func (model *User) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
if r.Method == "GET" {
result, err := model.Read(args, w, r)
result, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
user := result.(*User)
// if err := DB().Find(&user.AllElements).Error; err != nil {
// if err := db._db.Find(&user.AllElements).Error; err != nil {
// return nil, err
// }
@ -103,7 +103,7 @@ func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http
return user, nil
} else {
user, err := model.Read(args, w, r)
user, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -111,11 +111,11 @@ func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http
if err != nil {
return nil, err
}
_, err = SaveUser(user)
_, err = SaveUser(db, user)
if err != nil {
return nil, err
}
user, err = model.Read(args, w, r)
user, err = model.Read(db, args, w, r)
if err != nil {
return nil, err
}
@ -123,26 +123,26 @@ func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http
}
}
func (model *User) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
user, err := model.Read(args, w, r)
func (model *User) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) {
user, err := model.Read(db, args, w, r)
if err != nil {
return nil, err
}
if err := DB().Unscoped().Delete(user.(*User)).Error; err != nil {
if err := db._db.Unscoped().Delete(user.(*User)).Error; err != nil {
return nil, err
}
return user.(*User), nil
}
func CreateUser(user *User) (*User, error) {
if err := DB().Create(user).Error; err != nil {
func CreateUser(db *Database, user *User) (*User, error) {
if err := db._db.Create(user).Error; err != nil {
return nil, err
}
return user, nil
}
func SaveUser(user interface{}) (interface{}, error) {
if err := DB(). /*.Omit("Something")*/ Save(user).Error; err != nil {
func SaveUser(db *Database, user interface{}) (interface{}, error) {
if err := db._db. /*.Omit("Something")*/ Save(user).Error; err != nil {
return nil, err
}
return user, nil

View file

@ -39,20 +39,20 @@ func NewUserModifierCreate(r *http.Request) *UserModifierCreate {
}
}
func (um *UserModifierCreate) CreatedBy() (*UserAction, error) {
func (um *UserModifierCreate) CreatedBy(db *Database) (*UserAction, error) {
action := new(UserAction)
switch um.CreatorRole {
case "participant":
var participant Participant
if err := DB().Preload("User").First(&participant, um.CreatorID).Error; err != nil {
if err := db._db.Preload("User").First(&participant, um.CreatorID).Error; err != nil {
return nil, err
}
action.User = *participant.User
case "school":
var school School
if err := DB().Preload("User").First(&school, um.CreatorID).Error; err != nil {
if err := db._db.Preload("User").First(&school, um.CreatorID).Error; err != nil {
return nil, err
}
action.User = *school.User
@ -80,19 +80,19 @@ func NewUserModifierUpdate(r *http.Request) *UserModifierUpdate {
}
}
func (um *UserModifierUpdate) UpdatedBy() (*UserAction, error) {
func (um *UserModifierUpdate) UpdatedBy(db *Database) (*UserAction, error) {
action := new(UserAction)
switch um.UpdaterRole {
case "participant":
var participant Participant
if err := DB().Preload("User").First(&participant, um.UpdaterID).Error; err != nil {
if err := db._db.Preload("User").First(&participant, um.UpdaterID).Error; err != nil {
return nil, err
}
action.User = *participant.User
case "school":
var school School
if err := DB().Preload("User").First(&school, um.UpdaterID).Error; err != nil {
if err := db._db.Preload("User").First(&school, um.UpdaterID).Error; err != nil {
return nil, err
}
action.User = *school.User