Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 86 additions & 22 deletions src/httpfile/assertion_checker.zig
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ const http = std.http;
const HttpParser = @import("./parser.zig");
const Client = @import("./http_client.zig");

fn extractHeaderName(key: []const u8) ![]const u8 {
// Expects key in the form header["..."]
const start_quote = std.mem.indexOfScalar(u8, key, '"') orelse return error.InvalidAssertionKey;
const end_quote = std.mem.lastIndexOfScalar(u8, key, '"') orelse return error.InvalidAssertionKey;
if (end_quote <= start_quote) return error.InvalidAssertionKey;
return key[start_quote + 1 .. end_quote];
}

pub fn check(request: *HttpParser.HttpRequest, response: Client.HttpResponse) !void {
const stderr = std.io.getStdErr().writer();
for (request.assertions.items) |assertion| {
Expand All @@ -63,8 +71,7 @@ pub fn check(request: *HttpParser.HttpRequest, response: Client.HttpResponse) !v
return error.BodyContentMismatch;
}
} else if (std.mem.startsWith(u8, assertion.key, "header[\"")) {
// Extract the header name from the assertion key
const header_name = assertion.key[8 .. assertion.key.len - 2];
const header_name = try extractHeaderName(assertion.key);
const actual_value = response.headers.get(header_name);
if (actual_value == null or !std.ascii.eqlIgnoreCase(actual_value.?, assertion.value)) {
stderr.print("[Fail] Expected header \"{s}\" to be \"{s}\", got \"{s}\"\n", .{ header_name, assertion.value, actual_value orelse "null" }) catch {};
Expand All @@ -88,8 +95,7 @@ pub fn check(request: *HttpParser.HttpRequest, response: Client.HttpResponse) !v
return error.BodyContentMatchesButShouldnt;
}
} else if (std.mem.startsWith(u8, assertion.key, "header[\"")) {
// Extract the header name from the assertion key
const header_name = assertion.key[8 .. assertion.key.len - 2];
const header_name = try extractHeaderName(assertion.key);
const actual_value = response.headers.get(header_name);
if (actual_value != null and std.ascii.eqlIgnoreCase(actual_value.?, assertion.value)) {
stderr.print("[Fail] Expected header \"{s}\" to NOT equal \"{s}\", got \"{s}\"\n", .{ header_name, assertion.value, actual_value orelse "null" }) catch {};
Expand All @@ -100,20 +106,6 @@ pub fn check(request: *HttpParser.HttpRequest, response: Client.HttpResponse) !v
return error.InvalidAssertionKey;
}
},

