add special handling for infinity and nan (#3943)

* add special handling for infinity and nan

* use custom exception

* add test for inf and nan
This commit is contained in:
Khaleel Al-Adhami 2024-09-18 13:10:32 -07:00 committed by GitHub
parent a8734d7392
commit 91b50d713e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 2 deletions

View File

@ -107,3 +107,7 @@ class EventHandlerShadowsBuiltInStateMethod(ReflexError, NameError):
class GeneratedCodeHasNoFunctionDefs(ReflexError): class GeneratedCodeHasNoFunctionDefs(ReflexError):
"""Raised when refactored code generated with flexgen has no functions defined.""" """Raised when refactored code generated with flexgen has no functions defined."""
class PrimitiveUnserializableToJSON(ReflexError, ValueError):
"""Raised when a primitive type is unserializable to JSON. Usually with NaN and Infinity."""

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import dataclasses import dataclasses
import json import json
import math
import sys import sys
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -18,7 +19,7 @@ from typing import (
) )
from reflex.constants.base import Dirs from reflex.constants.base import Dirs
from reflex.utils.exceptions import VarTypeError from reflex.utils.exceptions import PrimitiveUnserializableToJSON, VarTypeError
from reflex.utils.imports import ImportDict, ImportVar from reflex.utils.imports import ImportDict, ImportVar
from .base import ( from .base import (
@ -1040,7 +1041,14 @@ class LiteralNumberVar(LiteralVar, NumberVar):
Returns: Returns:
The JSON representation of the var. The JSON representation of the var.
Raises:
PrimitiveUnserializableToJSON: If the var is unserializable to JSON.
""" """
if math.isinf(self._var_value) or math.isnan(self._var_value):
raise PrimitiveUnserializableToJSON(
f"No valid JSON representation for {self}"
)
return json.dumps(self._var_value) return json.dumps(self._var_value)
def __hash__(self) -> int: def __hash__(self) -> int:
@ -1062,8 +1070,15 @@ class LiteralNumberVar(LiteralVar, NumberVar):
Returns: Returns:
The number var. The number var.
""" """
if math.isinf(value):
js_expr = "Infinity" if value > 0 else "-Infinity"
elif math.isnan(value):
js_expr = "NaN"
else:
js_expr = str(value)
return cls( return cls(
_js_expr=str(value), _js_expr=js_expr,
_var_type=type(value), _var_type=type(value),
_var_data=_var_data, _var_data=_var_data,
_var_value=value, _var_value=value,

View File

@ -9,6 +9,7 @@ from pandas import DataFrame
from reflex.base import Base from reflex.base import Base
from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG from reflex.constants.base import REFLEX_VAR_CLOSING_TAG, REFLEX_VAR_OPENING_TAG
from reflex.state import BaseState from reflex.state import BaseState
from reflex.utils.exceptions import PrimitiveUnserializableToJSON
from reflex.utils.imports import ImportVar from reflex.utils.imports import ImportVar
from reflex.vars import VarData from reflex.vars import VarData
from reflex.vars.base import ( from reflex.vars.base import (
@ -989,6 +990,22 @@ def test_index_operation():
assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)" assert str(array_var[0].to(NumberVar) + 9) == "([1, 2, 3, 4, 5].at(0) + 9)"
@pytest.mark.parametrize(
"var, expected_js",
[
(Var.create(float("inf")), "Infinity"),
(Var.create(-float("inf")), "-Infinity"),
(Var.create(float("nan")), "NaN"),
],
)
def test_inf_and_nan(var, expected_js):
assert str(var) == expected_js
assert isinstance(var, NumberVar)
assert isinstance(var, LiteralVar)
with pytest.raises(PrimitiveUnserializableToJSON):
var.json()
def test_array_operations(): def test_array_operations():
array_var = LiteralArrayVar.create([1, 2, 3, 4, 5]) array_var = LiteralArrayVar.create([1, 2, 3, 4, 5])