diff --git a/mongo/collection.go b/mongo/collection.go index b397174dfa8459465bf77f5ea0255d4275f3c025..df339a014f64098693ec87c89b2dcc52023cd5a5 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -24,6 +24,9 @@ package mongo import ( "fmt" "strings" + "sync" + + "reflect" "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" @@ -203,18 +206,26 @@ func (self *Collection) Truncate() error { // Appends an item (map or struct) into the collection. func (self *Collection) Append(item interface{}) (interface{}, error) { var err error - var id bson.ObjectId - id = bson.NewObjectId() + id := getId(item) - // Allocating a new ID. - if err = self.collection.Insert(bson.M{"_id": id}); err != nil { - return nil, err - } + if self.parent.VersionAtLeast(2, 6, 0, 0) { + // this breaks MongoDb older than 2.6 + if _, err = self.collection.Upsert(bson.M{"_id": id}, item); err != nil { + return nil, err + } + } else { + // Allocating a new ID. + if err = self.collection.Insert(bson.M{"_id": id}); err != nil { + return nil, err + } - // Now append data the user wants to append. - if err = self.collection.Update(bson.M{"_id": id}, item); err != nil { - return nil, err + // Now append data the user wants to append. + if err = self.collection.Update(bson.M{"_id": id}, item); err != nil { + // Cleanup allocated ID + self.collection.Remove(bson.M{"_id": id}) + return nil, err + } } // Does the item satisfy the db.ID interface? @@ -232,21 +243,6 @@ func (self *Collection) Append(item interface{}) (interface{}, error) { } return id, nil - - /* - var id bson.ObjectId - var err error - - id = bson.NewObjectId() - - _, err = self.collection.Upsert(bson.M{"_id": id}, item); - - if err != nil { - return nil, err - } - - return id, nil - */ } // Returns true if the collection exists. @@ -294,3 +290,69 @@ func toNative(val interface{}) interface{} { return val } + +var ( + // idCache should be a struct if we're going to cache more than just _id field here + idCache = make(map[reflect.Type]string, 0) + idCacheMutex sync.RWMutex +) + +// Fetches object _id or generates a new one if object doesn't have one or the one it has is invalid +func getId(item interface{}) bson.ObjectId { + v := reflect.ValueOf(item) + + switch v.Kind() { + case reflect.Map: + if inItem, ok := item.(map[string]interface{}); ok { + if id, ok := inItem["_id"]; ok { + bsonId, ok := id.(bson.ObjectId) + if ok { + return bsonId + } + } + } + case reflect.Struct: + t := v.Type() + + idCacheMutex.RLock() + fieldName, found := idCache[t] + idCacheMutex.RUnlock() + + if !found { + for n := 0; n < t.NumField(); n++ { + field := t.Field(n) + if field.PkgPath != "" { + continue // Private field + } + + tag := field.Tag.Get("bson") + if tag == "" { + tag = field.Tag.Get("db") + } + + if tag == "" { + continue + } + + parts := strings.Split(tag, ",") + + if parts[0] == "_id" { + fieldName = field.Name + idCacheMutex.RLock() + idCache[t] = fieldName + idCacheMutex.RUnlock() + break + } + } + } + if fieldName != "" { + if bsonId, ok := v.FieldByName(fieldName).Interface().(bson.ObjectId); ok { + if bsonId.Valid() { + return bsonId + } + } + } + } + + return bson.NewObjectId() +} diff --git a/mongo/database.go b/mongo/database.go index 85c40d6e32526805394b597cb9176f3880bb3f94..b5148e2b97e5a6556c67d6e1b33bf9458665ccd4 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -41,6 +41,7 @@ type Source struct { connURL db.ConnectionURL session *mgo.Session database *mgo.Database + version []int } func debugEnabled() bool { @@ -75,6 +76,7 @@ func (s *Source) Clone() (db.Database, error) { connURL: s.connURL, session: s.session.Copy(), database: s.database, + version: s.version, } return clone, nil } @@ -201,3 +203,28 @@ func (s *Source) Collection(names ...string) (db.Collection, error) { return col, err } + +func (s *Source) VersionAtLeast(version ...int) bool { + // only fetch this once - it makes a db call + if len(s.version) == 0 { + buildInfo, err := s.database.Session.BuildInfo() + if err != nil { + return false + } + s.version = buildInfo.VersionArray + } + + for i := range version { + if i == len(s.version) { + return false + } + if s.version[i] < version[i] { + return false + } + + if s.version[i] > version[i] { + return true + } + } + return true +}