From b06f612a7dc6afd724c317cd729d32ccf8a83bf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Brand=C3=A9ho?= Date: Sun, 29 Jan 2023 20:20:06 +0100 Subject: [PATCH] Dynamically add vars to a State (#381) --- pynecone/base.py | 20 +++++++++++++ pynecone/state.py | 68 ++++++++++++++++++++++++++++++++++++--------- tests/test_state.py | 17 ++++++++++++ 3 files changed, 92 insertions(+), 13 deletions(-) diff --git a/pynecone/base.py b/pynecone/base.py index 66818342d..034483bc3 100644 --- a/pynecone/base.py +++ b/pynecone/base.py @@ -4,6 +4,7 @@ from __future__ import annotations from typing import Any, Dict, TypeVar import pydantic +from pydantic.fields import ModelField # Typevar to represent any class subclassing Base. PcType = TypeVar("PcType") @@ -55,6 +56,25 @@ class Base(pydantic.BaseModel): """ return cls.__fields__ + @classmethod + def add_field(cls, var: Any, default_value: Any): + """Add a pydantic field after class definition. + + Used by State.add_var() to correctly handle the new variable. + + Args: + var: The variable to add a pydantic field for. + default_value: The default value of the field + """ + new_field = ModelField.infer( + name=var.name, + value=default_value, + annotation=var.type_, + class_validators=None, + config=cls.__config__, + ) + cls.__fields__.update({var.name: new_field}) + def get_value(self, key: str) -> Any: """Get the value of a field. diff --git a/pynecone/state.py b/pynecone/state.py index a5901e9f1..f356f9723 100644 --- a/pynecone/state.py +++ b/pynecone/state.py @@ -109,9 +109,6 @@ class State(Base, ABC): Args: **kwargs: The kwargs to pass to the pydantic init_subclass method. - - Raises: - TypeError: If the class has a var with an invalid type. """ super().__init_subclass__(**kwargs) @@ -146,16 +143,7 @@ class State(Base, ABC): # Setup the base vars at the class level. for prop in cls.base_vars.values(): - if not utils.is_valid_var_type(prop.type_): - raise TypeError( - "State vars must be primitive Python types, " - "Plotly figures, Pandas dataframes, " - "or subclasses of pc.Base. " - f'Found var "{prop.name}" with type {prop.type_}.' - ) - cls._set_var(prop) - cls._create_setter(prop) - cls._set_default_value(prop) + cls._init_var(prop) # Set up the event handlers. events = { @@ -261,6 +249,60 @@ class State(Base, ABC): raise ValueError(f"Invalid path: {path}") return getattr(substate, name) + @classmethod + def _init_var(cls, prop: BaseVar): + """Initialize a variable. + + Args: + prop (BaseVar): The variable to initialize + + Raises: + TypeError: if the variable has an incorrect type + """ + if not utils.is_valid_var_type(prop.type_): + raise TypeError( + "State vars must be primitive Python types, " + "Plotly figures, Pandas dataframes, " + "or subclasses of pc.Base. " + f'Found var "{prop.name}" with type {prop.type_}.' + ) + cls._set_var(prop) + cls._create_setter(prop) + cls._set_default_value(prop) + + @classmethod + def add_var(cls, name: str, type_: Any, default_value: Any = None): + """Add dynamically a variable to the State. + + The variable added this way can be used in the same way as a variable + defined statically in the model. + + Args: + name (str): The name of the variable + type_ (Any): The type of the variable + default_value (Any): The default value of the variable + + Raises: + NameError: if a variable of this name already exists + """ + if name in cls.__fields__: + raise NameError( + f"The variable '{name}' already exist. Use a different name" + ) + + # create the variable based on name and type + var = BaseVar(name=name, type_=type_) + var.set_state(cls) + + # add the pydantic field dynamically (must be done before _init_var) + cls.add_field(var, default_value) + + cls._init_var(var) + + # update the internal dicts so the new variable is correctly handled + cls.base_vars.update({name: var}) + cls.vars.update({name: var}) + @classmethod def _set_var(cls, prop: BaseVar): """Set the var as a class member. diff --git a/tests/test_state.py b/tests/test_state.py index 39112014e..ea4930872 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -638,3 +638,20 @@ def test_get_query_params(test_state): test_state.router_data = {RouteVar.QUERY: params} assert test_state.get_query_params() == params + + +def test_add_var(test_state): + test_state.add_var("dynamic_int", int, 42) + assert test_state.dynamic_int == 42 + + test_state.add_var("dynamic_list", List[int], [5, 10]) + assert test_state.dynamic_list == [5, 10] + assert getattr(test_state, "dynamic_list") == [5, 10] + + # how to test that one? + # test_state.dynamic_list.append(15) + # assert test_state.dynamic_list == [5, 10, 15] + + test_state.add_var("dynamic_dict", Dict[str, int], {"k1": 5, "k2": 10}) + assert test_state.dynamic_dict == {"k1": 5, "k2": 10} + assert getattr(test_state, "dynamic_dict") == {"k1": 5, "k2": 10}