Add cors to backend server (#524)

This commit is contained in:
Nikhil Rao 2023-02-12 13:38:30 -08:00 committed by GitHub
parent 0bd8f6acfc
commit fb9b8a8c83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 0 deletions

View File

@ -3,6 +3,7 @@
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware import cors
from socketio import ASGIApp, AsyncNamespace, AsyncServer from socketio import ASGIApp, AsyncNamespace, AsyncServer
from pynecone import constants, utils from pynecone import constants, utils
@ -74,6 +75,8 @@ class App(Base):
# Set up the API. # Set up the API.
self.api = FastAPI() self.api = FastAPI()
self.add_cors(config.cors_allowed_origins)
self.add_default_endpoints()
# Set up CORS options. # Set up CORS options.
cors_allowed_origins = config.cors_allowed_origins cors_allowed_origins = config.cors_allowed_origins
@ -116,6 +119,26 @@ class App(Base):
""" """
return self.api return self.api
def add_default_endpoints(self):
"""Add the default endpoints."""
# To test the server.
self.api.get(str(constants.Endpoint.PING))(ping)
def add_cors(self, allowed_origins: Optional[List[str]] = None):
"""Add CORS middleware to the app.
Args:
allowed_origins: A list of allowed origins.
"""
allowed_origins = allowed_origins or ["*"]
self.api.add_middleware(
cors.CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def preprocess(self, state: State, event: Event) -> Optional[Delta]: def preprocess(self, state: State, event: Event) -> Optional[Delta]:
"""Preprocess the event. """Preprocess the event.
@ -392,6 +415,15 @@ async def process(
return update return update
async def ping() -> str:
"""Test API endpoint.
Returns:
The response.
"""
return "pong"
class EventNamespace(AsyncNamespace): class EventNamespace(AsyncNamespace):
"""The event namespace.""" """The event namespace."""

View File

@ -159,6 +159,7 @@ class LogLevel(str, Enum):
class Endpoint(Enum): class Endpoint(Enum):
"""Endpoints for the pynecone backend API.""" """Endpoints for the pynecone backend API."""
PING = "ping"
EVENT = "event" EVENT = "event"
def __str__(self) -> str: def __str__(self) -> str: