"""Support for django rest framework symmetric serialization"""
__all__ = ["EnumField", "FlagField", "EnumFieldMixin"]
import inspect
from datetime import date, datetime, time, timedelta
from decimal import Decimal, DecimalException
from enum import Enum, Flag
from functools import reduce
from operator import or_
from typing import Any
from rest_framework.fields import (
CharField,
ChoiceField,
DateField,
DateTimeField,
DecimalField,
DurationField,
Field,
FloatField,
IntegerField,
MultipleChoiceField,
TimeField,
)
from rest_framework.serializers import ModelSerializer
from rest_framework.utils.field_mapping import get_field_kwargs
from django_enum.fields import EnumField as EnumModelField
from django_enum.fields import FlagField as FlagModelField
from django_enum.utils import (
choices,
decimal_params,
determine_primitive,
with_typehint,
)
class ClassLookupdict:
"""
A dict-like object that looks up values using the MRO of a class or
instance. Similar to DRF's ClassLookupdict but returns None instead
of raising KeyError and allows classes or object instances to be
used as lookup keys.
:param mapping: A dictionary containing a mapping of class types to
values.
"""
def __init__(self, mapping: dict[type[Any], Any]):
self.mapping = mapping
def __getitem__(self, key: Any) -> Any | None:
"""
Fetch the given object for the type or type of the given object.
:param key: An object instance or class type
:return: The mapped value to the object instance's class or the
passed class type. Inheritance is honored. None is returned
if no mapping is present.
"""
for cls in inspect.getmro(
getattr(
key,
"_proxy_class",
key if isinstance(key, type) else getattr(key, "__class__"),
)
):
if cls in self.mapping:
return self.mapping.get(cls, None)
return None
[docs]
class EnumField(ChoiceField):
"""
A djangorestframework serializer field for Enumeration types. If
unspecified ModelSerializers will assign :class:`~django_enum.fields.EnumField`
model field types to `ChoiceField
<https://www.django-rest-framework.org/api-guide/fields/#choicefield>`_ which will
not accept symmetrical values, this field will.
:param enum: The type of the Enumeration of the field
:param strict: If True (default) only values in the Enumeration type
will be acceptable. If False, no errors will be thrown if other
values of the same primitive type are used
:param kwargs: Any other named arguments applicable to a ChoiceField
will be passed up to the base classes.
"""
enum: type[Enum]
primitive: type[Any]
strict: bool = True
primitive_field: Field | None = None
[docs]
def __init__(self, enum: type[Enum], strict: bool = strict, **kwargs):
self.enum = enum
self.primitive = determine_primitive(enum) # type: ignore
assert self.primitive is not None, (
f"Unable to determine primitive type for {enum}"
)
self.strict = strict
self.choices = kwargs.pop("choices", choices(enum))
field_name = kwargs.pop("field_name", None)
model_field = kwargs.pop("model_field", None)
if not self.strict:
# if this field is not strict, we instantiate its primitive
# field type so we can fall back to its to_internal_value
# method if the value is not a valid enum value
primitive_field_cls = ClassLookupdict(
{
str: CharField,
int: IntegerField,
float: FloatField,
date: DateField,
datetime: DateTimeField,
time: TimeField,
timedelta: DurationField,
Decimal: DecimalField,
}
)[self.primitive]
if primitive_field_cls:
field_kwargs = {
**kwargs,
**{
key: val
for key, val in (
get_field_kwargs(field_name, model_field)
if field_name and model_field
else {}
).items()
if key not in ["model_field", "field_name", "choices"]
},
}
if primitive_field_cls is not CharField:
field_kwargs.pop("allow_blank", None)
if primitive_field_cls is DecimalField:
field_kwargs = {
**field_kwargs,
**decimal_params(
self.enum,
max_digits=field_kwargs.pop("max_digits", None),
decimal_places=field_kwargs.pop("decimal_places", None),
),
}
self.primitive_field = primitive_field_cls(**field_kwargs)
super().__init__(choices=self.choices, **kwargs)
[docs]
def to_internal_value(self, data: Any) -> Enum | Any: # type: ignore[override]
"""
Transform the *incoming* primitive data into an enum instance.
:return: The enum instance or the primitive value if the enum
instance could not be found.
"""
if data == "" and self.allow_blank:
return ""
if not isinstance(data, self.enum):
try:
data = self.enum(data)
except (TypeError, ValueError):
try:
data = self.primitive(data)
data = self.enum(data)
except (TypeError, ValueError, DecimalException):
try:
data = self.enum[data]
except KeyError:
if self.strict:
self.fail("invalid_choice", input=data)
elif self.primitive_field:
return self.primitive_field.to_internal_value(data)
return data
[docs]
def to_representation(self, value: Any) -> Any:
"""
Transform the *outgoing* enum value into its primitive value.
"""
return getattr(value, "value", value)
[docs]
class FlagField(MultipleChoiceField):
"""
A djangorestframework serializer field for :class:`~enum.Flag` types. If
unspecified ModelSerializers will assign :class:`~django_enum.fields.FlagField`
model field types to `ChoiceField
<https://www.django-rest-framework.org/api-guide/fields/#choicefield>`_ which will
not combine composite flag values appropriately. This field will also allow any
symmetric values to be used (e.g. labels or names instead of values).
**You should add** :class:`~django_enum.drf.EnumFieldMixin` **to your serializer to
automatically use this field.**
:param enum: The type of the flag of the field
:param strict: If True (default) only values in the flag type
will be acceptable. If False, no errors will be thrown if other
values of the same primitive type are used
:param kwargs: Any other named arguments applicable to a ChoiceField
will be passed up to the base classes.
"""
enum: type[Flag]
strict: bool = True
[docs]
def __init__(self, enum: type[Flag], strict: bool = strict, **kwargs):
self.enum = enum
self.strict = strict
self.choices = kwargs.pop("choices", choices(enum))
kwargs.pop("field_name", None)
kwargs.pop("model_field", None)
super().__init__(choices=self.choices, **kwargs)
[docs]
def to_internal_value(self, data: Any) -> Enum | Any: # type: ignore[override]
"""
Transform the *incoming* primitive data into an enum instance.
We accept a composite flag value or a list of values. If a list,
each element will be converted to a flag value and then the values
will be reduced into a composite value with the or operator.
:return: A composite flag value.
"""
if not data:
if self.allow_null and (data is None or data == ""):
return None
return self.enum(0)
if not isinstance(data, self.enum):
try:
return self.enum(data)
except (TypeError, ValueError):
try:
if isinstance(data, str):
return self.enum[data]
if isinstance(data, (list, tuple)):
values = []
for val in data:
try:
values.append(self.enum(val))
except (TypeError, ValueError):
values.append(self.enum[val])
return reduce(or_, values)
except (TypeError, ValueError, KeyError):
pass
self.fail("invalid_choice", input=data)
return data
[docs]
def to_representation(self, value: Any) -> Any:
"""
Transform the *outgoing* enum value into its primitive value.
:return: The primitive composite value of the flag (most likely an integer).
"""
return getattr(value, "value", value)
[docs]
class EnumFieldMixin(with_typehint(ModelSerializer)): # type: ignore
"""
A mixin for ModelSerializers that adds auto-magic support for
EnumFields.
"""
[docs]
def build_standard_field(
self, field_name: str, model_field: EnumModelField
) -> tuple[type[Field], dict[str, Any]]:
"""
The default implementation of build_standard_field will set any
field with choices to a ChoiceField. This will override that for
EnumFields and add enum and strict arguments to the field's kwargs.
To use this mixin, include it before ModelSerializer in your
serializer's class hierarchy:
.. code-block:: python
from django_enum.drf import EnumFieldMixin
from rest_framework.serializers import ModelSerializer
class MySerializer(EnumFieldMixin, ModelSerializer):
class Meta:
model = MyModel
fields = '__all__'
:param field_name: The name of the field on the serializer
:param model_field: The Field instance on the model
:return: A 2-tuple, the first element is the field class, the
second is the kwargs for the field
"""
field_class = ClassLookupdict(
{FlagModelField: FlagField, EnumModelField: EnumField}
)[model_field]
if field_class:
return field_class, {
"enum": model_field.enum,
"strict": model_field.strict,
"field_name": field_name,
"model_field": model_field,
**super().build_standard_field(field_name, model_field)[1],
}
return super().build_standard_field(field_name, model_field)