From ab4fd41e55742326a984a799a09254d6241abe43 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Fri, 25 Oct 2024 17:34:47 -0700 Subject: [PATCH] make vardata merge not use classmethod (#4245) * make vardata merge not use classmethod * add clarifying comment * use simple cases for small values * add possible None * allow zero values to be given to var data * dang it darglint --- reflex/vars/base.py | 47 ++++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 2007bc091..2f26e9170 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -151,31 +151,41 @@ class VarData: """ return dict((k, list(v)) for k, v in self.imports) - @classmethod - def merge(cls, *others: VarData | None) -> VarData | None: + def merge(*all: VarData | None) -> VarData | None: """Merge multiple var data objects. Args: - *others: The var data objects to merge. + *all: The var data objects to merge. Returns: The merged var data object. + + # noqa: DAR102 *all """ - state = "" - field_name = "" - _imports = {} - hooks = {} - for var_data in others: - if var_data is None: - continue - state = state or var_data.state - field_name = field_name or var_data.field_name - _imports = imports.merge_imports(_imports, var_data.imports) - hooks.update( - var_data.hooks - if isinstance(var_data.hooks, dict) - else {k: None for k in var_data.hooks} - ) + all_var_datas = list(filter(None, all)) + + if not all_var_datas: + return None + + if len(all_var_datas) == 1: + return all_var_datas[0] + + # Get the first non-empty field name or default to empty string. + field_name = next( + (var_data.field_name for var_data in all_var_datas if var_data.field_name), + "", + ) + + # Get the first non-empty state or default to empty string. + state = next( + (var_data.state for var_data in all_var_datas if var_data.state), "" + ) + + hooks = {hook: None for var_data in all_var_datas for hook in var_data.hooks} + + _imports = imports.merge_imports( + *(var_data.imports for var_data in all_var_datas) + ) if state or _imports or hooks or field_name: return VarData( @@ -184,6 +194,7 @@ class VarData: imports=_imports, hooks=hooks, ) + return None def __bool__(self) -> bool: