Source code for mongokat.collection

from .utils import dotdict, json_clone
from bson import ObjectId
from bson.codec_options import CodecOptions
from pymongo import ReadPreference, WriteConcern, ReturnDocument, read_preferences
import collections
import base64
from .document import Document
from .exceptions import MultipleResultsFound, ImmutableDocumentError, ProtectedFieldsError


def _param_fields(kwargs, fields):
  """
    Normalize the "fields" argument to most find methods
  """
  if fields is None:
    return
  if type(fields) in [list, set, frozenset, tuple]:
    fields = {x: True for x in fields}
  if type(fields) == dict:
    fields.setdefault("_id", False)
  kwargs["projection"] = fields


def find_method(func):
  """
    Decorator that manages smart defaults or transforms for common find methods:

     - fields/projection: list of fields to be returned. Contrary to pymongo, _id won't be added automatically
     - json: performs a json_clone on the results. Beware of performance!
     - timeout
     - return_document
  """
  def wrapped(*args, **kwargs):

    # Normalize the fields argument if passed as a positional param.
    if len(args) == 3 and func.__name__ in ("find", "find_one", "find_by_id", "find_by_ids"):
      _param_fields(kwargs, args[2])
      args = (args[0], args[1])
    elif "fields" in kwargs:
      _param_fields(kwargs, kwargs["fields"])
      del kwargs["fields"]
    elif "projection" in kwargs:
      _param_fields(kwargs, kwargs["projection"])

    if "timeout" in kwargs:
      kwargs["no_cursor_timeout"] = not bool(kwargs["timeout"])
      del kwargs["timeout"]

    if "spec" in kwargs:
      kwargs["filter"] = kwargs["spec"]
      del kwargs["spec"]

    if kwargs.get("return_document") == "after":
        kwargs["return_document"] = ReturnDocument.AFTER
    elif kwargs.get("return_document") == "before":
        kwargs["return_document"] = ReturnDocument.BEFORE

    ret = func(*args, **kwargs)

    if kwargs.get("json"):
      ret = json_clone(ret)

    return ret

  return wrapped


def patch_cursor(cursor, batch_size=None, limit=None, skip=None, sort=None, **kwargs):
  """
    Adds batch_size, limit, sort parameters to a DB cursor
  """

  if type(batch_size) == int:
    cursor.batch_size(batch_size)

  if limit is not None:
    cursor.limit(limit)

  if sort is not None:
    cursor.sort(sort)

  if skip is not None:
    cursor.skip(skip)


