Skip to content

Commit 7ac0d9e

Browse files
achimnolYaminyam
authored andcommitted
chore(BA-440): Upgrade mypy to 1.14.1 and ruff to 0.8.5 (#3354)
1 parent 3ad97a4 commit 7ac0d9e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+684
-737
lines changed

src/ai/backend/accelerator/cuda_open/nvidia.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import ctypes
22
import platform
33
from abc import ABCMeta, abstractmethod
4+
from collections.abc import MutableMapping, Sequence
45
from itertools import groupby
56
from operator import itemgetter
6-
from typing import Any, MutableMapping, NamedTuple, Tuple, TypeAlias
7+
from typing import Any, NamedTuple, TypeAlias, cast
78

89
# ref: https://developer.nvidia.com/cuda-toolkit-archive
910
TARGET_CUDA_VERSIONS = (
@@ -487,7 +488,7 @@ def load_library(cls):
487488
return None
488489

489490
@classmethod
490-
def get_version(cls) -> Tuple[int, int]:
491+
def get_version(cls) -> tuple[int, int]:
491492
if cls._version == (0, 0):
492493
raw_ver = ctypes.c_int()
493494
cls.invoke("cudaRuntimeGetVersion", ctypes.byref(raw_ver))
@@ -513,7 +514,9 @@ def get_device_props(cls, device_idx: int):
513514
props_struct = cudaDeviceProp()
514515
cls.invoke("cudaGetDeviceProperties", ctypes.byref(props_struct), device_idx)
515516
props: MutableMapping[str, Any] = {
516-
k: getattr(props_struct, k) for k, _ in props_struct._fields_
517+
# Treat each field as two-tuple assuming that we don't have bit-fields
518+
k: getattr(props_struct, k)
519+
for k, _ in cast(Sequence[tuple[str, Any]], props_struct._fields_)
517520
}
518521
pci_bus_id = b" " * 16
519522
cls.invoke("cudaDeviceGetPCIBusId", ctypes.c_char_p(pci_bus_id), 16, device_idx)

src/ai/backend/accelerator/cuda_open/plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ async def list_devices(self) -> Collection[CUDADevice]:
169169
if dev_id in self.device_mask:
170170
continue
171171
raw_info = libcudart.get_device_props(int(dev_id))
172-
sysfs_node_path = f"/sys/bus/pci/devices/{raw_info["pciBusID_str"].lower()}/numa_node"
172+
sysfs_node_path = f"/sys/bus/pci/devices/{raw_info['pciBusID_str'].lower()}/numa_node"
173173
node: Optional[int]
174174
try:
175175
node = int(Path(sysfs_node_path).read_text().strip())

src/ai/backend/accelerator/mock/plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ async def list_devices(self) -> Collection[MockDevice]:
297297
init_kwargs["is_mig_device"] = dev_info["is_mig_device"]
298298
if dev_info["is_mig_device"]:
299299
init_kwargs["device_id"] = DeviceId(
300-
f"MIG-{dev_info["mother_uuid"]}/{idx}/0"
300+
f"MIG-{dev_info['mother_uuid']}/{idx}/0"
301301
)
302302
device_cls = CUDADevice
303303
case _:
@@ -810,7 +810,7 @@ def get_metadata(self) -> AcceleratorMetadata:
810810

811811
device_format = self.device_formats[format_key]
812812
return {
813-
"slot_name": f"{self.mock_config["slot_name"]}.{format_key}",
813+
"slot_name": f"{self.mock_config['slot_name']}.{format_key}",
814814
"human_readable_name": device_format["human_readable_name"],
815815
"description": device_format["description"],
816816
"display_unit": device_format["display_unit"],

src/ai/backend/account_manager/server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,9 @@ async def server_main(
299299
try:
300300
ssl_ctx = None
301301
if am_cfg.ssl_enabled:
302-
assert (
303-
am_cfg.ssl_cert is not None
304-
), "Should set `account_manager.ssl-cert` in config file."
302+
assert am_cfg.ssl_cert is not None, (
303+
"Should set `account_manager.ssl-cert` in config file."
304+
)
305305
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
306306
ssl_ctx.load_cert_chain(
307307
str(am_cfg.ssl_cert),

src/ai/backend/agent/agent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,7 +2035,7 @@ async def create_kernel(
20352035
if len(overlapping_services) > 0:
20362036
raise AgentError(
20372037
f"Port {port_no} overlaps with built-in service"
2038-
f" {overlapping_services[0]["name"]}"
2038+
f" {overlapping_services[0]['name']}"
20392039
)
20402040

20412041
preopen_sport: ServicePort = {
@@ -2377,7 +2377,7 @@ async def load_model_definition(
23772377

23782378
if not model_definition_path:
23792379
raise AgentError(
2380-
f"Model definition file ({" or ".join(model_definition_candidates)}) does not exist under vFolder"
2380+
f"Model definition file ({' or '.join(model_definition_candidates)}) does not exist under vFolder"
23812381
f" {model_folder.name} (ID {model_folder.vfid})",
23822382
)
23832383
try:
@@ -2408,11 +2408,11 @@ async def load_model_definition(
24082408
]
24092409
if len(overlapping_services) > 0:
24102410
raise AgentError(
2411-
f"Port {service["port"]} overlaps with built-in service"
2412-
f" {overlapping_services[0]["name"]}"
2411+
f"Port {service['port']} overlaps with built-in service"
2412+
f" {overlapping_services[0]['name']}"
24132413
)
24142414
service_ports.append({
2415-
"name": f"{model["name"]}-{service["port"]}",
2415+
"name": f"{model['name']}-{service['port']}",
24162416
"protocol": ServicePortProtocols.PREOPEN,
24172417
"container_ports": (service["port"],),
24182418
"host_ports": (None,),

src/ai/backend/agent/docker/agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,7 @@ async def start_container(
933933
label for label in service_ports_label if label
934934
])
935935
update_nested_dict(container_config, self.computer_docker_args)
936-
kernel_name = f"kernel.{self.image_ref.name.split("/")[-1]}.{self.kernel_id}"
936+
kernel_name = f"kernel.{self.image_ref.name.split('/')[-1]}.{self.kernel_id}"
937937

938938
# optional local override of docker config
939939
extra_container_opts_name = "agent-docker-container-opts.json"
@@ -1202,7 +1202,7 @@ async def __ainit__(self) -> None:
12021202
{
12031203
"Cmd": [
12041204
f"UNIX-LISTEN:/ipc/{self.agent_sockpath.name},unlink-early,fork,mode=777",
1205-
f"TCP-CONNECT:127.0.0.1:{self.local_config["agent"]["agent-sock-port"]}",
1205+
f"TCP-CONNECT:127.0.0.1:{self.local_config['agent']['agent-sock-port']}",
12061206
],
12071207
"HostConfig": {
12081208
"Mounts": [
@@ -1449,7 +1449,7 @@ async def handle_agent_socket(self):
14491449
while True:
14501450
agent_sock = zmq_ctx.socket(zmq.REP)
14511451
try:
1452-
agent_sock.bind(f"tcp://127.0.0.1:{self.local_config["agent"]["agent-sock-port"]}")
1452+
agent_sock.bind(f"tcp://127.0.0.1:{self.local_config['agent']['agent-sock-port']}")
14531453
while True:
14541454
msg = await agent_sock.recv_multipart()
14551455
if not msg:

src/ai/backend/agent/docker/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async def get_container_version_and_status(self) -> Tuple[int, bool]:
6464
raise
6565
if c["Config"].get("Labels", {}).get("ai.backend.system", "0") != "1":
6666
raise RuntimeError(
67-
f"An existing container named \"{c["Name"].lstrip("/")}\" is not a system container"
67+
f'An existing container named "{c["Name"].lstrip("/")}" is not a system container'
6868
" spawned by Backend.AI. Please check and remove it."
6969
)
7070
return (

src/ai/backend/agent/kubernetes/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ async def check_krunner_pv_status(self):
887887
new_pv.label("backend.ai/backend-ai-scratch-volume", "hostPath")
888888
else:
889889
raise NotImplementedError(
890-
f'Scratch type {self.local_config["container"]["scratch-type"]} is not'
890+
f"Scratch type {self.local_config['container']['scratch-type']} is not"
891891
" supported",
892892
)
893893

src/ai/backend/agent/kubernetes/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ async def get_container_version_and_status(self) -> Tuple[int, bool]:
6363
raise
6464
if c["Config"].get("Labels", {}).get("ai.backend.system", "0") != "1":
6565
raise RuntimeError(
66-
f"An existing container named \"{c["Name"].lstrip("/")}\" is not a system container"
66+
f'An existing container named "{c["Name"].lstrip("/")}" is not a system container'
6767
" spawned by Backend.AI. Please check and remove it."
6868
)
6969
return (

src/ai/backend/agent/server.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
import sys
1414
from collections import OrderedDict, defaultdict
1515
from datetime import datetime, timezone
16-
from ipaddress import _BaseAddress as BaseIPAddress
17-
from ipaddress import ip_network
16+
from ipaddress import IPv4Address, IPv6Address, ip_network
1817
from pathlib import Path
1918
from pprint import pformat, pprint
2019
from typing import (
@@ -972,7 +971,7 @@ async def server_main(
972971

973972
log.info("Preparing kernel runner environments...")
974973
kernel_mod = importlib.import_module(
975-
f"ai.backend.agent.{local_config["agent"]["backend"].value}.kernel",
974+
f"ai.backend.agent.{local_config['agent']['backend'].value}.kernel",
976975
)
977976
krunner_volumes = await kernel_mod.prepare_krunner_env(local_config) # type: ignore
978977
# TODO: merge k8s branch: nfs_mount_path = local_config['baistatic']['mounted-at']
@@ -992,8 +991,8 @@ async def server_main(
992991
}
993992
scope_prefix_map = {
994993
ConfigScopes.GLOBAL: "",
995-
ConfigScopes.SGROUP: f"sgroup/{local_config["agent"]["scaling-group"]}",
996-
ConfigScopes.NODE: f"nodes/agents/{local_config["agent"]["id"]}",
994+
ConfigScopes.SGROUP: f"sgroup/{local_config['agent']['scaling-group']}",
995+
ConfigScopes.NODE: f"nodes/agents/{local_config['agent']['id']}",
997996
}
998997
etcd = AsyncEtcd(
999998
local_config["etcd"]["addr"],
@@ -1155,7 +1154,9 @@ def main(
11551154
raise click.Abort()
11561155

11571156
rpc_host = cfg["agent"]["rpc-listen-addr"].host
1158-
if isinstance(rpc_host, BaseIPAddress) and (rpc_host.is_unspecified or rpc_host.is_link_local):
1157+
if isinstance(rpc_host, (IPv4Address, IPv6Address)) and (
1158+
rpc_host.is_unspecified or rpc_host.is_link_local
1159+
):
11591160
print(
11601161
"ConfigurationError: "
11611162
"Cannot use link-local or unspecified IP address as the RPC listening host.",

src/ai/backend/agent/watcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def main(
409409
fn = Path(cfg["logging"]["file"]["filename"])
410410
cfg["logging"]["file"]["filename"] = f"{fn.stem}-watcher{fn.suffix}"
411411

412-
setproctitle(f"backend.ai: watcher {cfg["etcd"]["namespace"]}")
412+
setproctitle(f"backend.ai: watcher {cfg['etcd']['namespace']}")
413413
with logger:
414414
log.info("Backend.AI Agent Watcher {0}", VERSION)
415415
log.info("runtime: {0}", utils.env_info())

src/ai/backend/cli/interaction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ def ask_string_in_array(prompt: str, choices: list, default: str) -> Optional[st
7373

7474
if default:
7575
question = (
76-
f"{prompt} (choices: {"/".join(choices)}, "
76+
f"{prompt} (choices: {'/'.join(choices)}, "
7777
f"if left empty, this will use default value: {default}): "
7878
)
7979
else:
8080
question = (
81-
f"{prompt} (choices: {"/".join(choices)}, if left empty, this will remove this key): "
81+
f"{prompt} (choices: {'/'.join(choices)}, if left empty, this will remove this key): "
8282
)
8383

8484
while True:
@@ -92,7 +92,7 @@ def ask_string_in_array(prompt: str, choices: list, default: str) -> Optional[st
9292
elif user_reply.lower() in choices:
9393
break
9494
else:
95-
print(f"Please answer in {"/".join(choices)}.")
95+
print(f"Please answer in {'/'.join(choices)}.")
9696
return user_reply
9797

9898

src/ai/backend/client/cli/admin/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ async def rescan_images_impl(registry: str) -> None:
6161
print_error(e)
6262
sys.exit(ExitCode.FAILURE)
6363
if not result["ok"]:
64-
print_fail(f"Failed to begin registry scanning: {result["msg"]}")
64+
print_fail(f"Failed to begin registry scanning: {result['msg']}")
6565
sys.exit(ExitCode.FAILURE)
6666
print_done("Started updating the image metadata from the configured registries.")
6767
bgtask_id = result["task_id"]

src/ai/backend/client/cli/pretty.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def format_error(exc: Exception):
127127
if matches:
128128
yield "\nCandidates (up to 10 recent entries):\n"
129129
for item in matches:
130-
yield f"- {item["id"]} ({item["name"]}, {item["status"]})\n"
130+
yield f"- {item['id']} ({item['name']}, {item['status']})\n"
131131
elif exc.data["type"].endswith("/session-already-exists"):
132132
existing_session_id = exc.data["data"].get("existingSessionId", None)
133133
if existing_session_id is not None:
@@ -144,7 +144,7 @@ def format_error(exc: Exception):
144144
if exc.data["type"].endswith("/graphql-error"):
145145
yield "\n\u279c Message:\n"
146146
for err_item in exc.data.get("data", []):
147-
yield f"{err_item["message"]}"
147+
yield f"{err_item['message']}"
148148
if err_path := err_item.get("path"):
149149
yield f" (path: {_format_gql_path(err_path)})"
150150
yield "\n"

src/ai/backend/client/cli/service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def info(ctx: CLIContext, service_name_or_id: str):
113113
)
114114
print()
115115
for route in routes:
116-
print(f"Route {route["routing_id"]}: ")
116+
print(f"Route {route['routing_id']}: ")
117117
ctx.output.print_item(
118118
route,
119119
_default_routing_fields,
@@ -645,7 +645,7 @@ def generate_token(ctx: CLIContext, service_name_or_id: str, duration: str, quie
645645
if quiet:
646646
print(resp["token"])
647647
else:
648-
print_done(f"Generated API token {resp["token"]}")
648+
print_done(f"Generated API token {resp['token']}")
649649
except Exception as e:
650650
ctx.output.print_error(e)
651651
sys.exit(ExitCode.FAILURE)

src/ai/backend/client/cli/session/lifecycle.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ async def cmd_main() -> None:
977977
session = api_sess.ComputeSession.from_session_id(session_id)
978978
resp = await session.update(priority=priority)
979979
item = resp["item"]
980-
print_done(f"Session {item["name"]!r} priority is changed to {item["priority"]}.")
980+
print_done(f"Session {item['name']!r} priority is changed to {item['priority']}.")
981981

982982
try:
983983
asyncio.run(cmd_main())
@@ -1372,7 +1372,7 @@ def watch(
13721372
session_names = _fetch_session_names()
13731373
if not session_names:
13741374
if output == "json":
1375-
sys.stderr.write(f'{json.dumps({"ok": False, "reason": "No matching items."})}\n')
1375+
sys.stderr.write(f"{json.dumps({'ok': False, 'reason': 'No matching items.'})}\n")
13761376
else:
13771377
print_fail("No matching items.")
13781378
sys.exit(ExitCode.FAILURE)
@@ -1394,7 +1394,7 @@ def watch(
13941394
else:
13951395
if output == "json":
13961396
sys.stderr.write(
1397-
f'{json.dumps({"ok": False, "reason": "No matching items."})}\n'
1397+
f"{json.dumps({'ok': False, 'reason': 'No matching items.'})}\n"
13981398
)
13991399
else:
14001400
print_fail("No matching items.")

src/ai/backend/client/cli/vfolder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def request_download(name, filename):
426426
with Session() as session:
427427
try:
428428
response = json.loads(session.VFolder(name).request_download(filename))
429-
print_done(f'Download token: {response["token"]}')
429+
print_done(f"Download token: {response['token']}")
430430
except Exception as e:
431431
print_error(e)
432432
sys.exit(ExitCode.FAILURE)

src/ai/backend/client/func/acl.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import textwrap
21
from typing import Sequence
32

4-
from ai.backend.client.output.fields import permission_fields
5-
from ai.backend.client.output.types import FieldSpec
6-
3+
from ..output.fields import permission_fields
4+
from ..output.types import FieldSpec
75
from ..session import api_session
6+
from ..utils import dedent as _d
87
from .base import BaseFunction, api_function
98

109
__all__ = ("Permission",)
@@ -24,13 +23,11 @@ async def list(
2423
2524
:param fields: Additional permission query fields to fetch.
2625
"""
27-
query = textwrap.dedent(
28-
"""\
26+
query = _d("""
2927
query {
30-
vfolder_host_permissions {$fields}
28+
vfolder_host_permissions { $fields }
3129
}
32-
"""
33-
)
30+
""")
3431
query = query.replace("$fields", " ".join(f.field_ref for f in fields))
3532
data = await api_session.get().Admin._query(query)
3633
return data["vfolder_host_permissions"]

src/ai/backend/client/func/agent.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from __future__ import annotations
22

3-
import textwrap
43
from typing import Optional, Sequence
54

6-
from ai.backend.client.output.fields import agent_fields
7-
from ai.backend.client.output.types import FieldSpec, PaginatedResult
8-
from ai.backend.client.pagination import fetch_paginated_result
9-
from ai.backend.client.request import Request
10-
from ai.backend.client.session import api_session
11-
5+
from ..output.fields import agent_fields
6+
from ..output.types import FieldSpec, PaginatedResult
7+
from ..pagination import fetch_paginated_result
8+
from ..request import Request
9+
from ..session import api_session
10+
from ..utils import dedent as _d
1211
from .base import BaseFunction, api_function
1312

1413
__all__ = (
@@ -88,13 +87,11 @@ async def detail(
8887
agent_id: str,
8988
fields: Sequence[FieldSpec] = _default_detail_fields,
9089
) -> Sequence[dict]:
91-
query = textwrap.dedent(
92-
"""\
90+
query = _d("""
9391
query($agent_id: String!) {
9492
agent(agent_id: $agent_id) {$fields}
9593
}
96-
"""
97-
)
94+
""")
9895
query = query.replace("$fields", " ".join(f.field_ref for f in fields))
9996
variables = {"agent_id": agent_id}
10097
data = await api_session.get().Admin._query(query, variables)

0 commit comments

Comments
 (0)