add special handling for infinity and nan (#3943)

* add special handling for infinity and nan

* use custom exception

* add test for inf and nan
This commit is contained in:
Khaleel Al-Adhami 2024-09-18 13:10:32 -07:00 committed by GitHub
parent a8734d7392
commit 91b50d713e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 2 deletions

View File

@ -107,3 +107,7 @@ class EventHandlerShadowsBuiltInStateMethod(ReflexError, NameError):
class GeneratedCodeHasNoFunctionDefs(ReflexError):
"""Raised when refactored code generated with flexgen has no functions defined."""
class PrimitiveUnserializableToJSON(ReflexError, ValueError):
"""Raised when a primitive type is unserializable to JSON. Usually with NaN and Infinity."""

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import dataclasses
import json
import math
import sys
from typing import (
TYPE_CHECKING,
@ -18,7 +19,7 @@ from typing import (
)
from reflex.constants.base import Dirs
from reflex.utils.exceptions import VarTypeError
from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
from reflex.utils.imports import ImportDict, ImportVar
from .base import (
@ -1040,7 +1041,14 @@ class LiteralNumberVar(LiteralVar, NumberVar):
Returns:
The JSON representation of the var.
Raises:
PrimitiveUnserializableToJSON: If the var is unserializable to JSON.
"""
if math.isinf(self._var_value) or math.isnan(self._var_value):
raise PrimitiveUnserializableToJSON(
f"No valid JSON representation for {self}"
)
return json.dumps(self._var_value)
def __hash__(self) -> int:
@ -1062,8 +1070,15 @@ class LiteralNumberVar(LiteralVar, NumberVar):
Returns:
The number var.
"""
if math.isinf(value):
js_expr = "Infinity" if value > 0 else "-Infinity"
elif math.isnan(value):
js_expr = "NaN"
else:
js_expr = str(value)
return cls(
_js_expr=str(value),
_js_expr=js_expr,
_var_type=type(value),
_var_data=_var_data,
_var_value=value,

View File

@ -9,6 +9,7 @@ from pandas import DataFrame
from reflex.base import Base
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.state import BaseState
from reflex.utils.exceptions import PrimitiveUnserializableToJSON
from reflex.utils.imports import ImportVar
from reflex.vars import VarData
from reflex.vars.base import (
@ -989,6 +990,22 @@ def test_index_operation():
assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)"
@pytest.mark.parametrize(
"var, expected_js",
[
(Var.create(float("inf")), "Infinity"),
(Var.create(-float("inf")), "-Infinity"),
(Var.create(float("nan")), "NaN"),
],
)
def test_inf_and_nan(var, expected_js):
assert str(var) == expected_js
assert isinstance(var, NumberVar)
assert isinstance(var, LiteralVar)
with pytest.raises(PrimitiveUnserializableToJSON):
var.json()
def test_array_operations():
array_var = LiteralArrayVar.create([1, 2, 3, 4, 5])