oef/orm/orm.go
2019-12-09 11:15:42 +01:00

147 lines
3 KiB
Go

package orm
import (
"fmt"
"net/http"
"path"
"reflect"
"strings"
"git.andreafazzi.eu/andrea/oef/config"
"github.com/gorilla/sessions"
"github.com/jinzhu/gorm"
"github.com/jinzhu/inflection"
_ "github.com/jinzhu/gorm/dialects/mysql"
)
type IDer interface {
GetID() uint
}
type GetFn func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
var (
fns map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
currDB *gorm.DB
store = sessions.NewCookieStore([]byte(config.Config.Keys.CookieStoreKey))
categories = []string{"Junior", "Senior"}
regions = []string{
"Abruzzo",
"Basilicata",
"Calabria",
"Campania",
"Emilia-Romagna",
"Friuli-Venezia Giulia",
"Lazio",
"Liguria",
"Lombardia",
"Marche",
"Molise",
"Piemonte",
"Puglia",
"Sardegna",
"Sicilia",
"Toscana",
"Trentino-Alto Adige",
"Umbria",
"Valle d'Aosta",
"Veneto",
}
)
func init() {
fns = make(map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0)
}
func New(connection string) (*gorm.DB, error) {
db, err := gorm.Open("mysql", connection)
if err != nil {
return nil, err
}
return db, nil
}
func AutoMigrate(models ...interface{}) {
if err := currDB.AutoMigrate(models...).Error; err != nil {
panic(err)
}
}
func CreateCategories() {
for _, name := range categories {
var category Category
if err := currDB.FirstOrCreate(&category, Category{Name: name}).Error; err != nil {
panic(err)
}
}
}
func CreateRegions() {
for _, name := range regions {
var region Region
if err := currDB.FirstOrCreate(&region, Region{Name: name}).Error; err != nil {
panic(err)
}
}
}
func Use(db *gorm.DB) {
currDB = db
}
func DB() *gorm.DB {
return currDB
}
func MapHandlers(models []interface{}) error {
for _, model := range models {
name := inflection.Plural(strings.ToLower(modelName(model)))
for p, action := range map[string]string{
"": "ReadAll",
"create/": "Create",
"{id}": "Read",
"{id}/update": "Update",
"{id}/delete": "Delete",
} {
method := reflect.ValueOf(model).MethodByName(action)
if !method.IsValid() {
return fmt.Errorf("Action %s is not defined for model %s", action, name)
}
joinedPath := path.Join("/", name, p)
if strings.HasSuffix(p, "/") {
joinedPath += "/"
}
fns[joinedPath] = method.Interface().(func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error))
}
}
return nil
}
func GetFunc(path string) (GetFn, error) {
fn, ok := fns[path]
if !ok {
return nil, fmt.Errorf("Can't map path %s to any model methods.", path)
}
return fn, nil
}
func GetNothing(args map[string]string) (interface{}, error) {
return nil, nil
}
func PostNothing(args map[string]string, w http.ResponseWriter, r *http.Request) (IDer, error) {
return nil, nil
}
func modelName(s interface{}) string {
if t := reflect.TypeOf(s); t.Kind() == reflect.Ptr {
return t.Elem().Name()
} else {
return t.Name()
}
}