Files
neon/test_runner/fixtures/neon_api.py
Mikhail 6689d6fd89 LFC prewarm perftest fixes: use existing staging project (#12651)
https://github.com/neondatabase/cloud/issues/19011

- Prewarm config changes are not publicly available.
  Correct the test by using a pre-filled 50 GB project on staging
- Create extension neon with schema neon to fix read performance tests
on staging, error example in
https://neon-github-public-dev.s3.amazonaws.com/reports/main/16483462789/index.html#suites/3d632da6dda4a70f5b4bd24904ab444c/919841e331089fc4/
- Don't create extra endpoint in LFC prewarm performance tests
2025-07-25 16:56:41 +00:00

480 lines
16 KiB
Python

from __future__ import annotations
import re
import time
from typing import TYPE_CHECKING, cast, final
import requests
from fixtures.log_helper import log
if TYPE_CHECKING:
from typing import Any, Literal
from fixtures.pg_version import PgVersion
def connstr_to_env(connstr: str) -> dict[str, str]:
# postgresql://neondb_owner:npg_kuv6Rqi1cB@ep-old-silence-w26pxsvz-pooler.us-east-2.aws.neon.build/neondb?sslmode=require&channel_binding=...'
parts = re.split(r":|@|\/|\?", connstr.removeprefix("postgresql://"))
return {
"PGUSER": parts[0],
"PGPASSWORD": parts[1],
"PGHOST": parts[2],
"PGDATABASE": parts[3],
}
def connection_parameters_to_env(params: dict[str, str]) -> dict[str, str]:
return {
"PGHOST": params["host"],
"PGDATABASE": params["database"],
"PGUSER": params["role"],
"PGPASSWORD": params["password"],
}
# Some API calls not yet implemented.
# You may want to copy not-yet-implemented methods from the PR https://github.com/neondatabase/neon/pull/11305
@final
class NeonAPI:
def __init__(self, neon_api_key: str, neon_api_base_url: str):
self.__neon_api_key = neon_api_key
self.__neon_api_base_url = neon_api_base_url.strip("/")
self.retry_if_possible = False
self.attempts = 10
self.sleep_before_retry = 1
self.retries524 = 0
self.retries4xx = 0
def __request(
self, method: str | bytes, endpoint: str, retry404: bool = False, **kwargs: Any
) -> requests.Response:
kwargs["headers"] = kwargs.get("headers", {})
kwargs["headers"]["Authorization"] = f"Bearer {self.__neon_api_key}"
for attempt in range(self.attempts):
retry = False
resp = requests.request(method, f"{self.__neon_api_base_url}{endpoint}", **kwargs)
if resp.status_code >= 400:
log.error(
"%s %s returned a %d: %s",
method,
endpoint,
resp.status_code,
resp.text if resp.status_code != 524 else "CloudFlare error page",
)
else:
log.debug("%s %s returned a %d: %s", method, endpoint, resp.status_code, resp.text)
if not self.retry_if_possible:
resp.raise_for_status()
break
elif resp.status_code >= 400:
if resp.status_code == 404 and retry404:
retry = True
self.retries4xx += 1
elif resp.status_code == 422 and resp.json()["message"] == "branch not ready yet":
retry = True
self.retries4xx += 1
elif resp.status_code == 423 and resp.json()["message"] in {
"endpoint is in some transitive state, could not suspend",
"project already has running conflicting operations, scheduling of new ones is prohibited",
}:
retry = True
self.retries4xx += 1
elif resp.status_code == 524:
log.info("The request was timed out")
retry = True
self.retries524 += 1
if retry:
log.info("Retrying, attempt %s/%s", attempt + 1, self.attempts)
time.sleep(self.sleep_before_retry)
continue
else:
resp.raise_for_status()
break
else:
raise RuntimeError("Max retry count is reached")
return resp
def create_project(
self,
pg_version: PgVersion | None = None,
name: str | None = None,
branch_name: str | None = None,
branch_role_name: str | None = None,
branch_database_name: str | None = None,
) -> dict[str, Any]:
data: dict[str, Any] = {
"project": {
"branch": {},
},
}
if name:
data["project"]["name"] = name
if pg_version:
data["project"]["pg_version"] = int(pg_version)
if branch_name:
data["project"]["branch"]["name"] = branch_name
if branch_role_name:
data["project"]["branch"]["role_name"] = branch_role_name
if branch_database_name:
data["project"]["branch"]["database_name"] = branch_database_name
resp = self.__request(
"POST",
"/projects",
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
json=data,
)
return cast("dict[str, Any]", resp.json())
def get_project_details(self, project_id: str) -> dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}",
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def get_project_limits(self, project_id: str) -> dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/limits",
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def delete_project(
self,
project_id: str,
) -> dict[str, Any]:
resp = self.__request(
"DELETE",
f"/projects/{project_id}",
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def create_branch(
self,
project_id: str,
branch_name: str | None = None,
parent_id: str | None = None,
parent_lsn: str | None = None,
parent_timestamp: str | None = None,
protected: bool | None = None,
archived: bool | None = None,
init_source: str | None = None,
add_endpoint: bool = True,
) -> dict[str, Any]:
data: dict[str, Any] = {}
if add_endpoint:
data["endpoints"] = [{"type": "read_write"}]
data["branch"] = {}
if parent_id:
data["branch"]["parent_id"] = parent_id
if branch_name:
data["branch"]["name"] = branch_name
if parent_lsn is not None:
data["branch"]["parent_lsn"] = parent_lsn
if parent_timestamp is not None:
data["branch"]["parent_timestamp"] = parent_timestamp
if protected is not None:
data["branch"]["protected"] = protected
if init_source is not None:
data["branch"]["init_source"] = init_source
if archived is not None:
data["branch"]["archived"] = archived
if not data["branch"]:
data.pop("branch")
resp = self.__request(
"POST",
f"/projects/{project_id}/branches",
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
json=data,
)
return cast("dict[str, Any]", resp.json())
def get_branch_details(self, project_id: str, branch_id: str) -> dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/branches/{branch_id}",
# XXX Retry get parent details to work around the issue
# https://databricks.atlassian.net/browse/LKB-279
retry404=True,
headers={
"Accept": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def delete_branch(self, project_id: str, branch_id: str) -> dict[str, Any]:
resp = self.__request(
"DELETE",
f"/projects/{project_id}/branches/{branch_id}",
headers={
"Accept": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def reset_to_parent(self, project_id: str, branch_id: str) -> dict[str, Any]:
resp = self.__request(
"POST",
f"/projects/{project_id}/branches/{branch_id}/reset_to_parent",
headers={
"Accept": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def restore_branch(
self,
project_id: str,
branch_id: str,
source_branch_id: str,
source_lsn: str | None,
source_timestamp: str | None,
preserve_under_name: str | None,
):
data = {"source_branch_id": source_branch_id}
if source_lsn:
data["source_lsn"] = source_lsn
if source_timestamp:
data["source_timestamp"] = source_timestamp
if preserve_under_name:
data["preserve_under_name"] = preserve_under_name
log.info("Data: %s", data)
resp = self.__request(
"POST",
f"/projects/{project_id}/branches/{branch_id}/restore",
headers={
"Accept": "application/json",
},
json=data,
)
return cast("dict[str, Any]", resp.json())
def start_endpoint(
self,
project_id: str,
endpoint_id: str,
) -> dict[str, Any]:
resp = self.__request(
"POST",
f"/projects/{project_id}/endpoints/{endpoint_id}/start",
headers={
"Accept": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def suspend_endpoint(
self,
project_id: str,
endpoint_id: str,
) -> dict[str, Any]:
resp = self.__request(
"POST",
f"/projects/{project_id}/endpoints/{endpoint_id}/suspend",
headers={
"Accept": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def restart_endpoint(
self,
project_id: str,
endpoint_id: str,
) -> dict[str, Any]:
resp = self.__request(
"POST",
f"/projects/{project_id}/endpoints/{endpoint_id}/restart",
headers={
"Accept": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def create_endpoint(
self,
project_id: str,
branch_id: str,
endpoint_type: Literal["read_write", "read_only"],
settings: dict[str, Any],
) -> dict[str, Any]:
data: dict[str, Any] = {
"endpoint": {
"branch_id": branch_id,
},
}
if endpoint_type:
data["endpoint"]["type"] = endpoint_type
if settings:
# otherwise we get 400 "settings must not be nil"
# TODO(myrrc): fix on cplane side
if "pg_settings" not in settings:
settings["pg_settings"] = {}
data["endpoint"]["settings"] = settings
resp = self.__request(
"POST",
f"/projects/{project_id}/endpoints",
headers={
"Accept": "application/json",
"Content-Type": "application/json",
},
json=data,
)
return cast("dict[str, Any]", resp.json())
def delete_endpoint(self, project_id: str, endpoint_id: str) -> dict[str, Any]:
resp = self.__request("DELETE", f"/projects/{project_id}/endpoints/{endpoint_id}")
return cast("dict[str,Any]", resp.json())
def get_connection_uri(
self,
project_id: str,
branch_id: str | None = None,
endpoint_id: str | None = None,
database_name: str = "neondb",
role_name: str = "neondb_owner",
pooled: bool = True,
) -> dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/connection_uri",
params={
"branch_id": branch_id,
"endpoint_id": endpoint_id,
"database_name": database_name,
"role_name": role_name,
"pooled": pooled,
},
headers={
"Accept": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def get_branches(self, project_id: str) -> dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/branches",
headers={
"Accept": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def get_endpoints(self, project_id: str) -> dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/endpoints",
headers={
"Accept": "application/json",
},
)
return cast("dict[str, Any]", resp.json())
def get_operations(self, project_id: str) -> dict[str, Any]:
resp = self.__request(
"GET",
f"/projects/{project_id}/operations",
headers={
"Accept": "application/json",
"Authorization": f"Bearer {self.__neon_api_key}",
},
)
return cast("dict[str, Any]", resp.json())
def wait_for_operation_to_finish(self, project_id: str):
has_running = True
while has_running:
has_running = False
operations = self.get_operations(project_id)["operations"]
for op in operations:
if op["status"] in {"scheduling", "running", "cancelling"}:
has_running = True
time.sleep(0.5)
@final
class NeonApiEndpoint:
def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: str | None):
self.neon_api = neon_api
self.project_id: str
self.endpoint_id: str
self.connstr: str
if project_id is None:
project = neon_api.create_project(pg_version)
neon_api.wait_for_operation_to_finish(cast("str", project["project"]["id"]))
self.project_id = project["project"]["id"]
self.endpoint_id = project["endpoints"][0]["id"]
self.connstr = project["connection_uris"][0]["connection_uri"]
self.pgbench_env = connection_parameters_to_env(
cast("dict[str, str]", project["connection_uris"][0]["connection_parameters"])
)
self.is_new = True
else:
project = neon_api.get_project_details(project_id)
if int(project["project"]["pg_version"]) != int(pg_version):
raise Exception(
f"A project with the provided ID exists, but it's not of the specified version (expected {pg_version}, got {project['project']['pg_version']})"
)
self.project_id = project_id
eps = neon_api.get_endpoints(project_id)["endpoints"]
self.endpoint_id = eps[0]["id"]
self.connstr = neon_api.get_connection_uri(
project_id, endpoint_id=self.endpoint_id, pooled=False
)["uri"]
pw = self.connstr.split("@")[0].split(":")[-1]
self.pgbench_env = {
"PGHOST": eps[0]["host"],
"PGDATABASE": "neondb",
"PGUSER": "neondb_owner",
"PGPASSWORD": pw,
}
self.is_new = False
def restart(self):
self.neon_api.restart_endpoint(self.project_id, self.endpoint_id)
self.neon_api.wait_for_operation_to_finish(self.project_id)
def get_synthetic_storage_size(self) -> int:
return int(
self.neon_api.get_project_details(self.project_id)["project"]["synthetic_storage_size"]
)