// .header => {
// // assertion.key is header[""] so we need to
// // parse it out of the quotes
// const tokens = std.mem.splitScalar(u8, assertion.key, '\"');
// const expected_header = tokens.next() orelse return error.InvalidHeaderFormat;
// if (expected_header.len != 2) {
// return error.InvalidHeaderFormat;
// }
// const actual_value = response.headers.get(expected_header);
// if (actual_value == null or actual_value.* != expected_header.value) {
// return error.HeaderMismatch;
// }
// },
.contains => {
if (std.ascii.eqlIgnoreCase(assertion.key, "status")) {
var status_buf: [3]u8 = undefined;
Expand All @@ -129,8 +121,7 @@ pub fn check(request: *HttpParser.HttpRequest, response: Client.HttpResponse) !v
return error.BodyContentNotContains;
}
} else if (std.mem.startsWith(u8, assertion.key, "header[\"")) {
// Extract the header name from the assertion key
const header_name = assertion.key[8 .. assertion.key.len - 2];
const header_name = try extractHeaderName(assertion.key);
const actual_value = response.headers.get(header_name);
if (actual_value == null or std.mem.indexOf(u8, actual_value.?, assertion.value) == null) {
stderr.print("[Fail] Expected header \"{s}\" to contain \"{s}\", got \"{s}\"\n", .{ header_name, assertion.value, actual_value orelse "null" }) catch {};
Expand All @@ -156,8 +147,7 @@ pub fn check(request: *HttpParser.HttpRequest, response: Client.HttpResponse) !v
return error.BodyContentContainsButShouldnt;
}
} else if (std.mem.startsWith(u8, assertion.key, "header[\"")) {
// Extract the header name from the assertion key
const header_name = assertion.key[8 .. assertion.key.len - 2];
const header_name = try extractHeaderName(assertion.key);
const actual_value = response.headers.get(header_name);
if (actual_value != null and std.mem.indexOf(u8, actual_value.?, assertion.value) != null) {
stderr.print("[Fail] Expected header \"{s}\" to NOT contain \"{s}\", got \"{s}\"\n", .{ header_name, assertion.value, actual_value orelse "null" }) catch {};
Expand All @@ -168,6 +158,32 @@ pub fn check(request: *HttpParser.HttpRequest, response: Client.HttpResponse) !v
return error.InvalidAssertionKey;
}
},
.starts_with => {
if (std.ascii.eqlIgnoreCase(assertion.key, "status")) {
var status_buf: [3]u8 = undefined;
const status_code = @intFromEnum(response.status.?);
const status_str = std.fmt.bufPrint(&status_buf, "{}", .{status_code}) catch return error.StatusCodeFormat;
if (!std.mem.startsWith(u8, status_str, assertion.value)) {
stderr.print("[Fail] Expected status code to start with \"{s}\", got \"{s}\"\n", .{ assertion.value, status_str }) catch {};
return error.StatusCodeNotStartsWith;
}
} else if (std.ascii.eqlIgnoreCase(assertion.key, "body")) {
if (!std.mem.startsWith(u8, response.body, assertion.value)) {
stderr.print("[Fail] Expected body content to start with \"{s}\", got \"{s}\"\n", .{ assertion.value, response.body }) catch {};
return error.BodyContentNotStartsWith;
}
} else if (std.mem.startsWith(u8, assertion.key, "header[\"")) {
const header_name = try extractHeaderName(assertion.key);
const actual_value = response.headers.get(header_name);
if (actual_value == null or !std.mem.startsWith(u8, actual_value.?, assertion.value)) {
stderr.print("[Fail] Expected header \"{s}\" to start with \"{s}\", got \"{s}\"\n", .{ header_name, assertion.value, actual_value orelse "null" }) catch {};
return error.HeaderNotStartsWith;
}
} else {
stderr.print("[Fail] Invalid assertion key for starts_with: {s}\n", .{assertion.key}) catch {};
return error.InvalidAssertionKey;
}
},
else => {},
}
}
Expand Down Expand Up @@ -276,3 +292,51 @@ test "HttpParser handles NotEquals" {

try check(&request, response);
}

test "HttpParser supports starts_with for status, body, and header" {
const allocator = std.testing.allocator;
var assertions = std.ArrayList(HttpParser.Assertion).init(allocator);
defer assertions.deinit();

// Status starts with "2"
try assertions.append(HttpParser.Assertion{
.key = "status",
.value = "2",
.assertion_type = .starts_with,
});
// Body starts with "Hello"
try assertions.append(HttpParser.Assertion{
.key = "body",
.value = "Hello",
.assertion_type = .starts_with,
});
// Header starts with "application"
try assertions.append(HttpParser.Assertion{
.key = "header[\"content-type\"]",
.value = "application",
.assertion_type = .starts_with,
});

var request = HttpParser.HttpRequest{
.method = .GET,
.url = "https://api.example.com",
.headers = std.ArrayList(http.Header).init(allocator),
.assertions = assertions,
.body = null,
};

var response_headers = std.StringHashMap([]const u8).init(allocator);
try response_headers.put("content-type", "application/json");
defer response_headers.deinit();

const body = try allocator.dupe(u8, "Hello world!");
defer allocator.free(body);
const response = Client.HttpResponse{
.status = http.Status.ok,
.headers = response_headers,
.body = body,
.allocator = allocator,
};

try check(&request, response);
}