Improve event processing performance (#153)

This commit is contained in:
Nikhil Rao 2022-12-21 15:18:04 -08:00 committed by GitHub
parent f445febdc9
commit 57e278ae1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 66 additions and 35 deletions

6
poetry.lock generated
View File

@ -604,14 +604,14 @@ plugins = ["importlib-metadata"]
[[package]]
name = "pyright"
version = "1.1.284"
version = "1.1.285"
description = "Command line wrapper for pyright"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "pyright-1.1.284-py3-none-any.whl", hash = "sha256:e3bfbd33c20af48eed9d20138767265161ba8a4b55c740476a36ce822bd482d1"},
{file = "pyright-1.1.284.tar.gz", hash = "sha256:ef7c0e46e38be95687f5a0633e55c5171ca166048b9560558168a976162e287c"},
{file = "pyright-1.1.285-py3-none-any.whl", hash = "sha256:8a6b60b3ff0d000c549621c367cdf0013abdaf24d09e6f0b4b95031b357cc4b1"},
{file = "pyright-1.1.285.tar.gz", hash = "sha256:ecd28e8556352e2c7eb5f412c6841ec768d25e8a6136326d4a6a67d94370eba1"},
]
[package.dependencies]

View File

@ -123,8 +123,16 @@ export const updateState = async (state, result, setResult, router, socket) => {
* @param setResult The function to set the result.
* @param endpoint The endpoint to connect to.
*/
export const connect = async (socket, state, setResult, endpoint) => {
export const connect = async (socket, state, result, setResult, router, endpoint) => {
// Create the socket.
socket.current = new WebSocket(endpoint);
// Once the socket is open, hydrate the page.
socket.current.onopen = () => {
updateState(state, result, setResult, router, socket.current)
}
// On each received message, apply the delta and set the result.
socket.current.onmessage = function (update) {
update = JSON.parse(update.data);
applyDelta(state, update.delta);

View File

@ -144,7 +144,7 @@ USE_EFFECT = join(
[
"useEffect(() => {{",
f" if (!{SOCKET}.current) {{{{",
f" connect({SOCKET}, {{state}}, {SET_RESULT}, {EVENT_ENDPOINT})",
f" connect({SOCKET}, {{state}}, {RESULT}, {SET_RESULT}, {ROUTER}, {EVENT_ENDPOINT})",
" }}",
" const update = async () => {{",
f" if ({RESULT}.{STATE} != null) {{{{",

View File

@ -257,6 +257,17 @@ class State(Base, ABC):
field.required = False
field.default = default_value
def getattr(self, name: str) -> Any:
"""Get a non-prop attribute.
Args:
name: The name of the attribute.
Returns:
The attribute.
"""
return super().__getattribute__(name)
def __getattribute__(self, name: str) -> Any:
"""Get the attribute.
@ -287,17 +298,20 @@ class State(Base, ABC):
name: The name of the attribute.
value: The value of the attribute.
"""
if name != "inherited_vars" and name in self.inherited_vars:
setattr(self.parent_state, name, value)
# NOTE: We use super().__getattribute__ for performance reasons.
if name != "inherited_vars" and name in super().__getattribute__(
"inherited_vars"
):
setattr(super().__getattribute__("parent_state"), name, value)
return
# Set the attribute.
super().__setattr__(name, value)
# Add the var to the dirty list.
if name in self.vars:
self.dirty_vars.add(name)
self.mark_dirty()
if name in super().__getattribute__("vars"):
super().__getattribute__("dirty_vars").add(name)
super().__getattribute__("mark_dirty")()
def reset(self):
"""Reset all the base vars to their default values."""
@ -344,10 +358,11 @@ class State(Base, ABC):
Returns:
The state update after processing the event.
"""
# NOTE: We use super().__getattribute__ for performance reasons.
# Get the event handler.
path = event.name.split(".")
path, name = path[:-1], path[-1]
substate = self.get_substate(path)
substate = super().__getattribute__("get_substate")(path)
handler = getattr(substate, name)
# Process the event.
@ -368,10 +383,10 @@ class State(Base, ABC):
events = utils.fix_events(events, event.token)
# Get the delta after processing the event.
delta = self.get_delta()
delta = super().__getattribute__("get_delta")()
# Reset the dirty vars.
self.clean()
super().__getattribute__("clean")()
# Return the state update.
return StateUpdate(delta=delta, events=events)
@ -382,19 +397,22 @@ class State(Base, ABC):
Returns:
The delta for the state.
"""
# NOTE: We use super().__getattribute__ for performance reasons.
delta = {}
# Return the dirty vars, as well as all computed vars.
subdelta = {
prop: getattr(self, prop)
for prop in self.dirty_vars | set(self.computed_vars.keys())
for prop in super().__getattribute__("dirty_vars")
| set(super().__getattribute__("computed_vars").keys())
}
if len(subdelta) > 0:
delta[self.get_full_name()] = subdelta
delta[super().__getattribute__("get_full_name")()] = subdelta
# Recursively find the substate deltas.
for substate in self.dirty_substates:
delta.update(self.substates[substate].get_delta())
substates = super().__getattribute__("substates")
for substate in super().__getattribute__("dirty_substates"):
delta.update(substates[substate].getattr("get_delta")())
# Format the delta.
delta = utils.format_state(delta)
@ -410,13 +428,14 @@ class State(Base, ABC):
def clean(self):
"""Reset the dirty vars."""
# NOTE: We use super().__getattribute__ for performance reasons.
# Recursively clean the substates.
for substate in self.dirty_substates:
self.substates[substate].clean()
for substate in super().__getattribute__("dirty_substates"):
super().__getattribute__("substates")[substate].getattr("clean")()
# Clean this state.
self.dirty_vars = set()
self.dirty_substates = set()
super().__setattr__("dirty_vars", set())
super().__setattr__("dirty_substates", set())
def dict(self, include_computed: bool = True, **kwargs) -> Dict[str, Any]:
"""Convert the object to a dictionary.

View File

@ -851,8 +851,16 @@ def format_state(value: Any) -> Dict:
Raises:
TypeError: If the given value is not a valid state.
"""
# Handle dicts.
if isinstance(value, dict):
return {k: format_state(v) for k, v in value.items()}
# Return state vars as is.
if isinstance(value, StateBases):
return value
# Convert plotly figures to JSON.
if _isinstance(value, go.Figure):
if isinstance(value, go.Figure):
return json.loads(to_json(value))["data"]
# Convert pandas dataframes to JSON.
@ -862,19 +870,11 @@ def format_state(value: Any) -> Dict:
"data": value.values.tolist(),
}
# Handle dicts.
if _isinstance(value, dict):
return {k: format_state(v) for k, v in value.items()}
# Make sure the value is JSON serializable.
if not _isinstance(value, StateVar):
raise TypeError(
"State vars must be primitive Python types, "
"or subclasses of pc.Base. "
f"Got var of type {type(value)}."
)
return value
raise TypeError(
"State vars must be primitive Python types, "
"or subclasses of pc.Base. "
f"Got var of type {type(value)}."
)
def get_event(state, event):
@ -1069,3 +1069,7 @@ def get_redis() -> Optional[Redis]:
redis_url, redis_port = config.redis_url.split(":")
print("Using redis at", config.redis_url)
return Redis(host=redis_url, port=int(redis_port), db=0)
# Store this here for performance.
StateBases = get_base_class(StateVar)