Implement FilePointer
This commit is contained in:
parent
7e05915540
commit
8bc3f07b28
|
|
@ -2,9 +2,11 @@ import inspect
|
||||||
import re
|
import re
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from operator import ior
|
from operator import ior
|
||||||
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from tinydb import Query, TinyDB, table
|
from tinydb import Query, TinyDB, table
|
||||||
|
from tinydb.storages import MemoryStorage
|
||||||
from tinydb.table import Document
|
from tinydb.table import Document
|
||||||
|
|
||||||
from grung.exceptions import UniqueConstraintError
|
from grung.exceptions import UniqueConstraintError
|
||||||
|
|
@ -18,11 +20,11 @@ class RecordTable(table.Table):
|
||||||
|
|
||||||
def __init__(self, name: str, db: TinyDB, document_class: Document = Record, **kwargs):
|
def __init__(self, name: str, db: TinyDB, document_class: Document = Record, **kwargs):
|
||||||
self.document_class = document_class
|
self.document_class = document_class
|
||||||
self._db = db
|
self.db = db
|
||||||
super().__init__(db.storage, name, **kwargs)
|
super().__init__(db.storage, name, **kwargs)
|
||||||
|
|
||||||
def insert(self, document):
|
def insert(self, document):
|
||||||
document.before_insert(self._db)
|
document.before_insert(self.db)
|
||||||
doc = document.serialize()
|
doc = document.serialize()
|
||||||
self._check_constraints(doc)
|
self._check_constraints(doc)
|
||||||
|
|
||||||
|
|
@ -31,8 +33,8 @@ class RecordTable(table.Table):
|
||||||
else:
|
else:
|
||||||
last_insert_id = super().insert(dict(doc))
|
last_insert_id = super().insert(dict(doc))
|
||||||
doc.doc_id = last_insert_id
|
doc.doc_id = last_insert_id
|
||||||
doc.after_insert(self._db)
|
doc.after_insert(self.db)
|
||||||
return doc.deserialize(self._db)
|
return doc.deserialize(self.db)
|
||||||
|
|
||||||
def get(self, *args, doc_id: int = None, recurse: bool = False, **kwargs):
|
def get(self, *args, doc_id: int = None, recurse: bool = False, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|
@ -48,7 +50,7 @@ class RecordTable(table.Table):
|
||||||
if doc_id:
|
if doc_id:
|
||||||
document = super().get(doc_id=doc_id)
|
document = super().get(doc_id=doc_id)
|
||||||
if document:
|
if document:
|
||||||
return document.deserialize(self._db, recurse=recurse)
|
return document.deserialize(self.db, recurse=recurse)
|
||||||
|
|
||||||
matches = self.search(*args, recurse=recurse, **kwargs)
|
matches = self.search(*args, recurse=recurse, **kwargs)
|
||||||
if matches:
|
if matches:
|
||||||
|
|
@ -56,7 +58,7 @@ class RecordTable(table.Table):
|
||||||
|
|
||||||
def search(self, *args, recurse: bool = False, **kwargs) -> List[Record]:
|
def search(self, *args, recurse: bool = False, **kwargs) -> List[Record]:
|
||||||
results = super().search(*args, **kwargs)
|
results = super().search(*args, **kwargs)
|
||||||
return [r.deserialize(self._db, recurse=recurse) for r in results]
|
return [r.deserialize(self.db, recurse=recurse) for r in results]
|
||||||
|
|
||||||
def remove(self, document):
|
def remove(self, document):
|
||||||
if document.doc_id:
|
if document.doc_id:
|
||||||
|
|
@ -89,7 +91,10 @@ class GrungDB(TinyDB):
|
||||||
default_table_name = "Record"
|
default_table_name = "Record"
|
||||||
_tables = {}
|
_tables = {}
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, path: Path, *args, **kwargs):
|
||||||
|
self.path = path
|
||||||
|
if kwargs["storage"] != MemoryStorage:
|
||||||
|
args = (path,) + args
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.create_table(Record)
|
self.create_table(Record)
|
||||||
|
|
||||||
|
|
@ -122,8 +127,8 @@ class GrungDB(TinyDB):
|
||||||
return super().__getattr__(attr_name)
|
return super().__getattr__(attr_name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def with_schema(cls, schema_module, *args, **kwargs):
|
def with_schema(cls, schema_module, path: Path | None, *args, **kwargs):
|
||||||
db = GrungDB(*args, **kwargs)
|
db = GrungDB(path=path, *args, **kwargs)
|
||||||
for name, obj in inspect.getmembers(schema_module):
|
for name, obj in inspect.getmembers(schema_module):
|
||||||
if type(obj) == type and issubclass(obj, Record):
|
if type(obj) == type and issubclass(obj, Record):
|
||||||
db.create_table(obj)
|
db.create_table(obj)
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from grung.types import (
|
||||||
DateTime,
|
DateTime,
|
||||||
Dict,
|
Dict,
|
||||||
Field,
|
Field,
|
||||||
|
FilePointer,
|
||||||
Integer,
|
Integer,
|
||||||
List,
|
List,
|
||||||
Password,
|
Password,
|
||||||
|
|
@ -49,6 +50,7 @@ class Album(Record):
|
||||||
Dict("credits"),
|
Dict("credits"),
|
||||||
List("tracks"),
|
List("tracks"),
|
||||||
BackReference("artist", Artist),
|
BackReference("artist", Artist),
|
||||||
|
FilePointer("cover", extension=".jpg"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -56,7 +58,4 @@ class Artist(User):
|
||||||
@classmethod
|
@classmethod
|
||||||
def fields(cls):
|
def fields(cls):
|
||||||
inherited = [f for f in super().fields() if f.name != "name"]
|
inherited = [f for f in super().fields() if f.name != "name"]
|
||||||
return inherited + [
|
return inherited + [Field("name"), RecordDict("albums", Album)]
|
||||||
Field("name"),
|
|
||||||
RecordDict("albums", Album),
|
|
||||||
]
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import typing
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import nanoid
|
import nanoid
|
||||||
from tinydb import TinyDB, where
|
from tinydb import TinyDB, where
|
||||||
|
|
@ -35,7 +36,7 @@ class Field:
|
||||||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def serialize(self, value: value_type) -> str:
|
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
||||||
if value is not None:
|
if value is not None:
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
@ -57,7 +58,7 @@ class Dict(Field):
|
||||||
value_type: type = str
|
value_type: type = str
|
||||||
default: dict = field(default_factory=lambda: {})
|
default: dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
def serialize(self, values: dict) -> Dict[(str, str)]:
|
def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]:
|
||||||
return dict((key, str(value)) for key, value in values.items())
|
return dict((key, str(value)) for key, value in values.items())
|
||||||
|
|
||||||
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
|
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
|
||||||
|
|
@ -69,7 +70,7 @@ class List(Field):
|
||||||
value_type: type = list
|
value_type: type = list
|
||||||
default: list = field(default_factory=lambda: [])
|
default: list = field(default_factory=lambda: [])
|
||||||
|
|
||||||
def serialize(self, values: list) -> Dict[(str, str)]:
|
def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]:
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
||||||
|
|
@ -81,7 +82,7 @@ class DateTime(Field):
|
||||||
value_type: datetime
|
value_type: datetime
|
||||||
default: datetime = datetime.utcfromtimestamp(0)
|
default: datetime = datetime.utcfromtimestamp(0)
|
||||||
|
|
||||||
def serialize(self, value: value_type) -> str:
|
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
||||||
return (value - datetime.utcfromtimestamp(0)).total_seconds()
|
return (value - datetime.utcfromtimestamp(0)).total_seconds()
|
||||||
|
|
||||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||||
|
|
@ -184,7 +185,7 @@ class Record(typing.Dict[(str, Field)]):
|
||||||
"""
|
"""
|
||||||
rec = {}
|
rec = {}
|
||||||
for name, _field in self._metadata.fields.items():
|
for name, _field in self._metadata.fields.items():
|
||||||
rec[name] = _field.serialize(self[name]) if isinstance(_field, Field) else _field
|
rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field
|
||||||
return self.__class__(rec, doc_id=self.doc_id)
|
return self.__class__(rec, doc_id=self.doc_id)
|
||||||
|
|
||||||
def deserialize(self, db, recurse: bool = True):
|
def deserialize(self, db, recurse: bool = True):
|
||||||
|
|
@ -208,6 +209,10 @@ class Record(typing.Dict[(str, Field)]):
|
||||||
def reference(self):
|
def reference(self):
|
||||||
return Pointer.reference(self)
|
return Pointer.reference(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
return Path(self._metadata.table) / self[self._metadata.primary_key]
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
def __setattr__(self, key, value):
|
||||||
if key in self:
|
if key in self:
|
||||||
self[key] = value
|
self[key] = value
|
||||||
|
|
@ -238,7 +243,7 @@ class Pointer(Field):
|
||||||
name: str = ""
|
name: str = ""
|
||||||
value_type: type = Record
|
value_type: type = Record
|
||||||
|
|
||||||
def serialize(self, value: value_type | str) -> str:
|
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
|
||||||
return Pointer.reference(value)
|
return Pointer.reference(value)
|
||||||
|
|
||||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||||
|
|
@ -277,6 +282,32 @@ class BackReference(Pointer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilePointer(Field):
|
||||||
|
"""
|
||||||
|
Write the contents of this field to disk and store the path in the db.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
value_type: type = bytes
|
||||||
|
extension: str = ".txt"
|
||||||
|
|
||||||
|
def relpath(self, record):
|
||||||
|
return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}"
|
||||||
|
|
||||||
|
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
return (db.path / value).read_bytes()
|
||||||
|
|
||||||
|
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||||
|
relpath = self.relpath(record)
|
||||||
|
path = db.path / relpath
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_bytes(record[self.name])
|
||||||
|
record[self.name] = relpath
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Collection(Field):
|
class Collection(Field):
|
||||||
"""
|
"""
|
||||||
|
|
@ -286,7 +317,7 @@ class Collection(Field):
|
||||||
value_type: type = Record
|
value_type: type = Record
|
||||||
default: typing.List[value_type] = field(default_factory=lambda: [])
|
default: typing.List[value_type] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
def serialize(self, values: typing.List[value_type]) -> typing.List[str]:
|
def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]:
|
||||||
return [Pointer.reference(val) for val in values]
|
return [Pointer.reference(val) for val in values]
|
||||||
|
|
||||||
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
|
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
|
||||||
|
|
@ -319,7 +350,9 @@ class RecordDict(Field):
|
||||||
value_type: type = Record
|
value_type: type = Record
|
||||||
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
|
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
def serialize(self, values: typing.Dict[(str, value_type)]) -> typing.Dict[(str, str)]:
|
def serialize(
|
||||||
|
self, values: typing.Dict[(str, value_type)], record: Record | None = None
|
||||||
|
) -> typing.Dict[(str, str)]:
|
||||||
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
|
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
|
||||||
|
|
||||||
def deserialize(
|
def deserialize(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
|
import tempfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
from pprint import pprint as print
|
from pprint import pprint as print
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
|
|
@ -13,9 +15,10 @@ from grung.exceptions import PointerReferenceError, UniqueConstraintError
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db():
|
def db():
|
||||||
_db = GrungDB.with_schema(examples, storage=MemoryStorage)
|
with tempfile.TemporaryDirectory() as path:
|
||||||
yield _db
|
_db = GrungDB.with_schema(examples, path=Path(path), storage=MemoryStorage)
|
||||||
print(_db)
|
yield _db
|
||||||
|
print(_db)
|
||||||
|
|
||||||
|
|
||||||
def test_crud(db):
|
def test_crud(db):
|
||||||
|
|
@ -153,10 +156,12 @@ def test_mapping(db):
|
||||||
name="The Impossible Kid",
|
name="The Impossible Kid",
|
||||||
credits={"Produced By": "Aesop Rock", "Lyrics By": "Aesop Rock", "Puke in the MeowMix By": "Kirby"},
|
credits={"Produced By": "Aesop Rock", "Lyrics By": "Aesop Rock", "Puke in the MeowMix By": "Kirby"},
|
||||||
tracks=["Mystery Fish", "Rings", "Lotta Years", "Dorks"],
|
tracks=["Mystery Fish", "Rings", "Lotta Years", "Dorks"],
|
||||||
|
cover=b"some jpg data",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert album.credits["Produced By"] == "Aesop Rock"
|
assert album.credits["Produced By"] == "Aesop Rock"
|
||||||
assert album.tracks[0] == "Mystery Fish"
|
assert album.tracks[0] == "Mystery Fish"
|
||||||
|
assert album.cover == b"some jpg data"
|
||||||
|
|
||||||
aes = db.save(
|
aes = db.save(
|
||||||
examples.Artist(
|
examples.Artist(
|
||||||
|
|
@ -170,3 +175,6 @@ def test_mapping(db):
|
||||||
assert album.name in aes.albums
|
assert album.name in aes.albums
|
||||||
assert aes.albums[album.name].uid == album.uid
|
assert aes.albums[album.name].uid == album.uid
|
||||||
assert "Kirby" in aes.albums[album.name].credits.values()
|
assert "Kirby" in aes.albums[album.name].credits.values()
|
||||||
|
|
||||||
|
location_on_disk = db.path / album._metadata.fields["cover"].relpath(album)
|
||||||
|
assert location_on_disk.read_bytes() == album.cover
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user