Add validators
This commit is contained in:
parent
b7d7ef9638
commit
64aef7c18b
|
|
@ -1,17 +1,15 @@
|
||||||
import inspect
|
import inspect
|
||||||
import re
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from functools import reduce
|
|
||||||
from operator import ior
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from tinydb import Query, TinyDB, table
|
from tinydb import TinyDB, table
|
||||||
from tinydb.storages import MemoryStorage
|
from tinydb.storages import MemoryStorage
|
||||||
from tinydb.table import Document
|
from tinydb.table import Document
|
||||||
|
|
||||||
from grung.exceptions import CircularReferenceError, UniqueConstraintError
|
from grung.exceptions import CircularReferenceError
|
||||||
from grung.types import Record
|
from grung.objects import Record
|
||||||
|
from grung.validators import TypeValidator
|
||||||
|
|
||||||
|
|
||||||
class RecordTable(table.Table):
|
class RecordTable(table.Table):
|
||||||
|
|
@ -26,6 +24,12 @@ class RecordTable(table.Table):
|
||||||
|
|
||||||
def insert(self, document):
|
def insert(self, document):
|
||||||
document.before_insert(self.db)
|
document.before_insert(self.db)
|
||||||
|
|
||||||
|
# check field types before attempting serialization
|
||||||
|
validator = TypeValidator()
|
||||||
|
for field in document._metadata.fields.values():
|
||||||
|
validator.validate(document, field, self.db)
|
||||||
|
|
||||||
doc = document.serialize()
|
doc = document.serialize()
|
||||||
self._check_constraints(doc)
|
self._check_constraints(doc)
|
||||||
|
|
||||||
|
|
@ -68,29 +72,17 @@ class RecordTable(table.Table):
|
||||||
|
|
||||||
def _check_constraints(self, document) -> bool:
|
def _check_constraints(self, document) -> bool:
|
||||||
self._check_for_recursion(document)
|
self._check_for_recursion(document)
|
||||||
self._check_unique(document)
|
for field in document._metadata.fields.values():
|
||||||
|
field.validate(document, self.db)
|
||||||
|
|
||||||
def _check_for_recursion(self, document) -> bool:
|
def _check_for_recursion(self, document) -> bool:
|
||||||
ref = document.reference
|
ref = document.reference
|
||||||
for field in document._metadata.fields.values():
|
for field in document._metadata.fields.values():
|
||||||
if isinstance(field.default, Iterable) and ref in document[field.name]:
|
if isinstance(field.default, Iterable) and ref in document[field.name]:
|
||||||
raise CircularReferenceError(ref, field)
|
raise CircularReferenceError(document, field, ref, "builtin")
|
||||||
elif document[field.name] == ref:
|
elif document[field.name] == ref:
|
||||||
raise CircularReferenceError(ref, field)
|
raise CircularReferenceError(document, field, ref, "builtin")
|
||||||
|
return True
|
||||||
def _check_unique(self, document) -> bool:
|
|
||||||
matches = []
|
|
||||||
queries = reduce(
|
|
||||||
ior,
|
|
||||||
[
|
|
||||||
Query()[field.name].matches(f"^{document[field.name]}$", flags=re.IGNORECASE)
|
|
||||||
for field in document._metadata.fields.values()
|
|
||||||
if field.unique
|
|
||||||
],
|
|
||||||
)
|
|
||||||
matches = [dict(match) for match in super().search(queries) if match.doc_id != document.doc_id]
|
|
||||||
if matches != []:
|
|
||||||
raise UniqueConstraintError(document, queries, matches)
|
|
||||||
|
|
||||||
|
|
||||||
class GrungDB(TinyDB):
|
class GrungDB(TinyDB):
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,21 @@
|
||||||
from grung.types import (
|
import re
|
||||||
|
|
||||||
|
from grung.objects import (
|
||||||
BackReference,
|
BackReference,
|
||||||
BinaryFilePointer,
|
BinaryFilePointer,
|
||||||
Collection,
|
Collection,
|
||||||
DateTime,
|
DateTime,
|
||||||
Dict,
|
Dict,
|
||||||
Field,
|
|
||||||
Integer,
|
Integer,
|
||||||
List,
|
List,
|
||||||
Password,
|
Password,
|
||||||
Record,
|
Record,
|
||||||
RecordDict,
|
RecordDict,
|
||||||
|
String,
|
||||||
TextFilePointer,
|
TextFilePointer,
|
||||||
Timestamp,
|
Timestamp,
|
||||||
)
|
)
|
||||||
|
from grung.validators import LengthValidator, MinMaxValidator, PatternValidator
|
||||||
|
|
||||||
|
|
||||||
class User(Record):
|
class User(Record):
|
||||||
|
|
@ -20,13 +23,13 @@ class User(Record):
|
||||||
def fields(cls):
|
def fields(cls):
|
||||||
return [
|
return [
|
||||||
*super().fields(),
|
*super().fields(),
|
||||||
Field("name", primary_key=True),
|
String("name", primary_key=True, validators=[LengthValidator(min=3, max=30)]),
|
||||||
Integer("number", default=0),
|
Integer("number", default=0, validators=[MinMaxValidator(min=0, max=255)]),
|
||||||
Field("email", unique=True),
|
String("email", unique=True, validators=[PatternValidator(re.compile(r"[^@]+@[\w\-\.]+$"))]),
|
||||||
Password("password"),
|
Password("password"),
|
||||||
DateTime("created"),
|
DateTime("created"),
|
||||||
Timestamp("last_updated"),
|
Timestamp("last_updated"),
|
||||||
BackReference("groups", Group),
|
BackReference("groups", value_type=Group),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -35,10 +38,10 @@ class Group(Record):
|
||||||
def fields(cls):
|
def fields(cls):
|
||||||
return [
|
return [
|
||||||
*super().fields(),
|
*super().fields(),
|
||||||
Field("name", primary_key=True),
|
String("name", primary_key=True),
|
||||||
Collection("members", User),
|
Collection("members", member_type=User),
|
||||||
Collection("groups", Group),
|
Collection("groups", member_type=Group),
|
||||||
BackReference("parent", Group),
|
BackReference("parent", value_type=Group),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -47,10 +50,10 @@ class Album(Record):
|
||||||
def fields(cls):
|
def fields(cls):
|
||||||
inherited = [f for f in super().fields() if f.name != "name"]
|
inherited = [f for f in super().fields() if f.name != "name"]
|
||||||
return inherited + [
|
return inherited + [
|
||||||
Field("name"),
|
String("name"),
|
||||||
Dict("credits"),
|
Dict("credits"),
|
||||||
List("tracks"),
|
List("tracks"),
|
||||||
BackReference("artist", Artist),
|
BackReference("artist", value_type=Artist),
|
||||||
BinaryFilePointer("cover", extension=".jpg"),
|
BinaryFilePointer("cover", extension=".jpg"),
|
||||||
TextFilePointer("review"),
|
TextFilePointer("review"),
|
||||||
]
|
]
|
||||||
|
|
@ -60,4 +63,4 @@ class Artist(User):
|
||||||
@classmethod
|
@classmethod
|
||||||
def fields(cls):
|
def fields(cls):
|
||||||
inherited = [f for f in super().fields() if f.name != "name"]
|
inherited = [f for f in super().fields() if f.name != "name"]
|
||||||
return inherited + [Field("name"), RecordDict("albums", Album)]
|
return inherited + [String("name"), RecordDict("albums", member_type=Album)]
|
||||||
|
|
|
||||||
|
|
@ -1,42 +1,92 @@
|
||||||
class UniqueConstraintError(Exception):
|
class ValidationError(Exception):
|
||||||
|
"""
|
||||||
|
Thrown when a record's field does not meet validation criteria.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
template = (
|
||||||
|
"\n"
|
||||||
|
" * Record: {_record}\n"
|
||||||
|
" * Field: {_field}\n"
|
||||||
|
" * Value: {_value}\n"
|
||||||
|
" * Validator: {_validator}\n"
|
||||||
|
"\n"
|
||||||
|
"{_messages}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, record, field, validator, messages=[], **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
self.template.format(
|
||||||
|
_record=dict(record),
|
||||||
|
_field=field,
|
||||||
|
_value=record[field.name],
|
||||||
|
_validator=validator,
|
||||||
|
_messages="\n".join(messages or self.__class__.messages),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidFieldTypeError(ValidationError):
|
||||||
|
"""
|
||||||
|
Thrown when a document's field value does not match the field value_type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = ["The value of the field is not an instance of the field's value_type."]
|
||||||
|
|
||||||
|
|
||||||
|
class UniqueConstraintError(ValidationError):
|
||||||
"""
|
"""
|
||||||
Thrown when a db write operation cannot complete due to a field's unique constraint.
|
Thrown when a db write operation cannot complete due to a field's unique constraint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, document, query, collisions):
|
def __init__(self, record, field, validator, query, matches):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
"\n"
|
record,
|
||||||
f" * Record: {dict(document)}\n"
|
field,
|
||||||
f" * Query: {query}\n"
|
validator,
|
||||||
f" * Error: Unique constraint failure\n"
|
messages=[
|
||||||
" * The record matches the following existing records:\n\n" + "\n".join(str(c) for c in collisions)
|
f"Query: {query}",
|
||||||
|
"The record matches the following existing records:\n\n" + "\n".join(str(m) for m in matches),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PointerReferenceError(Exception):
|
class CircularReferenceError(ValidationError):
|
||||||
"""
|
|
||||||
Thrown when a document field containing a document could not be resolve to an existing record in the database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, reference):
|
|
||||||
super().__init__(
|
|
||||||
"\n"
|
|
||||||
f" * Reference: {reference}\n"
|
|
||||||
f" * Error: Invalid Pointer\n"
|
|
||||||
" * This collection member does not refer an existing record. Do you need to save it first?"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CircularReferenceError(Exception):
|
|
||||||
"""
|
"""
|
||||||
Thrown when a record contains a reference to itself.
|
Thrown when a record contains a reference to itself.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reference, field):
|
messages = ["This record contains a reference to itself. This will lead to infinite recursion."]
|
||||||
super().__init__(
|
|
||||||
"\n"
|
|
||||||
f" * Reference: {reference}\n"
|
class MalformedPointerError(ValidationError):
|
||||||
f" * Field: {field.name}\n"
|
"""
|
||||||
f" * Error: Circular Reference\n"
|
Thrown when a Pointer's value is not a valid reference string.
|
||||||
f" * This record contains a reference to itself. This will lead to infinite recursion."
|
"""
|
||||||
)
|
|
||||||
|
messages = ["A Pointer's value must follow the format 'TABLE_NAME::PRIMARY_KEY_NAME::PRIMARY_KEY_VALUE'."]
|
||||||
|
|
||||||
|
|
||||||
|
class PointerReferenceError(Exception):
|
||||||
|
"""
|
||||||
|
Thrown when a record field containing a record could not be resolve to an existing record in the database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidLengthError(ValidationError):
|
||||||
|
"""
|
||||||
|
Thrown when a field does not meet its length constraint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSizeError(ValidationError):
|
||||||
|
"""
|
||||||
|
Thrown when a field's size is too large or too small.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class PatternMatchError(ValidationError):
|
||||||
|
"""
|
||||||
|
Thrown when a field does not match the specified pattern.
|
||||||
|
"""
|
||||||
|
|
|
||||||
448
src/grung/objects.py
Normal file
448
src/grung/objects.py
Normal file
|
|
@ -0,0 +1,448 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import typing
|
||||||
|
from collections import namedtuple
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import nanoid
|
||||||
|
from tinydb import TinyDB, where
|
||||||
|
|
||||||
|
import grung.types
|
||||||
|
from grung.exceptions import PointerReferenceError
|
||||||
|
from grung.validators import PointerReferenceValidator, UniqueValidator
|
||||||
|
|
||||||
|
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"])
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Field(grung.types.Field):
|
||||||
|
"""
|
||||||
|
Represents a single field in a Record.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
default: str = None
|
||||||
|
unique: bool = False
|
||||||
|
primary_key: bool = False
|
||||||
|
validators: list = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
value_type = str
|
||||||
|
|
||||||
|
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
||||||
|
if value is not None:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||||
|
return value
|
||||||
|
|
||||||
|
def validate(self, record: Record, db: TinyDB):
|
||||||
|
if self.unique:
|
||||||
|
UniqueValidator().validate(record, self, db)
|
||||||
|
for validator in self.validators:
|
||||||
|
validator.validate(record, self, db)
|
||||||
|
|
||||||
|
|
||||||
|
class Record(grung.types.Record):
|
||||||
|
"""
|
||||||
|
Base type for a single database record.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
||||||
|
self.doc_id = doc_id
|
||||||
|
fields = self.__class__.fields()
|
||||||
|
|
||||||
|
pkey = [field for field in fields if field.primary_key]
|
||||||
|
if len(pkey) > 1:
|
||||||
|
raise Exception(f"Cannnot have more than one primary key: {pkey}")
|
||||||
|
elif pkey:
|
||||||
|
pkey = pkey[0]
|
||||||
|
else:
|
||||||
|
# 1% collision rate at ~2M records
|
||||||
|
pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True)
|
||||||
|
fields.append(pkey)
|
||||||
|
|
||||||
|
pkey.unique = True
|
||||||
|
|
||||||
|
self._metadata = Metadata(
|
||||||
|
table=self.__class__.__name__,
|
||||||
|
primary_key=pkey.name,
|
||||||
|
fields={f.name: f for f in fields},
|
||||||
|
backrefs=lambda value_type: (
|
||||||
|
field for field in fields if type(field) == BackReference and field.value_type == value_type
|
||||||
|
),
|
||||||
|
)
|
||||||
|
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fields(cls):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
"""
|
||||||
|
Serialize every field on the record
|
||||||
|
"""
|
||||||
|
rec = {}
|
||||||
|
for name, _field in self._metadata.fields.items():
|
||||||
|
rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field
|
||||||
|
return self.__class__(rec, doc_id=self.doc_id)
|
||||||
|
|
||||||
|
def deserialize(self, db, recurse: bool = True):
|
||||||
|
"""
|
||||||
|
Deserialize every field on the record
|
||||||
|
"""
|
||||||
|
rec = {}
|
||||||
|
for name, _field in self._metadata.fields.items():
|
||||||
|
rec[name] = _field.deserialize(self[name], db, recurse=recurse)
|
||||||
|
return self.__class__(rec, doc_id=self.doc_id)
|
||||||
|
|
||||||
|
def before_insert(self, db: TinyDB) -> None:
|
||||||
|
for name, _field in self._metadata.fields.items():
|
||||||
|
_field.before_insert(self[name], db, self)
|
||||||
|
|
||||||
|
def after_insert(self, db: TinyDB) -> None:
|
||||||
|
for name, _field in self._metadata.fields.items():
|
||||||
|
_field.after_insert(db, self)
|
||||||
|
|
||||||
|
def update(self, **data):
|
||||||
|
for key, value in data.items():
|
||||||
|
self[key] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reference(self):
|
||||||
|
return Pointer.reference(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
return Path(self._metadata.table) / self[self._metadata.primary_key]
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if key in self:
|
||||||
|
self[key] = value
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, attr_name):
|
||||||
|
if attr_name in self:
|
||||||
|
return self.get(attr_name)
|
||||||
|
raise AttributeError(f"No such attribute: {attr_name}")
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(str(dict(self)))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"{self.__class__.__name__}[{self.doc_id}]("
|
||||||
|
+ ", ".join([f"{key}={val}" for (key, val) in self.items()])
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class String(Field):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Integer(Field):
|
||||||
|
value_type = int
|
||||||
|
default: int = 0
|
||||||
|
|
||||||
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||||
|
return int(value)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Dict(Field):
|
||||||
|
default: dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
value_type = dict
|
||||||
|
|
||||||
|
def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]:
|
||||||
|
return dict((key, str(value)) for key, value in values.items())
|
||||||
|
|
||||||
|
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class List(Field):
|
||||||
|
default: list = field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
value_type = list
|
||||||
|
|
||||||
|
def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]:
|
||||||
|
return values
|
||||||
|
|
||||||
|
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DateTime(Field):
|
||||||
|
default: datetime = datetime.utcfromtimestamp(0)
|
||||||
|
|
||||||
|
value_type = datetime
|
||||||
|
|
||||||
|
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
||||||
|
return (value - datetime.utcfromtimestamp(0)).total_seconds()
|
||||||
|
|
||||||
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||||
|
return datetime.utcfromtimestamp(int(value))
|
||||||
|
|
||||||
|
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||||
|
if not value:
|
||||||
|
record[self.name] = datetime.utcnow().replace(microsecond=0)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Timestamp(DateTime):
|
||||||
|
value_type = datetime
|
||||||
|
|
||||||
|
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||||
|
super().before_insert(None, db, record)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Password(Field):
|
||||||
|
value_type = str
|
||||||
|
default: str = None
|
||||||
|
|
||||||
|
# Relatively weak. Consider using stronger initial values in production applications.
|
||||||
|
salt_size = 4
|
||||||
|
digest_size = 16
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_digest(cls, passwd: str):
|
||||||
|
if not passwd:
|
||||||
|
return False
|
||||||
|
offset = 2 * cls.salt_size # each byte is 2 hex chars
|
||||||
|
try:
|
||||||
|
if passwd[offset] != ":":
|
||||||
|
return False
|
||||||
|
digest = passwd[(offset + 1) :] # noqa
|
||||||
|
if len(digest) != cls.digest_size * 2:
|
||||||
|
return False
|
||||||
|
return re.match(r"^[0-9a-f]+$", digest)
|
||||||
|
except IndexError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_digest(cls, passwd: str, salt: bytes = None):
|
||||||
|
if not salt:
|
||||||
|
salt = os.urandom(cls.salt_size)
|
||||||
|
digest = hashlib.blake2b(passwd.encode(), digest_size=cls.digest_size, salt=salt).hexdigest()
|
||||||
|
return digest, salt.hex()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def compare(cls, passwd: value_type, stored: value_type):
|
||||||
|
stored_salt, stored_digest = stored.split(":")
|
||||||
|
input_digest, input_salt = cls.get_digest(passwd, bytes.fromhex(stored_salt))
|
||||||
|
return hmac.compare_digest(input_digest, stored_digest)
|
||||||
|
|
||||||
|
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||||
|
if value and not self.__class__.is_digest(value):
|
||||||
|
digest, salt = self.__class__.get_digest(value)
|
||||||
|
record[self.name] = f"{salt}:{digest}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Pointer(Field):
|
||||||
|
"""
|
||||||
|
Store a string reference to a record.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = ""
|
||||||
|
value_type: grung.types.Record = Record
|
||||||
|
|
||||||
|
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
|
||||||
|
return Pointer.reference(value)
|
||||||
|
|
||||||
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||||
|
return Pointer.dereference(value, db, recurse)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reference(cls, value: Record | str):
|
||||||
|
if isinstance(value, str):
|
||||||
|
PointerReferenceValidator().validate_string(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
if value:
|
||||||
|
return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dereference(cls, value: str, db: TinyDB, recurse: bool = True):
|
||||||
|
if not value:
|
||||||
|
return
|
||||||
|
elif type(value) == str:
|
||||||
|
table_name, pkey, pval = value.split("::")
|
||||||
|
if pval:
|
||||||
|
table = db.table(table_name)
|
||||||
|
rec = table.get(where(pkey) == pval, recurse=recurse)
|
||||||
|
if not rec:
|
||||||
|
raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!")
|
||||||
|
return rec
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackReference(Pointer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BinaryFilePointer(Field):
|
||||||
|
"""
|
||||||
|
Write the contents of this field to disk and store the path in the db.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
extension: str = ".blob"
|
||||||
|
|
||||||
|
value_type = bytes
|
||||||
|
|
||||||
|
def relpath(self, record):
|
||||||
|
return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}"
|
||||||
|
|
||||||
|
def reference(self, record):
|
||||||
|
return f"/::{self.relpath(record)}"
|
||||||
|
|
||||||
|
def dereference(self, reference, db):
|
||||||
|
relpath = reference.replace("/::", "", 1)
|
||||||
|
try:
|
||||||
|
return (db.path / relpath).read_bytes()
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
|
||||||
|
return self.reference(record)
|
||||||
|
|
||||||
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
return self.dereference(value, db)
|
||||||
|
|
||||||
|
def prepare(self, data: value_type):
|
||||||
|
"""
|
||||||
|
Return bytes to be written to disk
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
if not isinstance(data, self.value_type):
|
||||||
|
return data.encode()
|
||||||
|
return data
|
||||||
|
|
||||||
|
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||||
|
if not value:
|
||||||
|
return
|
||||||
|
relpath = self.relpath(record)
|
||||||
|
path = db.path / relpath
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_bytes(self.prepare(value))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextFilePointer(BinaryFilePointer):
|
||||||
|
"""
|
||||||
|
Write the contents of this field to disk and store the path in the db.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
extension: str = ".txt"
|
||||||
|
|
||||||
|
value_type = str
|
||||||
|
|
||||||
|
def prepare(self, data: value_type):
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return data
|
||||||
|
return str(data).encode()
|
||||||
|
|
||||||
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
buf = super().deserialize(value, db)
|
||||||
|
return buf.decode() if buf else None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Collection(Field):
|
||||||
|
"""
|
||||||
|
A collection of pointers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
default: typing.List[value_type] = field(default_factory=lambda: [])
|
||||||
|
member_type: type = Record
|
||||||
|
|
||||||
|
value_type = list
|
||||||
|
|
||||||
|
def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]:
|
||||||
|
return [Pointer.reference(val) for val in values]
|
||||||
|
|
||||||
|
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
|
||||||
|
"""
|
||||||
|
Recursively deserialize the objects in this collection
|
||||||
|
"""
|
||||||
|
recs = []
|
||||||
|
if not recurse:
|
||||||
|
return values
|
||||||
|
for val in values:
|
||||||
|
recs.append(Pointer.dereference(val, db=db, recurse=False))
|
||||||
|
return recs
|
||||||
|
|
||||||
|
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||||
|
"""
|
||||||
|
Populate any backreferences in the members of this collection with the parent record's uid.
|
||||||
|
"""
|
||||||
|
if not record[self.name]:
|
||||||
|
return
|
||||||
|
|
||||||
|
for member in record[self.name]:
|
||||||
|
target = Pointer.dereference(member, db=db, recurse=False)
|
||||||
|
for backref in target._metadata.backrefs(type(record)):
|
||||||
|
target[backref.name] = record
|
||||||
|
db.save(target)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecordDict(Field):
|
||||||
|
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
|
||||||
|
member_type: type = Record
|
||||||
|
|
||||||
|
value_type = dict
|
||||||
|
|
||||||
|
def serialize(
|
||||||
|
self, values: typing.Dict[(str, value_type)], record: Record | None = None
|
||||||
|
) -> typing.Dict[(str, str)]:
|
||||||
|
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
|
||||||
|
|
||||||
|
def deserialize(
|
||||||
|
self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False
|
||||||
|
) -> typing.Dict[(str, value_type)]:
|
||||||
|
if not recurse:
|
||||||
|
return values
|
||||||
|
return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items())
|
||||||
|
|
||||||
|
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||||
|
"""
|
||||||
|
Populate any backreferences in the members of this mapping with the parent record's uid.
|
||||||
|
"""
|
||||||
|
if not record[self.name]:
|
||||||
|
return
|
||||||
|
|
||||||
|
for key, pointer in record[self.name].items():
|
||||||
|
target = Pointer.dereference(pointer, db=db, recurse=False)
|
||||||
|
for backref in target._metadata.backrefs(type(record)):
|
||||||
|
target[backref.name] = record
|
||||||
|
db.save(target)
|
||||||
|
|
@ -1,420 +1,9 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import typing
|
import typing
|
||||||
from collections import namedtuple
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import nanoid
|
|
||||||
from tinydb import TinyDB, where
|
|
||||||
|
|
||||||
from grung.exceptions import PointerReferenceError
|
|
||||||
|
|
||||||
Metadata = namedtuple("Metadata", ["table", "fields", "backrefs", "primary_key"])
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Field:
|
class Field:
|
||||||
"""
|
|
||||||
Represents a single field in a Record.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
value_type: type = str
|
|
||||||
default: str = None
|
|
||||||
unique: bool = False
|
|
||||||
primary_key: bool = False
|
|
||||||
|
|
||||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
|
||||||
if value is not None:
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Integer(Field):
|
|
||||||
value_type = int
|
|
||||||
default: int = 0
|
|
||||||
|
|
||||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
|
||||||
return int(value)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Dict(Field):
|
|
||||||
value_type: type = str
|
|
||||||
default: dict = field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]:
|
|
||||||
return dict((key, str(value)) for key, value in values.items())
|
|
||||||
|
|
||||||
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class List(Field):
|
|
||||||
value_type: type = list
|
|
||||||
default: list = field(default_factory=lambda: [])
|
|
||||||
|
|
||||||
def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]:
|
|
||||||
return values
|
|
||||||
|
|
||||||
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DateTime(Field):
|
|
||||||
value_type: datetime
|
|
||||||
default: datetime = datetime.utcfromtimestamp(0)
|
|
||||||
|
|
||||||
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
|
||||||
return (value - datetime.utcfromtimestamp(0)).total_seconds()
|
|
||||||
|
|
||||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
|
||||||
return datetime.utcfromtimestamp(int(value))
|
|
||||||
|
|
||||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
|
||||||
if not value:
|
|
||||||
record[self.name] = datetime.utcnow().replace(microsecond=0)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Timestamp(DateTime):
|
|
||||||
value_type: datetime
|
|
||||||
|
|
||||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
|
||||||
super().before_insert(None, db, record)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Password(Field):
|
|
||||||
value_type = str
|
|
||||||
default: str = None
|
|
||||||
|
|
||||||
# Relatively weak. Consider using stronger initial values in production applications.
|
|
||||||
salt_size = 4
|
|
||||||
digest_size = 16
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_digest(cls, passwd: str):
|
|
||||||
if not passwd:
|
|
||||||
return False
|
|
||||||
offset = 2 * cls.salt_size # each byte is 2 hex chars
|
|
||||||
try:
|
|
||||||
if passwd[offset] != ":":
|
|
||||||
return False
|
|
||||||
digest = passwd[(offset + 1) :] # noqa
|
|
||||||
if len(digest) != cls.digest_size * 2:
|
|
||||||
return False
|
|
||||||
return re.match(r"^[0-9a-f]+$", digest)
|
|
||||||
except IndexError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_digest(cls, passwd: str, salt: bytes = None):
|
|
||||||
if not salt:
|
|
||||||
salt = os.urandom(cls.salt_size)
|
|
||||||
digest = hashlib.blake2b(passwd.encode(), digest_size=cls.digest_size, salt=salt).hexdigest()
|
|
||||||
return digest, salt.hex()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def compare(cls, passwd: value_type, stored: value_type):
|
|
||||||
stored_salt, stored_digest = stored.split(":")
|
|
||||||
input_digest, input_salt = cls.get_digest(passwd, bytes.fromhex(stored_salt))
|
|
||||||
return hmac.compare_digest(input_digest, stored_digest)
|
|
||||||
|
|
||||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
|
||||||
if value and not self.__class__.is_digest(value):
|
|
||||||
digest, salt = self.__class__.get_digest(value)
|
|
||||||
record[self.name] = f"{salt}:{digest}"
|
|
||||||
|
|
||||||
|
|
||||||
class Record(typing.Dict[(str, Field)]):
|
class Record(typing.Dict[(str, Field)]):
|
||||||
"""
|
|
||||||
Base type for a single database record.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, raw_doc: dict = {}, doc_id: int = None, **params):
|
|
||||||
self.doc_id = doc_id
|
|
||||||
fields = self.__class__.fields()
|
|
||||||
|
|
||||||
pkey = [field for field in fields if field.primary_key]
|
|
||||||
if len(pkey) > 1:
|
|
||||||
raise Exception(f"Cannnot have more than one primary key: {pkey}")
|
|
||||||
elif pkey:
|
|
||||||
pkey = pkey[0]
|
|
||||||
else:
|
|
||||||
# 1% collision rate at ~2M records
|
|
||||||
pkey = Field("uid", default=nanoid.generate(size=8), primary_key=True)
|
|
||||||
fields.append(pkey)
|
|
||||||
|
|
||||||
pkey.unique = True
|
|
||||||
|
|
||||||
self._metadata = Metadata(
|
|
||||||
table=self.__class__.__name__,
|
|
||||||
primary_key=pkey.name,
|
|
||||||
fields={f.name: f for f in fields},
|
|
||||||
backrefs=lambda value_type: (
|
|
||||||
field for field in fields if type(field) == BackReference and field.value_type == value_type
|
|
||||||
),
|
|
||||||
)
|
|
||||||
super().__init__(dict({field.name: field.default for field in fields}, **raw_doc, **params))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def fields(cls):
|
|
||||||
return []
|
|
||||||
|
|
||||||
def serialize(self):
|
|
||||||
"""
|
|
||||||
Serialize every field on the record
|
|
||||||
"""
|
|
||||||
rec = {}
|
|
||||||
for name, _field in self._metadata.fields.items():
|
|
||||||
rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field
|
|
||||||
return self.__class__(rec, doc_id=self.doc_id)
|
|
||||||
|
|
||||||
def deserialize(self, db, recurse: bool = True):
|
|
||||||
"""
|
|
||||||
Deserialize every field on the record
|
|
||||||
"""
|
|
||||||
rec = {}
|
|
||||||
for name, _field in self._metadata.fields.items():
|
|
||||||
rec[name] = _field.deserialize(self[name], db, recurse=recurse)
|
|
||||||
return self.__class__(rec, doc_id=self.doc_id)
|
|
||||||
|
|
||||||
def before_insert(self, db: TinyDB) -> None:
|
|
||||||
for name, _field in self._metadata.fields.items():
|
|
||||||
_field.before_insert(self[name], db, self)
|
|
||||||
|
|
||||||
def after_insert(self, db: TinyDB) -> None:
|
|
||||||
for name, _field in self._metadata.fields.items():
|
|
||||||
_field.after_insert(db, self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def reference(self):
|
|
||||||
return Pointer.reference(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def path(self):
|
|
||||||
return Path(self._metadata.table) / self[self._metadata.primary_key]
|
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
|
||||||
if key in self:
|
|
||||||
self[key] = value
|
|
||||||
super().__setattr__(key, value)
|
|
||||||
|
|
||||||
def __getattr__(self, attr_name):
|
|
||||||
if attr_name in self:
|
|
||||||
return self.get(attr_name)
|
|
||||||
raise AttributeError(f"No such attribute: {attr_name}")
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash(str(dict(self)))
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
f"{self.__class__.__name__}[{self.doc_id}]("
|
|
||||||
+ ", ".join([f"{key}={val}" for (key, val) in self.items()])
|
|
||||||
+ ")"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Pointer(Field):
|
|
||||||
"""
|
|
||||||
Store a string reference to a record.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str = ""
|
|
||||||
value_type: type = Record
|
|
||||||
|
|
||||||
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
|
|
||||||
return Pointer.reference(value)
|
|
||||||
|
|
||||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
|
||||||
return Pointer.dereference(value, db, recurse)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def reference(cls, value: Record | str):
|
|
||||||
# XXX This could be smarter
|
|
||||||
if isinstance(value, str):
|
|
||||||
if "::" not in value:
|
|
||||||
raise PointerReferenceError("Value {value} does not look like a reference!")
|
|
||||||
return value
|
|
||||||
if value:
|
|
||||||
return f"{value._metadata.table}::{value._metadata.primary_key}::{value[value._metadata.primary_key]}"
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def dereference(cls, value: str, db: TinyDB, recurse: bool = True):
|
|
||||||
if not value:
|
|
||||||
return
|
|
||||||
elif type(value) == str:
|
|
||||||
table_name, pkey, pval = value.split("::")
|
|
||||||
if pval:
|
|
||||||
table = db.table(table_name)
|
|
||||||
rec = table.get(where(pkey) == pval, recurse=recurse)
|
|
||||||
if not rec:
|
|
||||||
raise PointerReferenceError(f"Expected a {table_name} with {pkey}=={pval} but did not find one!")
|
|
||||||
return rec
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BackReference(Pointer):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BinaryFilePointer(Field):
|
|
||||||
"""
|
|
||||||
Write the contents of this field to disk and store the path in the db.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
value_type: type = bytes
|
|
||||||
extension: str = ".blob"
|
|
||||||
|
|
||||||
def relpath(self, record):
|
|
||||||
return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}"
|
|
||||||
|
|
||||||
def reference(self, record):
|
|
||||||
return f"/::{self.relpath(record)}"
|
|
||||||
|
|
||||||
def dereference(self, reference, db):
|
|
||||||
relpath = reference.replace("/::", "", 1)
|
|
||||||
try:
|
|
||||||
return (db.path / relpath).read_bytes()
|
|
||||||
except FileNotFoundError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
|
|
||||||
return self.reference(record)
|
|
||||||
|
|
||||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
|
||||||
if not value:
|
|
||||||
return None
|
|
||||||
return self.dereference(value, db)
|
|
||||||
|
|
||||||
def prepare(self, data: value_type):
|
|
||||||
"""
|
|
||||||
Return bytes to be written to disk
|
|
||||||
"""
|
|
||||||
if not data:
|
|
||||||
return
|
|
||||||
if not isinstance(data, self.value_type):
|
|
||||||
return data.encode()
|
|
||||||
return data
|
|
||||||
|
|
||||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
|
||||||
if not value:
|
|
||||||
return
|
|
||||||
relpath = self.relpath(record)
|
|
||||||
path = db.path / relpath
|
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
path.write_bytes(self.prepare(value))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TextFilePointer(BinaryFilePointer):
|
|
||||||
"""
|
|
||||||
Write the contents of this field to disk and store the path in the db.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
value_type: type = str
|
|
||||||
extension: str = ".txt"
|
|
||||||
|
|
||||||
def prepare(self, data: value_type):
|
|
||||||
if isinstance(data, bytes):
|
|
||||||
return data
|
|
||||||
return str(data).encode()
|
|
||||||
|
|
||||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
|
||||||
if not value:
|
|
||||||
return None
|
|
||||||
buf = super().deserialize(value, db)
|
|
||||||
return buf.decode() if buf else None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Collection(Field):
|
|
||||||
"""
|
|
||||||
A collection of pointers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
value_type: type = Record
|
|
||||||
default: typing.List[value_type] = field(default_factory=lambda: [])
|
|
||||||
|
|
||||||
def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]:
|
|
||||||
return [Pointer.reference(val) for val in values]
|
|
||||||
|
|
||||||
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
|
|
||||||
"""
|
|
||||||
Recursively deserialize the objects in this collection
|
|
||||||
"""
|
|
||||||
recs = []
|
|
||||||
if not recurse:
|
|
||||||
return values
|
|
||||||
for val in values:
|
|
||||||
recs.append(Pointer.dereference(val, db=db, recurse=False))
|
|
||||||
return recs
|
|
||||||
|
|
||||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
|
||||||
"""
|
|
||||||
Populate any backreferences in the members of this collection with the parent record's uid.
|
|
||||||
"""
|
|
||||||
if not record[self.name]:
|
|
||||||
return
|
|
||||||
|
|
||||||
for member in record[self.name]:
|
|
||||||
target = Pointer.dereference(member, db=db, recurse=False)
|
|
||||||
for backref in target._metadata.backrefs(type(record)):
|
|
||||||
target[backref.name] = record
|
|
||||||
db.save(target)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RecordDict(Field):
|
|
||||||
value_type: type = Record
|
|
||||||
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
def serialize(
|
|
||||||
self, values: typing.Dict[(str, value_type)], record: Record | None = None
|
|
||||||
) -> typing.Dict[(str, str)]:
|
|
||||||
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
|
|
||||||
|
|
||||||
def deserialize(
|
|
||||||
self, values: typing.Dict[(str, str)], db: TinyDB, recurse: bool = False
|
|
||||||
) -> typing.Dict[(str, value_type)]:
|
|
||||||
if not recurse:
|
|
||||||
return values
|
|
||||||
return dict((key, Pointer.dereference(val, db=db, recurse=False)) for (key, val) in values.items())
|
|
||||||
|
|
||||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
|
||||||
"""
|
|
||||||
Populate any backreferences in the members of this mapping with the parent record's uid.
|
|
||||||
"""
|
|
||||||
if not record[self.name]:
|
|
||||||
return
|
|
||||||
|
|
||||||
for key, pointer in record[self.name].items():
|
|
||||||
target = Pointer.dereference(pointer, db=db, recurse=False)
|
|
||||||
for backref in target._metadata.backrefs(type(record)):
|
|
||||||
target[backref.name] = record
|
|
||||||
db.save(target)
|
|
||||||
|
|
|
||||||
199
src/grung/validators.py
Normal file
199
src/grung/validators.py
Normal file
|
|
@ -0,0 +1,199 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from tinydb import Query, TinyDB
|
||||||
|
|
||||||
|
from grung.exceptions import (
|
||||||
|
InvalidFieldTypeError,
|
||||||
|
InvalidLengthError,
|
||||||
|
InvalidSizeError,
|
||||||
|
MalformedPointerError,
|
||||||
|
PatternMatchError,
|
||||||
|
UniqueConstraintError,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
from grung.types import Field, Record
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Validator:
|
||||||
|
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||||
|
raise ValidationError(record, field, self)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TypeValidator:
|
||||||
|
def validate_list(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||||
|
messages = []
|
||||||
|
for i in range(len(record[field.name])):
|
||||||
|
member = record[field.name][i]
|
||||||
|
if isinstance(member, str):
|
||||||
|
class_name = field.member_type.__name__
|
||||||
|
try:
|
||||||
|
return PointerReferenceValidator().validate_string(member, class_name)
|
||||||
|
except MalformedPointerError as e:
|
||||||
|
messages.append(str(e))
|
||||||
|
elif not isinstance(member, field.member_type):
|
||||||
|
messages.append(f"{field.name}[{i}] must be a {field.member_type}, not a {type(member)}.")
|
||||||
|
if messages:
|
||||||
|
raise InvalidFieldTypeError(record, field, self, messages=messages)
|
||||||
|
|
||||||
|
def validate_dict(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||||
|
for key, member in record[field.name].items():
|
||||||
|
if not isinstance(member, field.member_type):
|
||||||
|
raise InvalidFieldTypeError(
|
||||||
|
record,
|
||||||
|
field,
|
||||||
|
self,
|
||||||
|
messages=[f"{field.name}[{key} must be a {field.member_type}, not a {type(member)}."],
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||||
|
if record[field.name] is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not isinstance(record[field.name], field.value_type):
|
||||||
|
raise InvalidFieldTypeError(
|
||||||
|
record,
|
||||||
|
field,
|
||||||
|
self,
|
||||||
|
messages=[f"{field.name} must be a {field.value_type}, not a {type(record[field.name])}."],
|
||||||
|
)
|
||||||
|
if not hasattr(field, "member_type"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if field.value_type == dict:
|
||||||
|
self.validate_dict(record, field, db)
|
||||||
|
elif field.value_type == list:
|
||||||
|
self.validate_list(record, field, db)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Expected a validation for iterable but didn't get one!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PointerReferenceValidator(Validator):
|
||||||
|
"""
|
||||||
|
Verify that the Pointer is either a string reference to the correct member type,
|
||||||
|
or a record instance of member_type that hasn't been serialized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def validate_string(self, value: str, type_name: str = "") -> bool:
|
||||||
|
(table, primary_key, val) = value.split("::")
|
||||||
|
if type_name and table != type_name:
|
||||||
|
raise MalformedPointerError(
|
||||||
|
{"string": value},
|
||||||
|
"",
|
||||||
|
self,
|
||||||
|
messages=[f"field should reference '{type_name}', not '{table}'."],
|
||||||
|
)
|
||||||
|
if not primary_key:
|
||||||
|
raise MalformedPointerError(
|
||||||
|
{"string": value},
|
||||||
|
"",
|
||||||
|
self,
|
||||||
|
messages=["Pointers must specify the primary_key field name."],
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate(self, record: Record | str, field: Field, db: TinyDB = None) -> bool:
|
||||||
|
if record[field.name] is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(record, str):
|
||||||
|
try:
|
||||||
|
self.validate_string(record, field.value_type.__name__)
|
||||||
|
except ValueError:
|
||||||
|
raise MalformedPointerError({field.name: record}, field, self)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not isinstance(record, Record):
|
||||||
|
raise MalformedPointerError(record, field, self)
|
||||||
|
|
||||||
|
if not isinstance(record[field.name], field.value_type):
|
||||||
|
raise MalformedPointerError(record, field, self)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UniqueValidator(Validator):
|
||||||
|
def validate(self, record: Record, field: Field, db: TinyDB) -> bool:
|
||||||
|
"""
|
||||||
|
Returns true if the field's value is unique across all records in the table.
|
||||||
|
"""
|
||||||
|
if record[field.name] is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
query = Query()[field.name].matches(f"^{record[field.name]}$", flags=re.IGNORECASE)
|
||||||
|
table = db.table(record._metadata.table)
|
||||||
|
matches = [dict(match) for match in table.search(query) if match.doc_id != record.doc_id]
|
||||||
|
if matches != []:
|
||||||
|
raise UniqueConstraintError(record, field, self, query=query, matches=matches)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LengthValidator(Validator):
|
||||||
|
min: int = 0
|
||||||
|
max: int = 0
|
||||||
|
|
||||||
|
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the length of the field's value is between min and max, inclusive.
|
||||||
|
"""
|
||||||
|
if record[field.name] is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
length = len(record[field.name])
|
||||||
|
if length < self.min or length > self.max:
|
||||||
|
raise InvalidLengthError(
|
||||||
|
record,
|
||||||
|
field,
|
||||||
|
self,
|
||||||
|
messages=[f"The field length must be between {self.min} and {self.max}, inclusive."],
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MinMaxValidator(Validator):
|
||||||
|
min: int = 0
|
||||||
|
max: int = 0
|
||||||
|
|
||||||
|
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the size of the field's integer value is between min and max, inclusive.
|
||||||
|
"""
|
||||||
|
if record[field.name] is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
size = int(record[field.name])
|
||||||
|
if size < self.min or size > self.max:
|
||||||
|
raise InvalidSizeError(
|
||||||
|
record,
|
||||||
|
field,
|
||||||
|
self,
|
||||||
|
messages=[f"The field size must be between {self.min} and {self.max}, inclusive."],
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PatternValidator(Validator):
|
||||||
|
pattern: re.Pattern
|
||||||
|
|
||||||
|
def validate(self, record: Record, field: Field, db: TinyDB = None) -> bool:
|
||||||
|
if record[field.name] is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not self.pattern.match(record[field.name]):
|
||||||
|
raise PatternMatchError(
|
||||||
|
record,
|
||||||
|
field,
|
||||||
|
self,
|
||||||
|
messages=[f"The field value must match the pattern {self.pattern}"],
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
@ -10,7 +10,14 @@ from tinydb.storages import MemoryStorage
|
||||||
|
|
||||||
from grung import examples
|
from grung import examples
|
||||||
from grung.db import GrungDB
|
from grung.db import GrungDB
|
||||||
from grung.exceptions import CircularReferenceError, UniqueConstraintError
|
from grung.exceptions import (
|
||||||
|
CircularReferenceError,
|
||||||
|
InvalidFieldTypeError,
|
||||||
|
InvalidLengthError,
|
||||||
|
InvalidSizeError,
|
||||||
|
PatternMatchError,
|
||||||
|
UniqueConstraintError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -82,7 +89,7 @@ def test_subgroups(db):
|
||||||
|
|
||||||
# recursion!
|
# recursion!
|
||||||
with pytest.raises(CircularReferenceError):
|
with pytest.raises(CircularReferenceError):
|
||||||
tos.members = [tos]
|
tos.groups = [tos]
|
||||||
db.save(tos)
|
db.save(tos)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -201,3 +208,29 @@ def test_file_pointers(db):
|
||||||
|
|
||||||
location_on_disk = db.path / album._metadata.fields["review"].relpath(album)
|
location_on_disk = db.path / album._metadata.fields["review"].relpath(album)
|
||||||
assert location_on_disk.read_text() == album.review
|
assert location_on_disk.read_text() == album.review
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"updates, expected",
|
||||||
|
[
|
||||||
|
({"name": ""}, InvalidLengthError),
|
||||||
|
({"name": "a name longer than 30 characters is what we have here"}, InvalidLengthError),
|
||||||
|
({"name": 23}, InvalidFieldTypeError),
|
||||||
|
({"number": -1}, InvalidSizeError),
|
||||||
|
({"number": 256}, InvalidSizeError),
|
||||||
|
({"email": "foo+alias@"}, PatternMatchError),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"name too short",
|
||||||
|
"name too long",
|
||||||
|
"name is not a string",
|
||||||
|
"number too small",
|
||||||
|
"number too big",
|
||||||
|
"invalid email addres",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_validators(updates, expected, db):
|
||||||
|
user = db.save(examples.User(name="john", email="john@foo", password="fnord", created=datetime.utcnow()))
|
||||||
|
with pytest.raises(expected):
|
||||||
|
user.update(**updates)
|
||||||
|
db.save(user)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user