From 0812edd3a2cd9851691ed4b17ed29076e6dae3b5 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Wed, 11 Sep 2024 21:32:18 +0200 Subject: [PATCH] improve datatable hacks, also default to ObjectVar for Unions --- reflex/components/gridjs/datatable.py | 10 ++-------- reflex/ivars/base.py | 2 +- reflex/utils/types.py | 6 ++++++ tests/components/datadisplay/conftest.py | 2 +- tests/components/datadisplay/test_datatable.py | 12 ++++++------ 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/reflex/components/gridjs/datatable.py b/reflex/components/gridjs/datatable.py index cd1acf100..ce39ade63 100644 --- a/reflex/components/gridjs/datatable.py +++ b/reflex/components/gridjs/datatable.py @@ -116,14 +116,8 @@ class DataTable(Gridjs): if isinstance(self.data, ImmutableVar) and types.is_dataframe( self.data._var_type ): - self.columns = self.data._replace( - _var_name=f"{self.data._var_name}.columns", - _var_type=List[Any], - ) - self.data = self.data._replace( - _var_name=f"{self.data._var_name}.data", - _var_type=List[List[Any]], - ) + self.columns = self.data.columns + self.data = self.data.data if types.is_dataframe(type(self.data)): # If given a pandas df break up the data and columns data = serialize(self.data) diff --git a/reflex/ivars/base.py b/reflex/ivars/base.py index f9029ddac..18beaad61 100644 --- a/reflex/ivars/base.py +++ b/reflex/ivars/base.py @@ -482,7 +482,7 @@ class ImmutableVar(Var, Generic[VAR_TYPE]): if all(inspect.isclass(t) and issubclass(t, Base) for t in inner_types): return self.to(ObjectVar, self._var_type) - return self + return self.to(ObjectVar, self._var_type) if not inspect.isclass(fixed_type): raise TypeError(f"Unsupported type {var_type} for guess_type.") diff --git a/reflex/utils/types.py b/reflex/utils/types.py index f4463fa92..cd54ad59d 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -338,6 +338,12 @@ def get_attribute_access_type(cls: GenericType, name: str) -> GenericType | None return hints[name] except exceptions as e: console.warn(f"Failed to resolve ForwardRefs for {cls}.{name} due to {e}") + # hardcoded fallbacks for forward ref issues + if is_dataframe(cls): + if name == "columns": + return List[Any] + if name == "data": + return List[List[Any]] pass return None # Attribute is not accessible. diff --git a/tests/components/datadisplay/conftest.py b/tests/components/datadisplay/conftest.py index 13c571c8c..956185d72 100644 --- a/tests/components/datadisplay/conftest.py +++ b/tests/components/datadisplay/conftest.py @@ -10,7 +10,7 @@ from reflex.state import BaseState @pytest.fixture -def data_table_state(request): +def data_table_state(request: pytest.FixtureRequest): """Get a data table state. Args: diff --git a/tests/components/datadisplay/test_datatable.py b/tests/components/datadisplay/test_datatable.py index b3d31ea32..19eaa46b0 100644 --- a/tests/components/datadisplay/test_datatable.py +++ b/tests/components/datadisplay/test_datatable.py @@ -23,7 +23,7 @@ from reflex.utils.serializers import serialize, serialize_dataframe ], indirect=["data_table_state"], ) -def test_validate_data_table(data_table_state: rx.State, expected): +def test_validate_data_table(data_table_state: rx.State, expected: str) -> None: """Test the str/render function. Args: @@ -40,13 +40,13 @@ def test_validate_data_table(data_table_state: rx.State, expected): data_table_dict = data_table_component.render() - # prefix expected with state name - state_name = data_table_state.get_name() - expected = f"{state_name}.{expected}" if expected else state_name + var = data_table_state + if expected: + var = getattr(var, expected) assert data_table_dict["props"] == [ - f"columns={{{expected}.columns}}", - f"data={{{expected}.data}}", + f"columns={{{var.columns._var_name}}}", + f"data={{{var.data._var_name}}}", ]