Dynamically add vars to a State (#381)
This commit is contained in:
parent
526e417e8f
commit
b06f612a7d
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
from typing import Any, Dict, TypeVar
|
||||
|
||||
import pydantic
|
||||
from pydantic.fields import ModelField
|
||||
|
||||
# Typevar to represent any class subclassing Base.
|
||||
PcType = TypeVar("PcType")
|
||||
@ -55,6 +56,25 @@ class Base(pydantic.BaseModel):
|
||||
"""
|
||||
return cls.__fields__
|
||||
|
||||
@classmethod
|
||||
def add_field(cls, var: Any, default_value: Any):
|
||||
"""Add a pydantic field after class definition.
|
||||
|
||||
Used by State.add_var() to correctly handle the new variable.
|
||||
|
||||
Args:
|
||||
var: The variable to add a pydantic field for.
|
||||
default_value: The default value of the field
|
||||
"""
|
||||
new_field = ModelField.infer(
|
||||
name=var.name,
|
||||
value=default_value,
|
||||
annotation=var.type_,
|
||||
class_validators=None,
|
||||
config=cls.__config__,
|
||||
)
|
||||
cls.__fields__.update({var.name: new_field})
|
||||
|
||||
def get_value(self, key: str) -> Any:
|
||||
"""Get the value of a field.
|
||||
|
||||
|
@ -109,9 +109,6 @@ class State(Base, ABC):
|
||||
|
||||
Args:
|
||||
**kwargs: The kwargs to pass to the pydantic init_subclass method.
|
||||
|
||||
Raises:
|
||||
TypeError: If the class has a var with an invalid type.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@ -146,16 +143,7 @@ class State(Base, ABC):
|
||||
|
||||
# Setup the base vars at the class level.
|
||||
for prop in cls.base_vars.values():
|
||||
if not utils.is_valid_var_type(prop.type_):
|
||||
raise TypeError(
|
||||
"State vars must be primitive Python types, "
|
||||
"Plotly figures, Pandas dataframes, "
|
||||
"or subclasses of pc.Base. "
|
||||
f'Found var "{prop.name}" with type {prop.type_}.'
|
||||
)
|
||||
cls._set_var(prop)
|
||||
cls._create_setter(prop)
|
||||
cls._set_default_value(prop)
|
||||
cls._init_var(prop)
|
||||
|
||||
# Set up the event handlers.
|
||||
events = {
|
||||
@ -261,6 +249,60 @@ class State(Base, ABC):
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
return getattr(substate, name)
|
||||
|
||||
@classmethod
|
||||
def _init_var(cls, prop: BaseVar):
|
||||
"""Initialize a variable.
|
||||
|
||||
Args:
|
||||
prop (BaseVar): The variable to initialize
|
||||
|
||||
Raises:
|
||||
TypeError: if the variable has an incorrect type
|
||||
"""
|
||||
if not utils.is_valid_var_type(prop.type_):
|
||||
raise TypeError(
|
||||
"State vars must be primitive Python types, "
|
||||
"Plotly figures, Pandas dataframes, "
|
||||
"or subclasses of pc.Base. "
|
||||
f'Found var "{prop.name}" with type {prop.type_}.'
|
||||
)
|
||||
cls._set_var(prop)
|
||||
cls._create_setter(prop)
|
||||
cls._set_default_value(prop)
|
||||
|
||||
@classmethod
|
||||
def add_var(cls, name: str, type_: Any, default_value: Any = None):
|
||||
"""Add dynamically a variable to the State.
|
||||
|
||||
The variable added this way can be used in the same way as a variable
|
||||
defined statically in the model.
|
||||
|
||||
Args:
|
||||
name (str): The name of the variable
|
||||
type_ (Any): The type of the variable
|
||||
default_value (Any): The default value of the variable
|
||||
|
||||
Raises:
|
||||
NameError: if a variable of this name already exists
|
||||
"""
|
||||
if name in cls.__fields__:
|
||||
raise NameError(
|
||||
f"The variable '{name}' already exist. Use a different name"
|
||||
)
|
||||
|
||||
# create the variable based on name and type
|
||||
var = BaseVar(name=name, type_=type_)
|
||||
var.set_state(cls)
|
||||
|
||||
# add the pydantic field dynamically (must be done before _init_var)
|
||||
cls.add_field(var, default_value)
|
||||
|
||||
cls._init_var(var)
|
||||
|
||||
# update the internal dicts so the new variable is correctly handled
|
||||
cls.base_vars.update({name: var})
|
||||
cls.vars.update({name: var})
|
||||
|
||||
@classmethod
|
||||
def _set_var(cls, prop: BaseVar):
|
||||
"""Set the var as a class member.
|
||||
|
@ -638,3 +638,20 @@ def test_get_query_params(test_state):
|
||||
test_state.router_data = {RouteVar.QUERY: params}
|
||||
|
||||
assert test_state.get_query_params() == params
|
||||
|
||||
|
||||
def test_add_var(test_state):
|
||||
test_state.add_var("dynamic_int", int, 42)
|
||||
assert test_state.dynamic_int == 42
|
||||
|
||||
test_state.add_var("dynamic_list", List[int], [5, 10])
|
||||
assert test_state.dynamic_list == [5, 10]
|
||||
assert getattr(test_state, "dynamic_list") == [5, 10]
|
||||
|
||||
# how to test that one?
|
||||
# test_state.dynamic_list.append(15)
|
||||
# assert test_state.dynamic_list == [5, 10, 15]
|
||||
|
||||
test_state.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10})
|
||||
assert test_state.dynamic_dict == {"k1": 5, "k2": 10}
|
||||
assert getattr(test_state, "dynamic_dict") == {"k1": 5, "k2": 10}
|
||||
|
Loading…
Reference in New Issue
Block a user