[docs]class Collection(object): """ mongokat.Collection wraps a pymongo.collection.Collection """ __collection__ = None __database__ = None document_class = Document structure = None immutable = False protected_fields = () def __init__(self, collection=None, database=None, client=None): """ You can pass a pymongo collection object directly, or rely on the __collection__ and/or __database__ attributes """ if collection: self.collection = collection self.database = collection.database self.client = self.database.client elif database and self.__collection__: self.database = database self.client = self.database.client self.collection = self.database[self.__collection__] elif client and self.__database__ and self.__collection__: self.client = client self.database = self.client[self.__database__] self.collection = self.database[self.__collection__] else: raise Exception("Not enough parameters given to identify the right collection!") def __call__(self, *args, **kwargs): """ Instanciates a new *Document* from this collection """ kwargs["mongokat_collection"] = self return self.document_class(*args, **kwargs) # # # READ-ONLY METHODS # #
[docs] def exists(self, query, **args): """ Returns True if the search matches at least one document """ return bool(self.find(query, **args).limit(1).count())
[docs] def count(self, *args, **kwargs): return self._collection_with_options(kwargs).count(*args, **kwargs)
[docs] def distinct(self, *args, **kwargs): return self._collection_with_options(kwargs).distinct(*args, **kwargs)
[docs] def group(self, *args, **kwargs): return self._collection_with_options(kwargs).group(*args, **kwargs)
@find_method
[docs] def aggregate(self, *args, **kwargs): # Fix weird pymongo inconsistency https://github.com/mongodb/mongo-python-driver/blob/6865ba72edcda31c717037435e7985e9e4139dd9/test/test_crud.py#L85 if "batch_size" in kwargs: kwargs["batchSize"] = kwargs["batch_size"] del kwargs["batch_size"] return self._collection_with_options(kwargs).aggregate(*args, **kwargs)
@find_method
[docs] def find(self, *args, **kwargs): return self._collection_with_options(kwargs).find(*args, **kwargs)
def _collection_with_options(self, kwargs): """ Returns a copy of the pymongo collection with various options set up """ # class DocumentClassWithFields(self.document_class): # _fetched_fields = kwargs.get("projection") # mongokat_collection = self read_preference = kwargs.get("read_preference") or getattr(self.collection, "read_preference", None) or ReadPreference.PRIMARY if "read_preference" in kwargs: del kwargs["read_preference"] # Simplified tag usage if "read_use" in kwargs: if kwargs["read_use"] == "primary": read_preference = ReadPreference.PRIMARY elif kwargs["read_use"] == "secondary": read_preference = ReadPreference.SECONDARY elif kwargs["read_use"] == "nearest": read_preference = ReadPreference.NEAREST elif kwargs["read_use"]: read_preference = read_preferences.Secondary(tag_sets=[{"use": kwargs["read_use"]}]) del kwargs["read_use"] write_concern = None if kwargs.get("w") is 0: write_concern = WriteConcern(w=0) elif kwargs.get("write_concern"): write_concern = kwargs.get("write_concern") codec_options = CodecOptions( document_class=( self.document_class, { "fetched_fields": kwargs.get("projection"), "mongokat_collection": self } ) ) return self.collection.with_options( codec_options=codec_options, read_preference=read_preference, write_concern=write_concern ) @find_method
[docs] def find_one(self, *args, **kwargs): """ Get a single document from the database. """ doc = self._collection_with_options(kwargs).find_one(*args, **kwargs) if doc is None: return None return doc
@find_method
[docs] def find_by_id(self, _id, **kwargs): """ Pass me anything that looks like an _id : str, ObjectId, {"_id": str}, {"_id": ObjectId} """ if type(_id) == dict and _id.get("_id"): return self.find_one({"_id": ObjectId(_id["_id"])}, **kwargs) return self.find_one({"_id": ObjectId(_id)}, **kwargs)
@find_method
[docs] def find_by_ids(self, _ids, projection=None, **kwargs): """ Does a big _id:$in query on any iterator """ id_list = [ObjectId(_id) for _id in _ids] if len(_ids) == 0: return [] # FIXME : this should be an empty cursor ! # Optimized path when only fetching the _id field. # Be mindful this might not filter missing documents that may not have been returned, had we done the query. if projection is not None and list(projection.keys()) == ["_id"]: return [self({"_id": x}, fetched_fields={"_id": True}) for x in id_list] else: return self.find({"_id": {"$in": id_list}}, projection=projection, **kwargs)
@find_method
[docs] def find_by_b64id(self, _id, **kwargs): """ Pass me a base64-encoded ObjectId """ return self.find_one({"_id": ObjectId(base64.b64decode(_id))}, **kwargs)
@find_method
[docs] def find_by_b64ids(self, _ids, **kwargs): """ Pass me a list of base64-encoded ObjectId """ return self.find_by_ids([ObjectId(base64.b64decode(_id)) for _id in _ids], **kwargs)
[docs] def list_column(self, *args, **kwargs): """ Return one field as a list """ return list(self.iter_column(*args, **kwargs))
[docs] def iter_column(self, query=None, field="_id", **kwargs): """ Return one field as an iterator. Beware that if your query returns records where the field is not set, it will raise a KeyError. """ find_kwargs = { "projection": {"_id": False} } find_kwargs["projection"][field] = True cursor = self._collection_with_options(kwargs).find(query, **find_kwargs) # We only want 1 field: bypass the ORM patch_cursor(cursor, **kwargs) return (dotdict(x)[field] for x in cursor)
[docs] def find_random(self, **kwargs): """ return one random document from the collection """ import random max = self.count(**kwargs) if max: num = random.randint(0, max - 1) return next(self.find(**kwargs).skip(num))
[docs] def one(self, *args, **kwargs): bson_obj = self.find(*args, **kwargs) count = bson_obj.count() if count > 1: raise MultipleResultsFound("%s results found" % count) elif count == 1: return next(bson_obj)
# # # WRITE METHODS # #
[docs] def insert(self, data, return_object=False): """ Inserts the data as a new document. """ obj = self(data) # pylint: disable=E1102 obj.save() if return_object: return obj else: return obj["_id"]
# http://api.mongodb.org/python/current/api/pymongo/collection.html
[docs] def bulk_write(self, *args, **kwargs): """ Hook are not supported for this method! """ return self.collection.bulk_write(*args, **kwargs)
[docs] def insert_one(self, document, **kwargs): ret = self.collection.insert_one(document, **kwargs) self.trigger("after_save", ids=[ret.inserted_id], replacements=[document]) return ret
[docs] def insert_many(self, documents, **kwargs): ret = self.collection.insert_many(documents, **kwargs) self.trigger("after_save", ids=ret.inserted_ids, replacements=documents) return ret
[docs] def replace_one(self, filter, replacement, **kwargs): if self.immutable: raise ImmutableDocumentError() if not kwargs.get("allow_protected_fields"): self._check_protected_fields(replacement) else: del kwargs["allow_protected_fields"] before_doc = None if self.has_trigger("before_save") or self.has_trigger("after_save"): before_doc = self.find_one(filter, read_use="primary", projection=["_id"]) if before_doc: self.trigger("before_save", replacements=[replacement], ids=[before_doc["_id"]]) ret = self.collection.replace_one(filter, replacement, **kwargs) if ret.modified_count is 0: return ret elif ret.upserted_id: self.trigger("after_save", replacements=[replacement], ids=[ret.upserted_id]) elif before_doc: self.trigger("after_save", replacements=[replacement], ids=[before_doc["_id"]]) return ret
[docs] def update_one(self, filter, update, **kwargs): if self.immutable: raise ImmutableDocumentError() if "$set" in update: if not kwargs.get("allow_protected_fields"): self._check_protected_fields(update["$set"]) else: del kwargs["allow_protected_fields"] before_doc = None if self.has_trigger("before_save") or self.has_trigger("after_save"): before_doc = self.find_one(filter, read_use="primary", projection=["_id"]) if before_doc: self.trigger("before_save", update=update, ids=[before_doc["_id"]]) ret = self.collection.update_one(filter, update, **kwargs) if ret.modified_count is 0: return ret elif ret.upserted_id: self.trigger("after_save", update=update, ids=[ret.upserted_id]) elif before_doc: self.trigger("after_save", update=update, ids=[before_doc["_id"]]) return ret
[docs] def update_many(self, filter, update, **kwargs): if self.immutable: raise ImmutableDocumentError() if "$set" in update: if not kwargs.get("allow_protected_fields"): self._check_protected_fields(update["$set"]) else: del kwargs["allow_protected_fields"] before_ids = None if self.has_trigger("before_save") or self.has_trigger("after_save"): before_ids = self.list_column(filter, read_use="primary") if before_ids: self.trigger("before_save", update=update, ids=before_ids) ret = self.collection.update_many(filter, update, **kwargs) if ret.modified_count is 0: return ret elif before_ids: self.trigger("after_save", ids=before_ids, update=update) return ret
[docs] def delete_one(self, filter, **kwargs): doc = None if self.has_trigger("before_delete") or self.has_trigger("after_delete"): doc = self.find_one(filter, read_use="primary") self.trigger("before_delete", documents=[doc]) ret = self.collection.delete_one(filter, **kwargs) if doc is not None: self.trigger("after_delete", documents=[doc]) return ret
[docs] def delete_many(self, filter, **kwargs): docs = [] if self.has_trigger("before_delete") or self.has_trigger("after_delete"): docs = list(self.find(filter, read_use="primary")) self.trigger("before_delete", documents=docs) ret = self.collection.delete_many(filter, **kwargs) if len(docs) > 0: self.trigger("after_delete", documents=docs) return ret
@find_method
[docs] def find_one_and_delete(self, filter, **kwargs): self.trigger("before_delete", filter=filter) ret = self.collection.find_one_and_delete(filter, **kwargs) if ret is None: return None doc = self(ret, fetched_fields=kwargs.get("projection")) self.trigger("after_delete", documents=[doc]) return doc
@find_method
[docs] def find_one_and_replace(self, filter, replacement, **kwargs): if self.immutable: raise ImmutableDocumentError() if not kwargs.get("allow_protected_fields"): self._check_protected_fields(replacement) else: del kwargs["allow_protected_fields"] ret = self.collection.find_one_and_replace(filter, replacement, **kwargs) if ret is None: return None doc = self(ret, fetched_fields=kwargs.get("projection")) self.trigger("after_save", documents=[doc], replacements=[replacement]) return doc
@find_method
[docs] def find_one_and_update(self, filter, update, **kwargs): if self.immutable: raise ImmutableDocumentError() if "$set" in update: if not kwargs.get("allow_protected_fields"): self._check_protected_fields(update["$set"]) else: del kwargs["allow_protected_fields"] if self.has_trigger("before_save"): before_id = self.find_one(filter, read_use="primary", projection=["_id"]) if before_id: self.trigger("before_save", update=update, ids=[before_id["_id"]]) ret = self.collection.find_one_and_update(filter, update, **kwargs) if ret is None: return None doc = self(ret, fetched_fields=kwargs.get("projection")) self.trigger("after_save", documents=[doc], update=update) return doc
# # # EVENTS MANAGEMENT # #
[docs] def has_trigger(self, event): """ Does this trigger need to run? """ return hasattr(self.document_class, event)
[docs] def trigger(self, event, filter=None, update=None, documents=None, ids=None, replacements=None): """ Trigger the after_save hook on documents, if present. """ if not self.has_trigger(event): return if documents is not None: pass elif ids is not None: documents = self.find_by_ids(ids, read_use="primary") elif filter is not None: documents = self.find(filter, read_use="primary") else: raise Exception("Trigger couldn't filter documents") for doc in documents: getattr(doc, event)(update=update, replacements=replacements)
# # # FOR BACKWARDS-COMPATIBILITY # # @property def connection(self): return self.client @property def db(self): return self.database
[docs] def save(self, to_save, **kwargs): if self.immutable and "_id" in to_save: raise ImmutableDocumentError() if not kwargs.get("allow_protected_fields"): self._check_protected_fields(to_save) else: del kwargs["allow_protected_fields"] if "safe" in kwargs: kwargs["w"] = 0 if not kwargs["safe"] else 1 del kwargs["safe"] if self.has_trigger("before_save") and "_id" in to_save: self.trigger("before_save", replacements=[to_save], ids=[to_save["_id"]]) _id = self.collection.save(to_save, **kwargs) self.trigger("after_save", replacements=[to_save], ids=[_id]) return _id
[docs] def update(self, spec, document, **kwargs): if self.immutable: raise ImmutableDocumentError() if "$set" in document: if not kwargs.get("allow_protected_fields"): self._check_protected_fields(document["$set"]) else: del kwargs["allow_protected_fields"] before_ids = None if self.has_trigger("before_save") or self.has_trigger("after_save"): before_ids = self.list_column(spec, read_use="primary") if before_ids: self.trigger("before_save", ids=before_ids, update=document) ret = self.collection.update(spec, document, **kwargs) self.trigger("after_save", ids=before_ids, update=document) return ret
[docs] def remove(self, spec_or_id=None, **kwargs): docs = [] if self.has_trigger("before_delete") or self.has_trigger("after_delete"): limit = 0 if spec_or_id is None: filter = {} elif not isinstance(spec_or_id, collections.Mapping): filter = {"_id": spec_or_id} else: filter = spec_or_id limit = 1 if kwargs.get("multi") is False else 0 docs = list(self.find(filter, read_use="primary", limit=limit)) self.trigger("before_delete", documents=docs) ret = self.collection.remove(spec_or_id=spec_or_id, **kwargs) if len(docs) > 0: self.trigger("after_delete", documents=docs) return ret
[docs] def find_and_modify(self, query={}, update=None, **kwargs): if self.immutable: raise ImmutableDocumentError() if "$set" in update: if not kwargs.get("allow_protected_fields"): self._check_protected_fields(update["$set"]) else: del kwargs["allow_protected_fields"] ret = self.collection.find_and_modify(query=query, update=update, **kwargs) if ret is None: return None self.trigger("after_save", ids=[ret["_id"]], update=update) return self(ret, fetched_fields=kwargs.get("projection"))
[docs] def get_from_id(self, _id): return self.find_one({"_id": _id})
[docs] def fetch(self, spec=None, *args, **kwargs): """ return all document which match the structure of the object `fetch()` takes the same arguments than the the pymongo.collection.find method. The query is launch against the db and collection of the object. """ if spec is None: spec = {} for key in self.structure: if key in spec: if isinstance(spec[key], dict): spec[key].update({'$exists': True}) else: spec[key] = {'$exists': True} return self.find(spec, *args, **kwargs)
[docs] def fetch_one(self, *args, **kwargs): """ return one document which match the structure of the object `fetch_one()` takes the same arguments than the the pymongo.collection.find method. If multiple documents are found, raise a MultipleResultsFound exception. If no document is found, return None The query is launch against the db and collection of the object. """ bson_obj = self.fetch(*args, **kwargs) count = bson_obj.count() if count > 1: raise MultipleResultsFound("%s results found" % count) elif count == 1: # return self(bson_obj.next(), fetched_fields=kwargs.get("projection")) return next(bson_obj)
def _check_protected_fields(self, data): if len(self.protected_fields): forbidden_fields = set(data.keys()) & set(self.protected_fields) if len(forbidden_fields) > 0: raise ProtectedFieldsError("cannot set those keys without allow_protected_fields : %s" % forbidden_fields)