diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 125dc2615..c290f7a64 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -30,6 +30,26 @@ from reflex import constants from reflex.base import Base from reflex.utils import serializers +# Potential GenericAlias types for isinstance checks. +GenericAliasTypes = [_GenericAlias] + +with contextlib.suppress(ImportError): + # For newer versions of Python. + from types import GenericAlias # type: ignore + + GenericAliasTypes.append(GenericAlias) + +with contextlib.suppress(ImportError): + # For older versions of Python. + from typing import _SpecialGenericAlias # type: ignore + + GenericAliasTypes.append(_SpecialGenericAlias) + +GenericAliasTypes = tuple(GenericAliasTypes) + +# Potential Union types for isinstance checks (UnionType added in py3.10). +UnionTypes = (Union, types.UnionType) if hasattr(types, "UnionType") else (Union,) + # Union of generic types. GenericType = Union[Type, _GenericAlias] @@ -75,22 +95,7 @@ def is_generic_alias(cls: GenericType) -> bool: Returns: Whether the class is a generic alias. """ - # For older versions of Python. - if isinstance(cls, _GenericAlias): - return True - - with contextlib.suppress(ImportError): - from typing import _SpecialGenericAlias # type: ignore - - if isinstance(cls, _SpecialGenericAlias): - return True - # For newer versions of Python. - try: - from types import GenericAlias # type: ignore - - return isinstance(cls, GenericAlias) - except ImportError: - return False + return isinstance(cls, GenericAliasTypes) def is_union(cls: GenericType) -> bool: @@ -102,11 +107,7 @@ def is_union(cls: GenericType) -> bool: Returns: Whether the class is a Union. """ - # UnionType added in py3.10 - if not hasattr(types, "UnionType"): - return get_origin(cls) is Union - - return get_origin(cls) in [Union, types.UnionType] + return get_origin(cls) in UnionTypes def is_literal(cls: GenericType) -> bool: diff --git a/reflex/vars.py b/reflex/vars.py index eed60a946..9b1725cc4 100644 --- a/reflex/vars.py +++ b/reflex/vars.py @@ -229,6 +229,13 @@ def _encode_var(value: Var) -> str: return str(value) +# Compile regex for finding reflex var tags. +_decode_var_pattern_re = ( + rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}" +) +_decode_var_pattern = re.compile(_decode_var_pattern_re, flags=re.DOTALL) + + def _decode_var(value: str) -> tuple[VarData | None, str]: """Decode the state name from a formatted var. @@ -240,6 +247,10 @@ def _decode_var(value: str) -> tuple[VarData | None, str]: """ var_datas = [] if isinstance(value, str): + # fast path if there is no encoded VarData + if constants.REFLEX_VAR_OPENING_TAG not in value: + return None, value + offset = 0 # Initialize some methods for reading json. @@ -251,12 +262,8 @@ def _decode_var(value: str) -> tuple[VarData | None, str]: except json.decoder.JSONDecodeError: return var_data_config.json_loads(var_data_config.json_loads(f'"{s}"')) - # Compile regex for finding reflex var tags. - pattern_re = rf"{constants.REFLEX_VAR_OPENING_TAG}(.*?){constants.REFLEX_VAR_CLOSING_TAG}" - pattern = re.compile(pattern_re, flags=re.DOTALL) - # Find all tags. - while m := pattern.search(value): + while m := _decode_var_pattern.search(value): start, end = m.span() value = value[:start] + value[end:]