diff --git a/pynecone/utils.py b/pynecone/utils.py index a0b4f00a9..98c7ea54f 100644 --- a/pynecone/utils.py +++ b/pynecone/utils.py @@ -46,6 +46,9 @@ join = os.linesep.join # Console for pretty printing. console = Console() +# Union of generic types. +GenericType = Union[Type, _GenericAlias] + def get_args(alias: _GenericAlias) -> Tuple[Type, ...]: """Get the arguments of a type alias. @@ -59,7 +62,59 @@ def get_args(alias: _GenericAlias) -> Tuple[Type, ...]: return alias.__args__ -def get_base_class(cls: Type) -> Type: +def is_generic_alias(cls: GenericType) -> bool: + """Check whether the class is a generic alias. + + Args: + cls: The class to check. + + Returns: + Whether the class is a generic alias. + """ + # For older versions of Python. + if isinstance(cls, _GenericAlias): + return True + + try: + from typing import _SpecialGenericAlias # type: ignore + + if isinstance(cls, _SpecialGenericAlias): + return True + except ImportError: + pass + + # For newer versions of Python. + try: + from types import GenericAlias # type: ignore + + return isinstance(cls, GenericAlias) + except ImportError: + return False + + +def is_union(cls: GenericType) -> bool: + """Check if a class is a Union. + + Args: + cls: The class to check. + + Returns: + Whether the class is a Union. + """ + try: + from typing import _UnionGenericAlias # type: ignore + + return isinstance(cls, _UnionGenericAlias) + except ImportError: + pass + + if is_generic_alias(cls): + return cls.__origin__ == Union + + return False + + +def get_base_class(cls: GenericType) -> Type: """Get the base class of a class. Args: @@ -68,37 +123,16 @@ def get_base_class(cls: Type) -> Type: Returns: The base class of the class. """ - # For newer versions of Python. - try: - from types import GenericAlias # type: ignore + if is_union(cls): + return tuple(get_base_class(arg) for arg in get_args(cls)) - if isinstance(cls, GenericAlias): - return get_base_class(cls.__origin__) - except: - pass - - # Check Union types first. - try: - from typing import _UnionGenericAlias # type: ignore - - if isinstance(cls, _UnionGenericAlias): - return tuple(get_base_class(arg) for arg in get_args(cls)) - except: - pass - - # Check other generic aliases. - if isinstance(cls, _GenericAlias): - if cls.__origin__ == Union: - return tuple(get_base_class(arg) for arg in get_args(cls)) + if is_generic_alias(cls): return get_base_class(cls.__origin__) - # This is the base class. return cls -def _issubclass( - cls: Union[Type, _GenericAlias], cls_check: Union[Type, _GenericAlias] -) -> bool: +def _issubclass(cls: GenericType, cls_check: GenericType) -> bool: """Check if a class is a subclass of another class. Args: @@ -118,7 +152,7 @@ def _issubclass( return cls_check_base == Any or issubclass(cls_base, cls_check_base) -def _isinstance(obj: Any, cls: Union[Type, _GenericAlias]) -> bool: +def _isinstance(obj: Any, cls: GenericType) -> bool: """Check if an object is an instance of a class. Args: diff --git a/pynecone/var.py b/pynecone/var.py index a8477b5fc..9218804f8 100644 --- a/pynecone/var.py +++ b/pynecone/var.py @@ -149,14 +149,14 @@ class Var(ABC): assert isinstance( i, utils.get_args(Union[int, Var]) ), "Index must be an integer." - if isinstance(self.type_, _GenericAlias): + if utils.is_generic_alias(self.type_): type_ = utils.get_args(self.type_)[0] else: type_ = Any elif utils._issubclass(self.type_, Dict) or utils.is_dataframe(self.type_): if isinstance(i, str): i = utils.wrap(i, '"') - if isinstance(self.type_, _GenericAlias): + if utils.is_generic_alias(self.type_): type_ = utils.get_args(self.type_)[1] else: type_ = Any diff --git a/pyproject.toml b/pyproject.toml index f49072afa..9b68914d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pynecone-io" -version = "0.1.3" +version = "0.1.6" description = "" authors = [ "Nikhil Rao ", diff --git a/tests/test_utils.py b/tests/test_utils.py index 3853080f7..ab6f010ad 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,5 @@ +import typing + import pytest from pynecone import utils @@ -181,3 +183,24 @@ def test_merge_imports(): assert set(d.keys()) == {"react", "react-dom"} assert set(d["react"]) == {"Component"} assert set(d["react-dom"]) == {"render"} + + +@pytest.mark.parametrize( + "cls,expected", + [ + (str, False), + (int, False), + (float, False), + (bool, False), + (typing.List, True), + (typing.List[int], True), + ], +) +def test_is_generic_alias(cls: type, expected: bool): + """Test checking if a class is a GenericAlias. + + Args: + cls: The class to check. + expected: Whether the class is a GenericAlias. + """ + assert utils.is_generic_alias(cls) == expected