oef/orm/orm.go

134 lines
3.2 KiB
Go
Raw Normal View History

2019-11-04 15:00:46 +01:00
package orm
import (
"errors"
2019-11-04 15:00:46 +01:00
"fmt"
"net/http"
"path"
"reflect"
"strings"
"git.andreafazzi.eu/andrea/oef/config"
2019-11-04 15:00:46 +01:00
"github.com/jinzhu/gorm"
"github.com/jinzhu/inflection"
_ "github.com/jinzhu/gorm/dialects/mysql"
)
type IDer interface {
GetID() uint
}
type GetFn func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
2019-11-04 15:00:46 +01:00
2020-01-14 16:28:27 +01:00
type Database struct {
Config *config.ConfigT
2019-11-04 15:00:46 +01:00
models []interface{}
_db *gorm.DB
fns map[string]func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
2019-11-04 15:00:46 +01:00
}
func NewDatabase(config *config.ConfigT, models []interface{}) (*Database, error) {
var err error
2020-01-14 16:28:27 +01:00
db := new(Database)
db.Config = config
db.fns = make(map[string]func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0)
db.mapHandlers(models)
db._db, err = gorm.Open("mysql", fmt.Sprintf("%s?%s", config.Orm.Connection, config.Orm.Options))
2019-11-04 15:00:46 +01:00
if err != nil {
return nil, err
}
return db, err
2019-11-04 15:00:46 +01:00
}
func (db *Database) AutoMigrate() {
if err := db._db.AutoMigrate(db.models...).Error; err != nil {
2019-11-04 15:00:46 +01:00
panic(err)
}
}
func (db *Database) GetUser(username, password string) (*User, error) {
var user User
if err := db._db.Where("username = ? AND password = ?", username, password).First(&user).Error; err != nil {
return nil, errors.New("Authentication failed!")
2019-12-09 08:27:46 +01:00
}
return &user, nil
2019-12-09 08:27:46 +01:00
}
func (db *Database) DB() *gorm.DB {
return db._db
2019-11-04 15:00:46 +01:00
}
func CreateCategories(db *Database) {
for _, name := range categories {
var category Category
if err := db._db.FirstOrCreate(&category, Category{Name: name}).Error; err != nil {
panic(err)
}
}
2019-11-04 15:00:46 +01:00
}
func (db *Database) mapHandlers(models []interface{}) error {
2019-11-04 15:00:46 +01:00
for _, model := range models {
2020-01-02 13:01:21 +01:00
name := inflection.Plural(strings.ToLower(ModelName(model)))
2019-11-04 15:00:46 +01:00
for p, action := range map[string]string{
"": "ReadAll",
"create/": "Create",
"{id}": "Read",
"{id}/update": "Update",
"{id}/delete": "Delete",
} {
method := reflect.ValueOf(model).MethodByName(action)
if !method.IsValid() {
return fmt.Errorf("Action %s is not defined for model %s", action, name)
}
joinedPath := path.Join("/", name, p)
if strings.HasSuffix(p, "/") {
joinedPath += "/"
}
db.fns[joinedPath] = method.Interface().(func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error))
2019-11-04 15:00:46 +01:00
}
}
return nil
}
func (db *Database) GetFunc(path string) (GetFn, error) {
fn, ok := db.fns[path]
2019-11-04 15:00:46 +01:00
if !ok {
return nil, fmt.Errorf("Can't map path %s to any model methods.", path)
}
return fn, nil
}
func GetNothing(db *Database, args map[string]string) (interface{}, error) {
2019-11-04 15:00:46 +01:00
return nil, nil
}
func PostNothing(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (IDer, error) {
2019-11-04 15:00:46 +01:00
return nil, nil
}
2020-01-02 13:01:21 +01:00
func ModelName(s interface{}) string {
2020-01-07 18:33:11 +01:00
t := reflect.TypeOf(s)
switch t.Kind() {
case reflect.Ptr:
elem := t.Elem()
if strings.Contains(elem.String(), "[]") {
return strings.Replace(elem.String(), "[]*orm.", "", -1)
}
return elem.Name()
case reflect.Slice:
return strings.Replace(t.Elem().String(), "*orm.", "", -1)
default:
2019-11-04 15:00:46 +01:00
return t.Name()
}
}