Add case where get_var_value gets something that's not a var

This commit is contained in:
Masen Furer 2024-12-19 16:45:04 -08:00
parent a52479290a
commit c93119a901
No known key found for this signature in database
GPG Key ID: 2AE2BD5531FF94F4

View File

@ -145,7 +145,7 @@ HANDLED_PICKLE_ERRORS = (
)
# For BaseState.get_var_value
VAR_TYPE = TypeVar("VAR_TYPE", bound=Var)
VAR_TYPE = TypeVar("VAR_TYPE")
def _no_chain_background_task(
@ -1613,9 +1613,14 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
UnretrievableVarValueError: If the var does not have a literal value
or associated state.
"""
# Oopsie case: you didn't give me a Var... so get what you give.
if not isinstance(var, Var):
return var # type: ignore
# Fast case: this is a literal var and the value is known.
if hasattr(var, "_var_value"):
return var._var_value
var_data = var._get_all_var_data()
if var_data is None or not var_data.state:
raise UnretrievableVarValueError(
@ -1624,6 +1629,7 @@ class BaseState(Base, ABC, extra=pydantic.Extra.allow):
# Fastish case: this var belongs to this state
if var_data.state == self.get_full_name():
return getattr(self, var_data.field_name)
# Slow case: this var belongs to another state
other_state = await self.get_state(
self._get_root_state().get_class_substate(var_data.state)