diff --git a/reflex/state.py b/reflex/state.py index 20552249b..e055d77fe 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -151,6 +151,7 @@ RESERVED_BACKEND_VAR_NAMES = { "_substate_var_dependencies", "_always_dirty_computed_vars", "_always_dirty_substates", + "_abc_impl", # pydantic v2 "_was_touched", } @@ -281,6 +282,11 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): "See https://reflex.dev/docs/state/ for further information." ) kwargs["parent_state"] = parent_state + + for prop_name, prop in self.base_vars.items(): + if prop_name not in kwargs and self.model_fields[prop_name].is_required(): + kwargs[prop_name] = prop.get_default_value() + super().__init__(*args, **kwargs) # Setup the substates (for memory state manager only). @@ -334,7 +340,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): Returns: The string representation of the state. """ - return f"{self.__class__.__name__}({self.dict()})" + return f"{type(self).__name__}({self.dict()})" @classmethod def _get_computed_vars(cls) -> list[ComputedVar]: @@ -401,8 +407,8 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): computed_vars = cls._get_computed_vars() new_backend_vars = { - name: value - for name, value in cls.__dict__.items() + name: value.default + for name, value in cls.__private_attributes__.items() if types.is_backend_variable(name, cls) and name not in RESERVED_BACKEND_VAR_NAMES and name not in cls.inherited_backend_vars @@ -738,7 +744,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): ) cls._set_var(prop) cls._create_setter(prop) - cls._set_default_value(prop) @classmethod def add_var(cls, name: str, type_: Any, default_value: Any = None): @@ -802,28 +807,6 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): cls.event_handlers[setter_name] = event_handler setattr(cls, setter_name, event_handler) - @classmethod - def _set_default_value(cls, prop: BaseVar): - """Set the default value for the var. - - Args: - prop: The var to set the default value for. - """ - # Get the pydantic field for the var. - field = cls.get_fields()[prop._var_name] - if field.required: - default_value = prop.get_default_value() - if default_value is not None: - field.required = False - field.default = default_value - if ( - not field.required - and field.default is None - and not types.is_optional(prop._var_type) - ): - # Ensure frontend uses null coalescing when accessing. - prop._var_type = Optional[prop._var_type] - @staticmethod def _get_base_functions() -> dict[str, FunctionType]: """Get all functions of the state class excluding dunder methods. @@ -1006,6 +989,9 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): # If the state hasn't been initialized yet, return the default value. if not super().__getattribute__("__dict__"): return super().__getattribute__(name) + private_attrs = super().__getattribute__("__pydantic_private__") + if private_attrs is None: + return super().__getattribute__(name) inherited_vars = { **super().__getattribute__("inherited_vars"), @@ -1018,7 +1004,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow): if parent_state is not None: return getattr(parent_state, name) - backend_vars = super().__getattribute__("_backend_vars") + backend_vars = private_attrs["_backend_vars"] if name in backend_vars: value = backend_vars[name] else: