Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public static void before(
}
if (uri != null) {
// Report the URL :
URLCollector.report(uri.toURL());
URLCollector.report(uri.toURL(), "org.apache.http.HttpClient.execute");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public static void before(
// Run report with "argument"
for (Method method2: clazz.getMethods()) {
if(method2.getName().equals("report")) {
method2.invoke(null, url);
method2.invoke(null, url, "HttpUrlConnection");
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static void before(
URL url = (URL) toUrlMethod.invoke(urlObject);

// Report the URL
URLCollector.report(url);
URLCollector.report(url, "okhttp3.OkHttpClient.newCall");
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.aikido.agent_api.background.cloud.api;

import dev.aikido.agent_api.background.Endpoint;
import dev.aikido.agent_api.storage.service_configuration.Domain;

import java.util.List;

Expand All @@ -11,6 +12,8 @@ public record APIResponse(
List<Endpoint> endpoints,
List<String> blockedUserIds,
List<String> allowedIPAddresses,
boolean blockNewOutgoingRequests,
List<Domain> domains,
boolean receivedAnyStats,
boolean block
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private static APIResponse getUnsuccessfulAPIResponse(String error) {
return new APIResponse(
false, // Success
error,
0, null, null, null, false, false // Unimportant values.
0, null, null, null, false, null, false, false // Unimportant values.
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dev.aikido.agent_api.context.Context;
import dev.aikido.agent_api.storage.Hostnames;
import dev.aikido.agent_api.storage.ServiceConfigStore;
import dev.aikido.agent_api.storage.statistics.OperationKind;
import dev.aikido.agent_api.storage.statistics.StatisticsStore;
import dev.aikido.agent_api.vulnerabilities.Attack;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ public final class RedirectCollector {
private static final Logger logger = LogManager.getLogger(RedirectCollector.class);

private RedirectCollector() {}
public static void report(URL origin, URL dest) {
public static void report(URL origin, URL dest, String operation) {
logger.trace("Redirect detected: [Origin]<%s> -> [Destination]<%s>", origin, dest);
ContextObject context = Context.get();
// Report destination URL :
URLCollector.report(dest);
URLCollector.report(dest, operation);

// Add as a node :
List<RedirectNode> redirectStarterNodes = context.getRedirectStartNodes();
Expand All @@ -41,4 +41,4 @@ public static void report(URL origin, URL dest) {
context.addRedirectNode(starterNode);
Context.set(context); // Update context.
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,73 @@

import dev.aikido.agent_api.context.Context;
import dev.aikido.agent_api.context.ContextObject;
import dev.aikido.agent_api.context.User;
import dev.aikido.agent_api.storage.HostnamesStore;
import dev.aikido.agent_api.helpers.logging.LogManager;
import dev.aikido.agent_api.helpers.logging.Logger;
import dev.aikido.agent_api.storage.ServiceConfigStore;
import dev.aikido.agent_api.vulnerabilities.Attack;
import dev.aikido.agent_api.vulnerabilities.Vulnerabilities;
import dev.aikido.agent_api.vulnerabilities.ssrf.SSRFException;

import java.net.URL;
import java.util.Map;

import static dev.aikido.agent_api.helpers.ShouldBlockHelper.shouldBlock;
import static dev.aikido.agent_api.helpers.StackTrace.getCurrentStackTrace;
import static dev.aikido.agent_api.helpers.url.PortParser.getPortFromURL;
import static dev.aikido.agent_api.storage.AttackQueue.attackDetected;

public final class URLCollector {
private static final Logger logger = LogManager.getLogger(URLCollector.class);

private URLCollector() {}
public static void report(URL url) {
public static void report(URL url, String operation) {
if(url != null) {
if (!url.getProtocol().startsWith("http")) {
return; // Non-HTTP(S) URL
return; // Non-HTTP(S) URL
}
logger.trace("Adding a new URL to the cache: %s", url);
int port = getPortFromURL(url);

// We store hostname and port in two places, HostnamesStore and Context. HostnamesStore is for reporting
// outbound domains. Context is to have a map of hostnames with used port numbers to detect SSRF attacks.

// hostname blocking :
String hostname = url.getHost();
if (ServiceConfigStore.shouldBlockOutgoingRequest(hostname)) {
ContextObject ctx = Context.get();

User currentUser = null;
if (ctx != null) {
currentUser = ctx.getUser();
}

Attack attack = new Attack(
operation,
Vulnerabilities.SSRF,
"",
"",
Map.of(),
/* payload */ hostname,
getCurrentStackTrace(),
currentUser
);

attackDetected(attack, ctx);
if (shouldBlock()) {
logger.debug("Blocking request to domain: %s", hostname);
throw SSRFException.get();
}
};

// Store (new) hostname hits
HostnamesStore.incrementHits(url.getHost(), port);
HostnamesStore.incrementHits(hostname, port);

// Add to context :
ContextObject context = Context.get();
if (context != null) {
context.getHostnames().add(url.getHost(), port);
context.getHostnames().add(hostname, port);
Context.set(context);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,13 @@ public static void setMiddlewareInstalled(boolean middlewareInstalled) {
mutex.writeLock().unlock();
}
}

public static boolean shouldBlockOutgoingRequest(String hostname) {
mutex.readLock().lock();
try {
return config.shouldBlockOutgoingRequest(hostname);
} finally {
mutex.readLock().unlock();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import dev.aikido.agent_api.background.cloud.api.APIResponse;
import dev.aikido.agent_api.background.cloud.api.ReportingApi;
import dev.aikido.agent_api.helpers.net.IPList;
import dev.aikido.agent_api.storage.service_configuration.Domain;
import dev.aikido.agent_api.storage.service_configuration.ParsedFirewallLists;
import dev.aikido.agent_api.storage.statistics.StatisticsStore;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.*;

import static dev.aikido.agent_api.helpers.IPListBuilder.createIPList;
import static dev.aikido.agent_api.vulnerabilities.ssrf.IsPrivateIP.isPrivateIp;
Expand All @@ -26,6 +25,8 @@ public class ServiceConfiguration {
private IPList bypassedIPs = new IPList();
private HashSet<String> blockedUserIDs = new HashSet<>();
private List<Endpoint> endpoints = new ArrayList<>();
private Map<String, Domain> domains = new HashMap<>();
private boolean blockNewOutgoingRequests = false;

public ServiceConfiguration() {
this.receivedAnyStats = true; // true by default, waiting for the startup event
Expand All @@ -46,6 +47,15 @@ public void updateConfig(APIResponse apiResponse) {
if (apiResponse.endpoints() != null) {
this.endpoints = apiResponse.endpoints();
}
if (apiResponse.domains() != null) {
for (Domain domain : apiResponse.domains()) {
if (this.domains.get(domain.hostname()) != null) {
continue; // use first provided domain value
}
this.domains.put(domain.hostname(), domain);
}
}
this.blockNewOutgoingRequests = apiResponse.blockNewOutgoingRequests();
this.receivedAnyStats = apiResponse.receivedAnyStats();
}

Expand Down Expand Up @@ -127,4 +137,18 @@ public boolean isBlockedUserAgent(String userAgent) {

public record BlockedResult(boolean blocked, String description) {
}

public boolean shouldBlockOutgoingRequest(String hostname) {
Domain matchingDomain = this.domains.get(hostname);
if (matchingDomain == null) {
return false;
}

boolean isDomainBlocked = matchingDomain.isBlockingMode();
if (this.blockNewOutgoingRequests) {
return isDomainBlocked;
}

return isDomainBlocked;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package dev.aikido.agent_api.storage.service_configuration;

public record Domain(String hostname, String mode) {
public boolean isBlockingMode() {
// mode can either be "allow" or "block"
return this.mode.equals("block");
}
}
4 changes: 2 additions & 2 deletions agent_api/src/test/java/ShouldBlockRequestTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void testUserSet() throws SQLException {
ServiceConfigStore.updateFromAPIResponse(new APIResponse(
true, "", getUnixTimeMS(), List.of(),
/* blockedUserIds */ List.of("ID1", "ID2", "ID3"), List.of(),
false, true
false, null, false, true
));
var res2 = ShouldBlockRequest.shouldBlockRequest();
assertTrue(res2.block());
Expand Down Expand Up @@ -227,7 +227,7 @@ public void testBlockedUserWithMultipleEndpoints() throws SQLException {
);
List<String> blockedUserIds = List.of("ID1");
ServiceConfigStore.updateFromAPIResponse(new APIResponse(
true, "", getUnixTimeMS(), endpoints, blockedUserIds, List.of(), true, false
true, "", getUnixTimeMS(), endpoints, blockedUserIds, List.of(), false, null,true, false
));

// Call the method
Expand Down
18 changes: 9 additions & 9 deletions agent_api/src/test/java/collectors/URLCollectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ private void setContextAndLifecycle(String url) {
@Test
public void testNewUrlConnectionWithPort() throws IOException {
setContextAndLifecycle("");
URLCollector.report(new URL("http://localhost:8080"));

URLCollector.report(new URL("http://localhost:8080"), "test");
Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList();
assertEquals(1, hostnameArray.length);
assertEquals(8080, hostnameArray[0].getPort());
Expand All @@ -49,7 +49,7 @@ public void testNewUrlConnectionWithPort() throws IOException {
@Test
public void testNewUrlConnectionWithHttp() throws IOException {
setContextAndLifecycle("");
URLCollector.report(new URL("http://app.local.aikido.io"));
URLCollector.report(new URL("http://app.local.aikido.io"), "test");
Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList();
assertEquals(1, hostnameArray.length);
assertEquals(80, hostnameArray[0].getPort());
Expand All @@ -64,7 +64,7 @@ public void testNewUrlConnectionWithHttp() throws IOException {
@Test
public void testNewUrlConnectionHttps() throws IOException {
setContextAndLifecycle("");
URLCollector.report(new URL("https://aikido.dev"));
URLCollector.report(new URL("https://aikido.dev"), "test");
Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList();
assertEquals(1, hostnameArray.length);
assertEquals(443, hostnameArray[0].getPort());
Expand All @@ -79,7 +79,7 @@ public void testNewUrlConnectionHttps() throws IOException {
@Test
public void testNewUrlConnectionFaultyProtocol() throws IOException {
setContextAndLifecycle("");
URLCollector.report(new URL("ftp://localhost:8080"));
URLCollector.report(new URL("ftp://localhost:8080"), "test");
Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList();
assertEquals(0, hostnameArray.length);
Hostnames.HostnameEntry[] hostnameArray2 = Context.get().getHostnames().asArray();
Expand All @@ -89,7 +89,7 @@ public void testNewUrlConnectionFaultyProtocol() throws IOException {
@Test
public void testWithNullURL() throws IOException {
setContextAndLifecycle("");
URLCollector.report(null);
URLCollector.report(null, "test");
Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList();
assertEquals(0, hostnameArray.length);
Hostnames.HostnameEntry[] hostnameArray2 = Context.get().getHostnames().asArray();
Expand All @@ -100,7 +100,7 @@ public void testWithNullURL() throws IOException {
public void testWithNullContext() throws IOException {
setContextAndLifecycle("");
Context.reset();
URLCollector.report(new URL("https://aikido.dev"));
URLCollector.report(new URL("https://aikido.dev"), "test");
Hostnames.HostnameEntry[] hostnameArray = HostnamesStore.getHostnamesAsList();
assertEquals(1, hostnameArray.length);
assertEquals(443, hostnameArray[0].getPort());
Expand All @@ -112,10 +112,10 @@ public void testWithNullContext() throws IOException {
public void testOnlyContext() throws IOException {
setContextAndLifecycle("");
HostnamesStore.clear();
URLCollector.report(new URL("https://aikido.dev"));
URLCollector.report(new URL("https://aikido.dev"), "test");
Hostnames.HostnameEntry[] hostnameArray = Context.get().getHostnames().asArray();
assertEquals(1, hostnameArray.length);
assertEquals(443, hostnameArray[0].getPort());
assertEquals("aikido.dev", hostnameArray[0].getHostname());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ void testReport_userAgentBlocked_Ip_Bypassed() {

List<String> bypassedIps = List.of("192.168.1.1");
ServiceConfigStore.updateFromAPIResponse(new APIResponse(
true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, true, false
true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, false, null, true, false
));


Expand All @@ -231,7 +231,7 @@ void testReport_ipBlockedUsingLists_Ip_Bypassed() {

List<String> bypassedIps = List.of("192.168.1.1");
ServiceConfigStore.updateFromAPIResponse(new APIResponse(
true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, true, false
true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, false, null, true, false
));

WebRequestCollector.Res response = WebRequestCollector.report(contextObject);
Expand All @@ -251,7 +251,7 @@ void testReport_ipNotAllowedUsingLists_Ip_Bypassed() {

List<String> bypassedIps = List.of("192.168.1.1");
ServiceConfigStore.updateFromAPIResponse(new APIResponse(
true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, true, false
true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, false, null, true, false
));


Expand Down
Loading
Loading