diff --git a/docs/docs/concepts/gateways.md b/docs/docs/concepts/gateways.md
index 728077addb..7402bc5762 100644
--- a/docs/docs/concepts/gateways.md
+++ b/docs/docs/concepts/gateways.md
@@ -56,6 +56,27 @@ You can create gateways with the `aws`, `azure`, `gcp`, or `kubernetes` backends
Gateways in `kubernetes` backend require an external load balancer. Managed Kubernetes solutions usually include a load balancer.
For self-hosted Kubernetes, you must provide a load balancer by yourself.
+### Instance type
+
+By default, `dstack` provisions a small, low-cost instance for the gateway. If you expect to run high-traffic services, you can configure a larger instance type using the `instance_type` property.
+
+
+
+```yaml
+type: gateway
+name: example-gateway
+
+backend: aws
+region: eu-west-1
+
+# (Optional) Override the gateway instance type
+instance_type: t3.large
+
+domain: example.com
+```
+
+
+
### Router
By default, the gateway uses its own load balancer to route traffic between replicas. However, you can delegate this responsibility to a specific router by setting the `router` property. Currently, the only supported external router is `sglang`.
diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py
index 67a7e409dd..48720bb316 100644
--- a/src/dstack/_internal/core/backends/aws/compute.py
+++ b/src/dstack/_internal/core/backends/aws/compute.py
@@ -76,6 +76,7 @@
logger = get_logger(__name__)
# gp2 volumes can be 1GB-16TB, dstack AMIs are 100GB
CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("100GB"), max=Memory.parse("16TB"))
+DEFAULT_GATEWAY_INSTANCE_TYPE = "t3.micro"
class AWSGatewayBackendData(CoreModel):
@@ -454,22 +455,27 @@ def create_gateway(
project_id=configuration.project_name,
vpc_id=vpc_id,
)
- response = ec2_resource.create_instances(
- **aws_resources.create_instances_struct(
- disk_size=10,
- image_id=aws_resources.get_gateway_image_id(ec2_client),
- instance_type="t3.micro",
- iam_instance_profile=None,
- user_data=get_gateway_user_data(
- configuration.ssh_key_pub, router=configuration.router
- ),
- tags=tags,
- security_group_id=security_group_id,
- spot=False,
- subnet_id=subnet_id,
- allocate_public_ip=configuration.public_ip,
- )
+ instance_struct = aws_resources.create_instances_struct(
+ disk_size=10,
+ image_id=aws_resources.get_gateway_image_id(ec2_client),
+ instance_type=configuration.instance_type or DEFAULT_GATEWAY_INSTANCE_TYPE,
+ iam_instance_profile=None,
+ user_data=get_gateway_user_data(
+ configuration.ssh_key_pub, router=configuration.router
+ ),
+ tags=tags,
+ security_group_id=security_group_id,
+ spot=False,
+ subnet_id=subnet_id,
+ allocate_public_ip=configuration.public_ip,
)
+ try:
+ response = ec2_resource.create_instances(**instance_struct)
+ except botocore.exceptions.ClientError as e:
+ msg = f"AWS Error: {e.response['Error']['Code']}"
+ if e.response["Error"].get("Message"):
+ msg += f": {e.response['Error']['Message']}"
+ raise ComputeError(msg)
instance = response[0]
instance.wait_until_running()
instance.reload() # populate instance.public_ip_address
diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py
index 0089f5e478..74e585d631 100644
--- a/src/dstack/_internal/core/backends/azure/compute.py
+++ b/src/dstack/_internal/core/backends/azure/compute.py
@@ -79,6 +79,7 @@
logger = get_logger(__name__)
# OS disks can be 1GB-4095GB, dstack images are 30GB
CONFIGURABLE_DISK_SIZE = Range[Memory](min=Memory.parse("30GB"), max=Memory.parse("4095GB"))
+DEFAULT_GATEWAY_INSTANCE_TYPE = "Standard_B1ms"
class AzureCompute(
@@ -230,6 +231,13 @@ def create_gateway(
self,
configuration: GatewayComputeConfiguration,
) -> GatewayProvisioningData:
+ if configuration.instance_type is not None:
+ # TODO: support instance_type. Requires selecting a VM image to avoid errors like this:
+ # > The selected VM size 'Standard_E4s_v6' cannot boot Hypervisor Generation '1'
+ raise ComputeError(
+ "The `azure` backend does not support the `instance_type`"
+ " gateway configuration property"
+ )
logger.info(
"Launching %s gateway instance in %s...",
configuration.instance_name,
@@ -275,7 +283,7 @@ def create_gateway(
managed_identity_name=None,
managed_identity_resource_group=None,
image_reference=_get_gateway_image_ref(),
- vm_size="Standard_B1ms",
+ vm_size=DEFAULT_GATEWAY_INSTANCE_TYPE,
instance_name=instance_name,
user_data=get_gateway_user_data(
configuration.ssh_key_pub, router=configuration.router
diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py
index 8e5c36fccb..76a394bc7a 100644
--- a/src/dstack/_internal/core/backends/gcp/compute.py
+++ b/src/dstack/_internal/core/backends/gcp/compute.py
@@ -88,6 +88,7 @@
)
RESOURCE_NAME_PATTERN = re.compile(r"[a-z0-9-]+")
TPU_VERSIONS = [tpu.name for tpu in KNOWN_TPUS]
+DEFAULT_GATEWAY_INSTANCE_TYPE = "e2-medium"
class GCPOfferBackendData(CoreModel):
@@ -596,7 +597,7 @@ def create_gateway(
request.instance_resource = gcp_resources.create_instance_struct(
disk_size=10,
image_id=_get_gateway_image_id(),
- machine_type="e2-medium",
+ machine_type=configuration.instance_type or DEFAULT_GATEWAY_INSTANCE_TYPE,
accelerators=[],
spot=False,
user_data=get_gateway_user_data(
@@ -612,8 +613,14 @@ def create_gateway(
subnetwork=subnetwork,
allocate_public_ip=configuration.public_ip,
)
- operation = self.instances_client.insert(request=request)
- gcp_resources.wait_for_extended_operation(operation, "instance creation")
+ try:
+ operation = self.instances_client.insert(request=request)
+ gcp_resources.wait_for_extended_operation(operation, "instance creation")
+ except (
+ google.api_core.exceptions.ServiceUnavailable,
+ google.api_core.exceptions.ClientError,
+ ) as e:
+ raise ComputeError(f"GCP error: {e.message}")
instance = self.instances_client.get(
project=self.config.project_id, zone=zone, instance=instance_name
)
diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py
index da5e125b96..e46b99d9d7 100644
--- a/src/dstack/_internal/core/backends/kubernetes/compute.py
+++ b/src/dstack/_internal/core/backends/kubernetes/compute.py
@@ -370,6 +370,11 @@ def create_gateway(
# TODO: By default EKS creates a Classic Load Balancer for Load Balancer services.
# Consider deploying an NLB. It seems it requires some extra configuration on the cluster:
# https://docs.aws.amazon.com/eks/latest/userguide/network-load-balancing.html
+ if configuration.instance_type is not None:
+ raise ComputeError(
+ "The `kubernetes` backend does not support the `instance_type`"
+ " gateway configuration property"
+ )
instance_name = generate_unique_gateway_instance_name(configuration)
commands = _get_gateway_commands(
authorized_keys=[configuration.ssh_key_pub], router=configuration.router
diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py
index ace68e5429..2dfeb5b181 100644
--- a/src/dstack/_internal/core/models/gateways.py
+++ b/src/dstack/_internal/core/models/gateways.py
@@ -51,6 +51,16 @@ class GatewayConfiguration(CoreModel):
default: Annotated[bool, Field(description="Make the gateway default")] = False
backend: Annotated[BackendType, Field(description="The gateway backend")]
region: Annotated[str, Field(description="The gateway region")]
+ instance_type: Annotated[
+ Optional[str],
+ Field(
+ description=(
+ "Backend-specific instance type to use for the gateway instance."
+ " Omit to use the backend's default, which is typically a small non-GPU instance"
+ ),
+ min_length=1,
+ ),
+ ] = None
router: Annotated[
Optional[AnyRouterConfig],
Field(description="The router configuration"),
@@ -115,6 +125,7 @@ class GatewayComputeConfiguration(CoreModel):
instance_name: str
backend: BackendType
region: str
+ instance_type: Optional[str] = None
public_ip: bool
ssh_key_pub: str
certificate: Optional[AnyGatewayCertificate] = None
diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py
index 273b1fb894..682feaf31b 100644
--- a/src/dstack/_internal/server/services/gateways/__init__.py
+++ b/src/dstack/_internal/server/services/gateways/__init__.py
@@ -104,6 +104,7 @@ async def create_gateway_compute(
instance_name=configuration.name,
backend=configuration.backend,
region=configuration.region,
+ instance_type=configuration.instance_type,
public_ip=configuration.public_ip,
ssh_key_pub=gateway_ssh_public_key,
certificate=configuration.certificate,
diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py
index 0bee1e6f06..b909c7d729 100644
--- a/src/tests/_internal/server/routers/test_gateways.py
+++ b/src/tests/_internal/server/routers/test_gateways.py
@@ -70,6 +70,7 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient):
"name": gateway.name,
"backend": backend.type.value,
"region": gateway.region,
+ "instance_type": None,
"router": None,
"domain": gateway.wildcard_domain,
"default": False,
@@ -122,6 +123,7 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient):
"name": gateway.name,
"backend": backend.type.value,
"region": gateway.region,
+ "instance_type": None,
"router": None,
"domain": gateway.wildcard_domain,
"default": False,
@@ -203,6 +205,7 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn
"name": "test",
"backend": backend.type.value,
"region": "us",
+ "instance_type": None,
"router": None,
"domain": None,
"default": True,
@@ -256,6 +259,7 @@ async def test_create_gateway_without_name(
"name": "random-name",
"backend": backend.type.value,
"region": "us",
+ "instance_type": None,
"router": None,
"domain": None,
"default": True,
@@ -359,6 +363,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client:
"name": gateway.name,
"backend": backend.type.value,
"region": gateway.region,
+ "instance_type": None,
"router": None,
"domain": gateway.wildcard_domain,
"default": True,
@@ -482,6 +487,7 @@ def get_backend(project, backend_type):
"name": gateway_gcp.name,
"backend": backend_gcp.type.value,
"region": gateway_gcp.region,
+ "instance_type": None,
"router": None,
"domain": gateway_gcp.wildcard_domain,
"default": False,
@@ -552,6 +558,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client:
"name": gateway.name,
"backend": backend.type.value,
"region": gateway.region,
+ "instance_type": None,
"router": None,
"domain": "test.com",
"default": False,