diff --git a/src/grung/db.py b/src/grung/db.py index ddcb6b3..a8d2da0 100644 --- a/src/grung/db.py +++ b/src/grung/db.py @@ -1,17 +1,15 @@ import inspect -import re from collections.abc import Iterable -from functools import reduce -from operator import ior from pathlib import Path from typing import List -from tinydb import Query, TinyDB, table +from tinydb import TinyDB, table from tinydb.storages import MemoryStorage from tinydb.table import Document -from grung.exceptions import CircularReferenceError, UniqueConstraintError -from grung.types import Record +from grung.exceptions import CircularReferenceError +from grung.objects import Record +from grung.validators import TypeValidator class RecordTable(table.Table): @@ -26,6 +24,12 @@ class RecordTable(table.Table): def insert(self, document): document.before_insert(self.db) + + # check field types before attempting serialization + validator = TypeValidator() + for field in document._metadata.fields.values(): + validator.validate(document, field, self.db) + doc = document.serialize() self._check_constraints(doc) @@ -68,29 +72,17 @@ class RecordTable(table.Table): def _check_constraints(self, document) -> bool: self._check_for_recursion(document) - self._check_unique(document) + for field in document._metadata.fields.values(): + field.validate(document, self.db) def _check_for_recursion(self, document) -> bool: ref = document.reference for field in document._metadata.fields.values(): if isinstance(field.default, Iterable) and ref in document[field.name]: - raise CircularReferenceError(ref, field) + raise CircularReferenceError(document, field, ref, "builtin") elif document[field.name] == ref: - raise CircularReferenceError(ref, field) - - def _check_unique(self, document) -> bool: - matches = [] - queries = reduce( - ior, - [ - Query()[field.name].matches(f"^{document[field.name]}$", flags=re.IGNORECASE) - for field in document._metadata.fields.values() - if field.unique - ], - ) - matches = [dict(match) for match in super().search(queries) if match.doc_id != document.doc_id] - if matches != []: - raise UniqueConstraintError(document, queries, matches) + raise CircularReferenceError(document, field, ref, "builtin") + return True class GrungDB(TinyDB): diff --git a/src/grung/examples.py b/src/grung/examples.py index 7753a78..bc8e9bb 100644 --- a/src/grung/examples.py +++ b/src/grung/examples.py @@ -1,18 +1,21 @@ -from grung.types import ( +import re + +from grung.objects import ( BackReference, BinaryFilePointer, Collection, DateTime, Dict, - Field, Integer, List, Password, Record, RecordDict, + String, TextFilePointer, Timestamp, ) +from grung.validators import LengthValidator, MinMaxValidator, PatternValidator class User(Record): @@ -20,13 +23,13 @@ class User(Record): def fields(cls): return [ *super().fields(), - Field("name", primary_key=True), - Integer("number", default=0), - Field("email", unique=True), + String("name", primary_key=True, validators=[LengthValidator(min=3, max=30)]), + Integer("number", default=0, validators=[MinMaxValidator(min=0, max=255)]), + String("email", unique=True, validators=[PatternValidator(re.compile(r"[^@]+@[\w\-\.]+$"))]), Password("password"), DateTime("created"), Timestamp("last_updated"), - BackReference("groups", Group), + BackReference("groups", value_type=Group), ] @@ -35,10 +38,10 @@ class Group(Record): def fields(cls): return [ *super().fields(), - Field("name", primary_key=True), - Collection("members", User), - Collection("groups", Group), - BackReference("parent", Group), + String("name", primary_key=True), + Collection("members", member_type=User), + Collection("groups", member_type=Group), + BackReference("parent", value_type=Group), ] @@ -47,10 +50,10 @@ class Album(Record): def fields(cls): inherited = [f for f in super().fields() if f.name != "name"] return inherited + [ - Field("name"), + String("name"), Dict("credits"), List("tracks"), - BackReference("artist", Artist), + BackReference("artist", value_type=Artist), BinaryFilePointer("cover", extension=".jpg"), TextFilePointer("review"), ] @@ -60,4 +63,4 @@ class Artist(User): @classmethod def fields(cls): inherited = [f for f in super().fields() if f.name != "name"] - return inherited + [Field("name"), RecordDict("albums", Album)] + return inherited + [String("name"), RecordDict("albums", member_type=Album)] diff --git a/src/grung/exceptions.py b/src/grung/exceptions.py index 5b474d8..15f27e1 100644 --- a/src/grung/exceptions.py +++ b/src/grung/exceptions.py @@ -1,42 +1,92 @@ -class UniqueConstraintError(Exception): +class ValidationError(Exception): + """ + Thrown when a record's field does not meet validation criteria. + """ + + messages = [] + template = ( + "\n" + " * Record: {_record}\n" + " * Field: {_field}\n" + " * Value: {_value}\n" + " * Validator: {_validator}\n" + "\n" + "{_messages}" + ) + + def __init__(self, record, field, validator, messages=[], **kwargs): + super().__init__( + self.template.format( + _record=dict(record), + _field=field, + _value=record[field.name], + _validator=validator, + _messages="\n".join(messages or self.__class__.messages), + **kwargs, + ) + ) + + +class InvalidFieldTypeError(ValidationError): + """ + Thrown when a document's field value does not match the field value_type. + """ + + messages = ["The value of the field is not an instance of the field's value_type."] + + +class UniqueConstraintError(ValidationError): """ Thrown when a db write operation cannot complete due to a field's unique constraint. """ - def __init__(self, document, query, collisions): + def __init__(self, record, field, validator, query, matches): super().__init__( - "\n" - f" * Record: {dict(document)}\n" - f" * Query: {query}\n" - f" * Error: Unique constraint failure\n" - " * The record matches the following existing records:\n\n" + "\n".join(str(c) for c in collisions) + record, + field, + validator, + messages=[ + f"Query: {query}", + "The record matches the following existing records:\n\n" + "\n".join(str(m) for m in matches), + ], ) -class PointerReferenceError(Exception): - """ - Thrown when a document field containing a document could not be resolve to an existing record in the database. - """ - - def __init__(self, reference): - super().__init__( - "\n" - f" * Reference: {reference}\n" - f" * Error: Invalid Pointer\n" - " * This collection member does not refer an existing record. Do you need to save it first?" - ) - - -class CircularReferenceError(Exception): +class CircularReferenceError(ValidationError): """ Thrown when a record contains a reference to itself. """ - def __init__(self, reference, field): - super().__init__( - "\n" - f" * Reference: {reference}\n" - f" * Field: {field.name}\n" - f" * Error: Circular Reference\n" - f" * This record contains a reference to itself. This will lead to infinite recursion." - ) + messages = ["This record contains a reference to itself. This will lead to infinite recursion."] + + +class MalformedPointerError(ValidationError): + """ + Thrown when a Pointer's value is not a valid reference string. + """ + + messages = ["A Pointer's value must follow the format 'TABLE_NAME::PRIMARY_KEY_NAME::PRIMARY_KEY_VALUE'."] + + +class PointerReferenceError(Exception): + """ + Thrown when a record field containing a record could not be resolve to an existing record in the database. + """ + + +class InvalidLengthError(ValidationError): + """ + Thrown when a field does not meet its length constraint. + """ + + +class InvalidSizeError(ValidationError): + """ + Thrown when a field's size is too large or too small. + """ + + +class PatternMatchError(ValidationError): + """ + Thrown when a field does not match the specified pattern. + """ diff --git a/src/grung/objects.py b/src/grung/objects.py new file mode 100644 index 0000000..12cbb98 --- /dev/null +++ b/src/grung/objects.py @@ -0,0 +1,448 @@ +from __future__ import annotations + +import hashlib +import hmac +import os +import re +import typing +from collections import namedtuple +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path + +import nanoid +from tinydb import TinyDB, where + +import grung.types +from grung.exceptions import PointerReferenceError +from grung.validators import PointerReferenceValidator, UniqueValidator + +Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"]) + + +@dataclass +class Field(grung.types.Field): + """ + Represents a single field in a Record. + """ + + name: str + default: str = None + unique: bool = False + primary_key: bool = False + validators: list = field(default_factory=lambda: []) + + value_type = str + + def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: + pass + + def after_insert(self, db: TinyDB, record: Record) -> None: + pass + + def serialize(self, value: value_type, record: Record | None = None) -> str: + if value is not None: + return str(value) + + def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type: + return value + + def validate(self, record: Record, db: TinyDB): + if self.unique: + UniqueValidator().validate(record, self, db) + for validator in self.validators: + validator.validate(record, self, db) + + +class Record(grung.types.Record): + """ + Base type for a single database record. + """ + + def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params): + self.doc_id = doc_id + fields = self.__class__.fields() + + pkey = [field for field in fields if field.primary_key] + if len(pkey) > 1: + raise Exception(f"Cannnot have more than one primary key: {pkey}") + elif pkey: + pkey = pkey[0] + else: + # 1% collision rate at ~2M records + pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True) + fields.append(pkey) + + pkey.unique = True + + self._metadata = Metadata( + table=self.__class__.__name__, + primary_key=pkey.name, + fields={f.name: f for f in fields}, + backrefs=lambda value_type: ( + field for field in fields if type(field) == BackReference and field.value_type == value_type + ), + ) + super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params)) + + @classmethod + def fields(cls): + return [] + + def serialize(self): + """ + Serialize every field on the record + """ + rec = {} + for name, _field in self._metadata.fields.items(): + rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field + return self.__class__(rec, doc_id=self.doc_id) + + def deserialize(self, db, recurse: bool = True): + """ + Deserialize every field on the record + """ + rec = {} + for name, _field in self._metadata.fields.items(): + rec[name] = _field.deserialize(self[name], db, recurse=recurse) + return self.__class__(rec, doc_id=self.doc_id) + + def before_insert(self, db: TinyDB) -> None: + for name, _field in self._metadata.fields.items(): + _field.before_insert(self[name], db, self) + + def after_insert(self, db: TinyDB) -> None: + for name, _field in self._metadata.fields.items(): + _field.after_insert(db, self) + + def update(self, **data): + for key, value in data.items(): + self[key] = value + + @property + def reference(self): + return Pointer.reference(self) + + @property + def path(self): + return Path(self._metadata.table) / self[self._metadata.primary_key] + + def __setattr__(self, key, value): + if key in self: + self[key] = value + super().__setattr__(key, value) + + def __getattr__(self, attr_name): + if attr_name in self: + return self.get(attr_name) + raise AttributeError(f"No such attribute: {attr_name}") + + def __hash__(self): + return hash(str(dict(self))) + + def __repr__(self): + return ( + f"{self.__class__.__name__}[{self.doc_id}](" + + ", ".join([f"{key}={val}" for (key, val) in self.items()]) + + ")" + ) + + +@dataclass +class String(Field): + pass + + +@dataclass +class Integer(Field): + value_type = int + default: int = 0 + + def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type: + return int(value) + + +@dataclass +class Dict(Field): + default: dict = field(default_factory=lambda: {}) + + value_type = dict + + def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]: + return dict((key, str(value)) for key, value in values.items()) + + def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]: + return values + + +@dataclass +class List(Field): + default: list = field(default_factory=lambda: []) + + value_type = list + + def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]: + return values + + def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]: + return values + + +@dataclass +class DateTime(Field): + default: datetime = datetime.utcfromtimestamp(0) + + value_type = datetime + + def serialize(self, value: value_type, record: Record | None = None) -> str: + return (value - datetime.utcfromtimestamp(0)).total_seconds() + + def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type: + return datetime.utcfromtimestamp(int(value)) + + def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: + if not value: + record[self.name] = datetime.utcnow().replace(microsecond=0) + + +@dataclass +class Timestamp(DateTime): + value_type = datetime + + def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: + super().before_insert(None, db, record) + + +@dataclass +class Password(Field): + value_type = str + default: str = None + + # Relatively weak. Consider using stronger initial values in production applications. + salt_size = 4 + digest_size = 16 + + @classmethod + def is_digest(cls, passwd: str): + if not passwd: + return False + offset = 2 * cls.salt_size # each byte is 2 hex chars + try: + if passwd[offset] != ":": + return False + digest = passwd[(offset + 1) :] # noqa + if len(digest) != cls.digest_size * 2: + return False + return re.match(r"^[0-9a-f]+$", digest) + except IndexError: + return False + + @classmethod + def get_digest(cls, passwd: str, salt: bytes = None): + if not salt: + salt = os.urandom(cls.salt_size) + digest = hashlib.blake2b(passwd.encode(), digest_size=cls.digest_size, salt=salt).hexdigest() + return digest, salt.hex() + + @classmethod + def compare(cls, passwd: value_type, stored: value_type): + stored_salt, stored_digest = stored.split(":") + input_digest, input_salt = cls.get_digest(passwd, bytes.fromhex(stored_salt)) + return hmac.compare_digest(input_digest, stored_digest) + + def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: + if value and not self.__class__.is_digest(value): + digest, salt = self.__class__.get_digest(value) + record[self.name] = f"{salt}:{digest}" + + +@dataclass +class Pointer(Field): + """ + Store a string reference to a record. + """ + + name: str = "" + value_type: grung.types.Record = Record + + def serialize(self, value: value_type | str, record: Record | None = None) -> str: + return Pointer.reference(value) + + def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: + return Pointer.dereference(value, db, recurse) + + @classmethod + def reference(cls, value: Record | str): + if isinstance(value, str): + PointerReferenceValidator().validate_string(value) + return value + + if value: + return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}" + + return None + + @classmethod + def dereference(cls, value: str, db: TinyDB, recurse: bool = True): + if not value: + return + elif type(value) == str: + table_name, pkey, pval = value.split("::") + if pval: + table = db.table(table_name) + rec = table.get(where(pkey) == pval, recurse=recurse) + if not rec: + raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!") + return rec + return value + + +@dataclass +class BackReference(Pointer): + pass + + +@dataclass +class BinaryFilePointer(Field): + """ + Write the contents of this field to disk and store the path in the db. + """ + + name: str + extension: str = ".blob" + + value_type = bytes + + def relpath(self, record): + return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}" + + def reference(self, record): + return f"/::{self.relpath(record)}" + + def dereference(self, reference, db): + relpath = reference.replace("/::", "", 1) + try: + return (db.path / relpath).read_bytes() + except FileNotFoundError: + return None + + def serialize(self, value: value_type | str, record: Record | None = None) -> str: + return self.reference(record) + + def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: + if not value: + return None + return self.dereference(value, db) + + def prepare(self, data: value_type): + """ + Return bytes to be written to disk + """ + if not data: + return + if not isinstance(data, self.value_type): + return data.encode() + return data + + def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: + if not value: + return + relpath = self.relpath(record) + path = db.path / relpath + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(self.prepare(value)) + + +@dataclass +class TextFilePointer(BinaryFilePointer): + """ + Write the contents of this field to disk and store the path in the db. + """ + + name: str + extension: str = ".txt" + + value_type = str + + def prepare(self, data: value_type): + if isinstance(data, bytes): + return data + return str(data).encode() + + def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: + if not value: + return None + buf = super().deserialize(value, db) + return buf.decode() if buf else None + + +@dataclass +class Collection(Field): + """ + A collection of pointers. + """ + + default: typing.List[value_type] = field(default_factory=lambda: []) + member_type: type = Record + + value_type = list + + def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]: + return [Pointer.reference(val) for val in values] + + def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]: + """ + Recursively deserialize the objects in this collection + """ + recs = [] + if not recurse: + return values + for val in values: + recs.append(Pointer.dereference(val, db=db, recurse=False)) + return recs + + def after_insert(self, db: TinyDB, record: Record) -> None: + """ + Populate any backreferences in the members of this collection with the parent record's uid. + """ + if not record[self.name]: + return + + for member in record[self.name]: + target = Pointer.dereference(member, db=db, recurse=False) + for backref in target._metadata.backrefs(type(record)): + target[backref.name] = record + db.save(target) + + +@dataclass +class RecordDict(Field): + default: typing.Dict[(str, Record)] = field(default_factory=lambda: {}) + member_type: type = Record + + value_type = dict + + def serialize( + self, values: typing.Dict[(str, value_type)], record: Record | None = None + ) -> typing.Dict[(str, str)]: + return dict((key, Pointer.reference(val)) for (key, val) in values.items()) + + def deserialize( + self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False + ) -> typing.Dict[(str, value_type)]: + if not recurse: + return values + return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items()) + + def after_insert(self, db: TinyDB, record: Record) -> None: + """ + Populate any backreferences in the members of this mapping with the parent record's uid. + """ + if not record[self.name]: + return + + for key, pointer in record[self.name].items(): + target = Pointer.dereference(pointer, db=db, recurse=False) + for backref in target._metadata.backrefs(type(record)): + target[backref.name] = record + db.save(target) diff --git a/src/grung/types.py b/src/grung/types.py index a6614a7..4bf4a22 100644 --- a/src/grung/types.py +++ b/src/grung/types.py @@ -1,420 +1,9 @@ -from __future__ import annotations - -import hashlib -import hmac -import os -import re import typing -from collections import namedtuple -from dataclasses import dataclass, field -from datetime import datetime -from pathlib import Path - -import nanoid -from tinydb import TinyDB, where - -from grung.exceptions import PointerReferenceError - -Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"]) -@dataclass class Field: - """ - Represents a single field in a Record. - """ - - name: str - value_type: type = str - default: str = None - unique: bool = False - primary_key: bool = False - - def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: - pass - - def after_insert(self, db: TinyDB, record: Record) -> None: - pass - - def serialize(self, value: value_type, record: Record | None = None) -> str: - if value is not None: - return str(value) - - def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type: - return value - - -@dataclass -class Integer(Field): - value_type = int - default: int = 0 - - def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type: - return int(value) - - -@dataclass -class Dict(Field): - value_type: type = str - default: dict = field(default_factory=lambda: {}) - - def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]: - return dict((key, str(value)) for key, value in values.items()) - - def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]: - return values - - -@dataclass -class List(Field): - value_type: type = list - default: list = field(default_factory=lambda: []) - - def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]: - return values - - def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]: - return values - - -@dataclass -class DateTime(Field): - value_type: datetime - default: datetime = datetime.utcfromtimestamp(0) - - def serialize(self, value: value_type, record: Record | None = None) -> str: - return (value - datetime.utcfromtimestamp(0)).total_seconds() - - def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type: - return datetime.utcfromtimestamp(int(value)) - - def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: - if not value: - record[self.name] = datetime.utcnow().replace(microsecond=0) - - -@dataclass -class Timestamp(DateTime): - value_type: datetime - - def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: - super().before_insert(None, db, record) - - -@dataclass -class Password(Field): - value_type = str - default: str = None - - # Relatively weak. Consider using stronger initial values in production applications. - salt_size = 4 - digest_size = 16 - - @classmethod - def is_digest(cls, passwd: str): - if not passwd: - return False - offset = 2 * cls.salt_size # each byte is 2 hex chars - try: - if passwd[offset] != ":": - return False - digest = passwd[(offset + 1) :] # noqa - if len(digest) != cls.digest_size * 2: - return False - return re.match(r"^[0-9a-f]+$", digest) - except IndexError: - return False - - @classmethod - def get_digest(cls, passwd: str, salt: bytes = None): - if not salt: - salt = os.urandom(cls.salt_size) - digest = hashlib.blake2b(passwd.encode(), digest_size=cls.digest_size, salt=salt).hexdigest() - return digest, salt.hex() - - @classmethod - def compare(cls, passwd: value_type, stored: value_type): - stored_salt, stored_digest = stored.split(":") - input_digest, input_salt = cls.get_digest(passwd, bytes.fromhex(stored_salt)) - return hmac.compare_digest(input_digest, stored_digest) - - def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: - if value and not self.__class__.is_digest(value): - digest, salt = self.__class__.get_digest(value) - record[self.name] = f"{salt}:{digest}" - - -class Record(typing.Dict[(str, Field)]): - """ - Base type for a single database record. - """ - - def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params): - self.doc_id = doc_id - fields = self.__class__.fields() - - pkey = [field for field in fields if field.primary_key] - if len(pkey) > 1: - raise Exception(f"Cannnot have more than one primary key: {pkey}") - elif pkey: - pkey = pkey[0] - else: - # 1% collision rate at ~2M records - pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True) - fields.append(pkey) - - pkey.unique = True - - self._metadata = Metadata( - table=self.__class__.__name__, - primary_key=pkey.name, - fields={f.name: f for f in fields}, - backrefs=lambda value_type: ( - field for field in fields if type(field) == BackReference and field.value_type == value_type - ), - ) - super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params)) - - @classmethod - def fields(cls): - return [] - - def serialize(self): - """ - Serialize every field on the record - """ - rec = {} - for name, _field in self._metadata.fields.items(): - rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field - return self.__class__(rec, doc_id=self.doc_id) - - def deserialize(self, db, recurse: bool = True): - """ - Deserialize every field on the record - """ - rec = {} - for name, _field in self._metadata.fields.items(): - rec[name] = _field.deserialize(self[name], db, recurse=recurse) - return self.__class__(rec, doc_id=self.doc_id) - - def before_insert(self, db: TinyDB) -> None: - for name, _field in self._metadata.fields.items(): - _field.before_insert(self[name], db, self) - - def after_insert(self, db: TinyDB) -> None: - for name, _field in self._metadata.fields.items(): - _field.after_insert(db, self) - - @property - def reference(self): - return Pointer.reference(self) - - @property - def path(self): - return Path(self._metadata.table) / self[self._metadata.primary_key] - - def __setattr__(self, key, value): - if key in self: - self[key] = value - super().__setattr__(key, value) - - def __getattr__(self, attr_name): - if attr_name in self: - return self.get(attr_name) - raise AttributeError(f"No such attribute: {attr_name}") - - def __hash__(self): - return hash(str(dict(self))) - - def __repr__(self): - return ( - f"{self.__class__.__name__}[{self.doc_id}](" - + ", ".join([f"{key}={val}" for (key, val) in self.items()]) - + ")" - ) - - -@dataclass -class Pointer(Field): - """ - Store a string reference to a record. - """ - - name: str = "" - value_type: type = Record - - def serialize(self, value: value_type | str, record: Record | None = None) -> str: - return Pointer.reference(value) - - def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: - return Pointer.dereference(value, db, recurse) - - @classmethod - def reference(cls, value: Record | str): - # XXX This could be smarter - if isinstance(value, str): - if "::" not in value: - raise PointerReferenceError("Value {value} does not look like a reference!") - return value - if value: - return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}" - return None - - @classmethod - def dereference(cls, value: str, db: TinyDB, recurse: bool = True): - if not value: - return - elif type(value) == str: - table_name, pkey, pval = value.split("::") - if pval: - table = db.table(table_name) - rec = table.get(where(pkey) == pval, recurse=recurse) - if not rec: - raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!") - return rec - return value - - -@dataclass -class BackReference(Pointer): pass -@dataclass -class BinaryFilePointer(Field): - """ - Write the contents of this field to disk and store the path in the db. - """ - - name: str - value_type: type = bytes - extension: str = ".blob" - - def relpath(self, record): - return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}" - - def reference(self, record): - return f"/::{self.relpath(record)}" - - def dereference(self, reference, db): - relpath = reference.replace("/::", "", 1) - try: - return (db.path / relpath).read_bytes() - except FileNotFoundError: - return None - - def serialize(self, value: value_type | str, record: Record | None = None) -> str: - return self.reference(record) - - def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: - if not value: - return None - return self.dereference(value, db) - - def prepare(self, data: value_type): - """ - Return bytes to be written to disk - """ - if not data: - return - if not isinstance(data, self.value_type): - return data.encode() - return data - - def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None: - if not value: - return - relpath = self.relpath(record) - path = db.path / relpath - path.parent.mkdir(parents=True, exist_ok=True) - path.write_bytes(self.prepare(value)) - - -@dataclass -class TextFilePointer(BinaryFilePointer): - """ - Write the contents of this field to disk and store the path in the db. - """ - - name: str - value_type: type = str - extension: str = ".txt" - - def prepare(self, data: value_type): - if isinstance(data, bytes): - return data - return str(data).encode() - - def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type: - if not value: - return None - buf = super().deserialize(value, db) - return buf.decode() if buf else None - - -@dataclass -class Collection(Field): - """ - A collection of pointers. - """ - - value_type: type = Record - default: typing.List[value_type] = field(default_factory=lambda: []) - - def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]: - return [Pointer.reference(val) for val in values] - - def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]: - """ - Recursively deserialize the objects in this collection - """ - recs = [] - if not recurse: - return values - for val in values: - recs.append(Pointer.dereference(val, db=db, recurse=False)) - return recs - - def after_insert(self, db: TinyDB, record: Record) -> None: - """ - Populate any backreferences in the members of this collection with the parent record's uid. - """ - if not record[self.name]: - return - - for member in record[self.name]: - target = Pointer.dereference(member, db=db, recurse=False) - for backref in target._metadata.backrefs(type(record)): - target[backref.name] = record - db.save(target) - - -@dataclass -class RecordDict(Field): - value_type: type = Record - default: typing.Dict[(str, Record)] = field(default_factory=lambda: {}) - - def serialize( - self, values: typing.Dict[(str, value_type)], record: Record | None = None - ) -> typing.Dict[(str, str)]: - return dict((key, Pointer.reference(val)) for (key, val) in values.items()) - - def deserialize( - self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False - ) -> typing.Dict[(str, value_type)]: - if not recurse: - return values - return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items()) - - def after_insert(self, db: TinyDB, record: Record) -> None: - """ - Populate any backreferences in the members of this mapping with the parent record's uid. - """ - if not record[self.name]: - return - - for key, pointer in record[self.name].items(): - target = Pointer.dereference(pointer, db=db, recurse=False) - for backref in target._metadata.backrefs(type(record)): - target[backref.name] = record - db.save(target) +class Record(typing.Dict[(str, Field)]): + pass diff --git a/src/grung/validators.py b/src/grung/validators.py new file mode 100644 index 0000000..d845552 --- /dev/null +++ b/src/grung/validators.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass + +from tinydb import Query, TinyDB + +from grung.exceptions import ( + InvalidFieldTypeError, + InvalidLengthError, + InvalidSizeError, + MalformedPointerError, + PatternMatchError, + UniqueConstraintError, + ValidationError, +) +from grung.types import Field, Record + + +@dataclass +class Validator: + def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool: + raise ValidationError(record, field, self) + + +@dataclass +class TypeValidator: + def validate_list(self, record: Record, field: Field, db: TinyDB = None) -> bool: + messages = [] + for i in range(len(record[field.name])): + member = record[field.name][i] + if isinstance(member, str): + class_name = field.member_type.__name__ + try: + return PointerReferenceValidator().validate_string(member, class_name) + except MalformedPointerError as e: + messages.append(str(e)) + elif not isinstance(member, field.member_type): + messages.append(f"{field.name}[{i}] must be a {field.member_type}, not a {type(member)}.") + if messages: + raise InvalidFieldTypeError(record, field, self, messages=messages) + + def validate_dict(self, record: Record, field: Field, db: TinyDB = None) -> bool: + for key, member in record[field.name].items(): + if not isinstance(member, field.member_type): + raise InvalidFieldTypeError( + record, + field, + self, + messages=[f"{field.name}[{key} must be a {field.member_type}, not a {type(member)}."], + ) + return True + + def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool: + if record[field.name] is None: + return True + + if not isinstance(record[field.name], field.value_type): + raise InvalidFieldTypeError( + record, + field, + self, + messages=[f"{field.name} must be a {field.value_type}, not a {type(record[field.name])}."], + ) + if not hasattr(field, "member_type"): + return True + + if field.value_type == dict: + self.validate_dict(record, field, db) + elif field.value_type == list: + self.validate_list(record, field, db) + else: + raise RuntimeError("Expected a validation for iterable but didn't get one!") + return True + + +@dataclass +class PointerReferenceValidator(Validator): + """ + Verify that the Pointer is either a string reference to the correct member type, + or a record instance of member_type that hasn't been serialized. + """ + + def validate_string(self, value: str, type_name: str = "") -> bool: + (table, primary_key, val) = value.split("::") + if type_name and table != type_name: + raise MalformedPointerError( + {"string": value}, + "", + self, + messages=[f"field should reference '{type_name}', not '{table}'."], + ) + if not primary_key: + raise MalformedPointerError( + {"string": value}, + "", + self, + messages=["Pointers must specify the primary_key field name."], + ) + + def validate(self, record: Record | str, field: Field, db: TinyDB = None) -> bool: + if record[field.name] is None: + return True + + if isinstance(record, str): + try: + self.validate_string(record, field.value_type.__name__) + except ValueError: + raise MalformedPointerError({field.name: record}, field, self) + return True + + if not isinstance(record, Record): + raise MalformedPointerError(record, field, self) + + if not isinstance(record[field.name], field.value_type): + raise MalformedPointerError(record, field, self) + + return True + + +@dataclass +class UniqueValidator(Validator): + def validate(self, record: Record, field: Field, db: TinyDB) -> bool: + """ + Returns true if the field's value is unique across all records in the table. + """ + if record[field.name] is None: + return True + + query = Query()[field.name].matches(f"^{record[field.name]}$", flags=re.IGNORECASE) + table = db.table(record._metadata.table) + matches = [dict(match) for match in table.search(query) if match.doc_id != record.doc_id] + if matches != []: + raise UniqueConstraintError(record, field, self, query=query, matches=matches) + return True + + +@dataclass +class LengthValidator(Validator): + min: int = 0 + max: int = 0 + + def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool: + """ + Returns True if the length of the field's value is between min and max, inclusive. + """ + if record[field.name] is None: + return True + + length = len(record[field.name]) + if length < self.min or length > self.max: + raise InvalidLengthError( + record, + field, + self, + messages=[f"The field length must be between {self.min} and {self.max}, inclusive."], + ) + return True + + +@dataclass +class MinMaxValidator(Validator): + min: int = 0 + max: int = 0 + + def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool: + """ + Returns True if the size of the field's integer value is between min and max, inclusive. + """ + if record[field.name] is None: + return True + + size = int(record[field.name]) + if size < self.min or size > self.max: + raise InvalidSizeError( + record, + field, + self, + messages=[f"The field size must be between {self.min} and {self.max}, inclusive."], + ) + return True + + +@dataclass +class PatternValidator(Validator): + pattern: re.Pattern + + def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool: + if record[field.name] is None: + return True + + if not self.pattern.match(record[field.name]): + raise PatternMatchError( + record, + field, + self, + messages=[f"The field value must match the pattern {self.pattern}"], + ) + return True diff --git a/test/test_db.py b/test/test_db.py index edff94b..03310f9 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -10,7 +10,14 @@ from tinydb.storages import MemoryStorage from grung import examples from grung.db import GrungDB -from grung.exceptions import CircularReferenceError, UniqueConstraintError +from grung.exceptions import ( + CircularReferenceError, + InvalidFieldTypeError, + InvalidLengthError, + InvalidSizeError, + PatternMatchError, + UniqueConstraintError, +) @pytest.fixture @@ -82,7 +89,7 @@ def test_subgroups(db): # recursion! with pytest.raises(CircularReferenceError): - tos.members = [tos] + tos.groups = [tos] db.save(tos) @@ -201,3 +208,29 @@ def test_file_pointers(db): location_on_disk = db.path / album._metadata.fields["review"].relpath(album) assert location_on_disk.read_text() == album.review + + +@pytest.mark.parametrize( + "updates, expected", + [ + ({"name": ""}, InvalidLengthError), + ({"name": "a name longer than 30 characters is what we have here"}, InvalidLengthError), + ({"name": 23}, InvalidFieldTypeError), + ({"number": -1}, InvalidSizeError), + ({"number": 256}, InvalidSizeError), + ({"email": "foo+alias@"}, PatternMatchError), + ], + ids=[ + "name too short", + "name too long", + "name is not a string", + "number too small", + "number too big", + "invalid email addres", + ], +) +def test_validators(updates, expected, db): + user = db.save(examples.User(name="john", email="john@foo", password="fnord", created=datetime.utcnow())) + with pytest.raises(expected): + user.update(**updates) + db.save(user)