Implement FilePointer
This commit is contained in:
parent
7e05915540
commit
8bc3f07b28
|
|
@ -2,9 +2,11 @@ import inspect
|
|||
import re
|
||||
from functools import reduce
|
||||
from operator import ior
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from tinydb import Query, TinyDB, table
|
||||
from tinydb.storages import MemoryStorage
|
||||
from tinydb.table import Document
|
||||
|
||||
from grung.exceptions import UniqueConstraintError
|
||||
|
|
@ -18,11 +20,11 @@ class RecordTable(table.Table):
|
|||
|
||||
def __init__(self, name: str, db: TinyDB, document_class: Document = Record, **kwargs):
|
||||
self.document_class = document_class
|
||||
self._db = db
|
||||
self.db = db
|
||||
super().__init__(db.storage, name, **kwargs)
|
||||
|
||||
def insert(self, document):
|
||||
document.before_insert(self._db)
|
||||
document.before_insert(self.db)
|
||||
doc = document.serialize()
|
||||
self._check_constraints(doc)
|
||||
|
||||
|
|
@ -31,8 +33,8 @@ class RecordTable(table.Table):
|
|||
else:
|
||||
last_insert_id = super().insert(dict(doc))
|
||||
doc.doc_id = last_insert_id
|
||||
doc.after_insert(self._db)
|
||||
return doc.deserialize(self._db)
|
||||
doc.after_insert(self.db)
|
||||
return doc.deserialize(self.db)
|
||||
|
||||
def get(self, *args, doc_id: int = None, recurse: bool = False, **kwargs):
|
||||
"""
|
||||
|
|
@ -48,7 +50,7 @@ class RecordTable(table.Table):
|
|||
if doc_id:
|
||||
document = super().get(doc_id=doc_id)
|
||||
if document:
|
||||
return document.deserialize(self._db, recurse=recurse)
|
||||
return document.deserialize(self.db, recurse=recurse)
|
||||
|
||||
matches = self.search(*args, recurse=recurse, **kwargs)
|
||||
if matches:
|
||||
|
|
@ -56,7 +58,7 @@ class RecordTable(table.Table):
|
|||
|
||||
def search(self, *args, recurse: bool = False, **kwargs) -> List[Record]:
|
||||
results = super().search(*args, **kwargs)
|
||||
return [r.deserialize(self._db, recurse=recurse) for r in results]
|
||||
return [r.deserialize(self.db, recurse=recurse) for r in results]
|
||||
|
||||
def remove(self, document):
|
||||
if document.doc_id:
|
||||
|
|
@ -89,7 +91,10 @@ class GrungDB(TinyDB):
|
|||
default_table_name = "Record"
|
||||
_tables = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, path: Path, *args, **kwargs):
|
||||
self.path = path
|
||||
if kwargs["storage"] != MemoryStorage:
|
||||
args = (path,) + args
|
||||
super().__init__(*args, **kwargs)
|
||||
self.create_table(Record)
|
||||
|
||||
|
|
@ -122,8 +127,8 @@ class GrungDB(TinyDB):
|
|||
return super().__getattr__(attr_name)
|
||||
|
||||
@classmethod
|
||||
def with_schema(cls, schema_module, *args, **kwargs):
|
||||
db = GrungDB(*args, **kwargs)
|
||||
def with_schema(cls, schema_module, path: Path | None, *args, **kwargs):
|
||||
db = GrungDB(path=path, *args, **kwargs)
|
||||
for name, obj in inspect.getmembers(schema_module):
|
||||
if type(obj) == type and issubclass(obj, Record):
|
||||
db.create_table(obj)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from grung.types import (
|
|||
DateTime,
|
||||
Dict,
|
||||
Field,
|
||||
FilePointer,
|
||||
Integer,
|
||||
List,
|
||||
Password,
|
||||
|
|
@ -49,6 +50,7 @@ class Album(Record):
|
|||
Dict("credits"),
|
||||
List("tracks"),
|
||||
BackReference("artist", Artist),
|
||||
FilePointer("cover", extension=".jpg"),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -56,7 +58,4 @@ class Artist(User):
|
|||
@classmethod
|
||||
def fields(cls):
|
||||
inherited = [f for f in super().fields() if f.name != "name"]
|
||||
return inherited + [
|
||||
Field("name"),
|
||||
RecordDict("albums", Album),
|
||||
]
|
||||
return inherited + [Field("name"), RecordDict("albums", Album)]
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import typing
|
|||
from collections import namedtuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import nanoid
|
||||
from tinydb import TinyDB, where
|
||||
|
|
@ -35,7 +36,7 @@ class Field:
|
|||
def after_insert(self, db: TinyDB, record: Record) -> None:
|
||||
pass
|
||||
|
||||
def serialize(self, value: value_type) -> str:
|
||||
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
||||
if value is not None:
|
||||
return str(value)
|
||||
|
||||
|
|
@ -57,7 +58,7 @@ class Dict(Field):
|
|||
value_type: type = str
|
||||
default: dict = field(default_factory=lambda: {})
|
||||
|
||||
def serialize(self, values: dict) -> Dict[(str, str)]:
|
||||
def serialize(self, values: dict, record: Record | None = None) -> Dict[(str, str)]:
|
||||
return dict((key, str(value)) for key, value in values.items())
|
||||
|
||||
def deserialize(self, values: dict, db: TinyDB, recurse: bool = False) -> Dict[(str, str)]:
|
||||
|
|
@ -69,7 +70,7 @@ class List(Field):
|
|||
value_type: type = list
|
||||
default: list = field(default_factory=lambda: [])
|
||||
|
||||
def serialize(self, values: list) -> Dict[(str, str)]:
|
||||
def serialize(self, values: list, record: Record | None = None) -> Dict[(str, str)]:
|
||||
return values
|
||||
|
||||
def deserialize(self, values: list, db: TinyDB, recurse: bool = False) -> typing.List[str]:
|
||||
|
|
@ -81,7 +82,7 @@ class DateTime(Field):
|
|||
value_type: datetime
|
||||
default: datetime = datetime.utcfromtimestamp(0)
|
||||
|
||||
def serialize(self, value: value_type) -> str:
|
||||
def serialize(self, value: value_type, record: Record | None = None) -> str:
|
||||
return (value - datetime.utcfromtimestamp(0)).total_seconds()
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = False) -> value_type:
|
||||
|
|
@ -184,7 +185,7 @@ class Record(typing.Dict[(str, Field)]):
|
|||
"""
|
||||
rec = {}
|
||||
for name, _field in self._metadata.fields.items():
|
||||
rec[name] = _field.serialize(self[name]) if isinstance(_field, Field) else _field
|
||||
rec[name] = _field.serialize(self[name], record=self) if isinstance(_field, Field) else _field
|
||||
return self.__class__(rec, doc_id=self.doc_id)
|
||||
|
||||
def deserialize(self, db, recurse: bool = True):
|
||||
|
|
@ -208,6 +209,10 @@ class Record(typing.Dict[(str, Field)]):
|
|||
def reference(self):
|
||||
return Pointer.reference(self)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return Path(self._metadata.table) / self[self._metadata.primary_key]
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in self:
|
||||
self[key] = value
|
||||
|
|
@ -238,7 +243,7 @@ class Pointer(Field):
|
|||
name: str = ""
|
||||
value_type: type = Record
|
||||
|
||||
def serialize(self, value: value_type | str) -> str:
|
||||
def serialize(self, value: value_type | str, record: Record | None = None) -> str:
|
||||
return Pointer.reference(value)
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
|
|
@ -277,6 +282,32 @@ class BackReference(Pointer):
|
|||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilePointer(Field):
|
||||
"""
|
||||
Write the contents of this field to disk and store the path in the db.
|
||||
"""
|
||||
|
||||
name: str
|
||||
value_type: type = bytes
|
||||
extension: str = ".txt"
|
||||
|
||||
def relpath(self, record):
|
||||
return Path(record._metadata.table) / record[record._metadata.primary_key] / f"{self.name}{self.extension}"
|
||||
|
||||
def deserialize(self, value: str, db: TinyDB, recurse: bool = True) -> value_type:
|
||||
if not value:
|
||||
return None
|
||||
return (db.path / value).read_bytes()
|
||||
|
||||
def before_insert(self, value: value_type, db: TinyDB, record: Record) -> None:
|
||||
relpath = self.relpath(record)
|
||||
path = db.path / relpath
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(record[self.name])
|
||||
record[self.name] = relpath
|
||||
|
||||
|
||||
@dataclass
|
||||
class Collection(Field):
|
||||
"""
|
||||
|
|
@ -286,7 +317,7 @@ class Collection(Field):
|
|||
value_type: type = Record
|
||||
default: typing.List[value_type] = field(default_factory=lambda: [])
|
||||
|
||||
def serialize(self, values: typing.List[value_type]) -> typing.List[str]:
|
||||
def serialize(self, values: typing.List[value_type], record: Record | None = None) -> typing.List[str]:
|
||||
return [Pointer.reference(val) for val in values]
|
||||
|
||||
def deserialize(self, values: typing.List[str], db: TinyDB, recurse: bool = False) -> typing.List[value_type]:
|
||||
|
|
@ -319,7 +350,9 @@ class RecordDict(Field):
|
|||
value_type: type = Record
|
||||
default: typing.Dict[(str, Record)] = field(default_factory=lambda: {})
|
||||
|
||||
def serialize(self, values: typing.Dict[(str, value_type)]) -> typing.Dict[(str, str)]:
|
||||
def serialize(
|
||||
self, values: typing.Dict[(str, value_type)], record: Record | None = None
|
||||
) -> typing.Dict[(str, str)]:
|
||||
return dict((key, Pointer.reference(val)) for (key, val) in values.items())
|
||||
|
||||
def deserialize(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from pprint import pprint as print
|
||||
from time import sleep
|
||||
|
||||
|
|
@ -13,7 +15,8 @@ from grung.exceptions import PointerReferenceError, UniqueConstraintError
|
|||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
_db = GrungDB.with_schema(examples, storage=MemoryStorage)
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
_db = GrungDB.with_schema(examples, path=Path(path), storage=MemoryStorage)
|
||||
yield _db
|
||||
print(_db)
|
||||
|
||||
|
|
@ -153,10 +156,12 @@ def test_mapping(db):
|
|||
name="The Impossible Kid",
|
||||
credits={"Produced By": "Aesop Rock", "Lyrics By": "Aesop Rock", "Puke in the MeowMix By": "Kirby"},
|
||||
tracks=["Mystery Fish", "Rings", "Lotta Years", "Dorks"],
|
||||
cover=b"some jpg data",
|
||||
)
|
||||
)
|
||||
assert album.credits["Produced By"] == "Aesop Rock"
|
||||
assert album.tracks[0] == "Mystery Fish"
|
||||
assert album.cover == b"some jpg data"
|
||||
|
||||
aes = db.save(
|
||||
examples.Artist(
|
||||
|
|
@ -170,3 +175,6 @@ def test_mapping(db):
|
|||
assert album.name in aes.albums
|
||||
assert aes.albums[album.name].uid == album.uid
|
||||
assert "Kirby" in aes.albums[album.name].credits.values()
|
||||
|
||||
location_on_disk = db.path / album._metadata.fields["cover"].relpath(album)
|
||||
assert location_on_disk.read_bytes() == album.cover
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user