diff --git a/src/repository.zig b/src/repository.zig index 343bfe9..767fa59 100644 --- a/src/repository.zig +++ b/src/repository.zig @@ -2,6 +2,7 @@ const std = @import("std"); const pg = @import("pg"); const zollections = @import("zollections"); const _sql = @import("sql.zig"); +const _conditions = @import("conditions.zig"); const query = @import("query.zig"); const insert = @import("insert.zig"); const update = @import("update.zig"); @@ -34,6 +35,41 @@ pub fn RepositoryConfiguration(comptime Model: type, comptime TableShape: type) }; } +/// Build the type of a model key, based on the given configuration. +pub fn ModelKeyType(comptime Model: type, comptime TableShape: type, comptime config: RepositoryConfiguration(Model, TableShape)) type { + if (config.key.len == 0) { + // Get the type of the simple key. + return std.meta.fields(TableShape)[std.meta.fieldIndex(TableShape, config.key[0]).?].type; + } else { + // Build the type of the composite key. + + // Build key fields. + var fields: [config.key.len]std.builtin.Type.StructField = undefined; + inline for (config.key, &fields) |keyName, *field| { + // Build NULL-terminated key name as field name. + var fieldName: [keyName.len:0]u8 = undefined; + @memcpy(fieldName[0..keyName.len], keyName); + + field.* = .{ + .name = &fieldName, + .type = std.meta.fields(TableShape)[std.meta.fieldIndex(TableShape, keyName).?].type, + .default_value = null, + .is_comptime = false, + .alignment = 0, + }; + } + + return @Type(.{ + .Struct = std.builtin.Type.Struct{ + .layout = std.builtin.Type.ContainerLayout.auto, + .fields = &fields, + .decls = &[_]std.builtin.Type.Declaration{}, + .is_tuple = false, + }, + }); + } +} + /// Repository of structures of a certain type. pub fn Repository(comptime Model: type, comptime TableShape: type, comptime config: RepositoryConfiguration(Model, TableShape)) type { return struct { @@ -42,6 +78,9 @@ pub fn Repository(comptime Model: type, comptime TableShape: type, comptime conf pub const Query: type = query.RepositoryQuery(Model, TableShape, config); pub const Insert: type = insert.RepositoryInsert(Model, TableShape, config, config.insertShape); + /// Type of one model key. + pub const KeyType = ModelKeyType(Model, TableShape, config); + pub fn InsertCustom(comptime InsertShape: type) type { return insert.RepositoryInsert(Model, TableShape, config, InsertShape); } @@ -51,34 +90,70 @@ pub fn Repository(comptime Model: type, comptime TableShape: type, comptime conf } /// Try to find the requested model. + /// For simple keys: modelKey type must match the type of its corresponding field. + /// modelKey can be an array / slice of keys. + /// For composite keys: modelKey must be a struct with all the keys, matching the type of their corresponding field. + /// modelKey can be an array / slice of these structs. pub fn find(allocator: std.mem.Allocator, database: *pg.Pool, modelKey: anytype) !RepositoryResult(Model) { // Initialize a new query. var modelQuery = Self.Query.init(allocator, database, .{}); defer modelQuery.deinit(); if (config.key.len == 1) { - // Add a simple condition. - try modelQuery.whereValue(std.meta.fields(TableShape)[std.meta.fieldIndex(TableShape, config.key[0]).?].type, config.key[0], "=", modelKey); - } else { - // Add conditions for all keys in the composite key. - var conditions: [config.key.len]_sql.SqlParams = undefined; + // Find key name and its type. + const keyName = config.key[0]; + const keyType = std.meta.fields(TableShape)[std.meta.fieldIndex(TableShape, keyName).?].type; - inline for (config.key, &conditions) |keyName, *condition| { - if (std.meta.fieldIndex(@TypeOf(modelKey), keyName)) |_| { - // The field exists in the key structure, create its condition. - condition.* = try modelQuery.newCondition().value( - std.meta.fields(TableShape)[std.meta.fieldIndex(TableShape, keyName).?].type, - keyName, "=", - @field(modelKey, keyName), - ); - } else { - // The field doesn't exist, compilation error. - @compileError("The key structure must include a field for " ++ keyName); - } + // Accept arrays / slices of keys, and simple keys. + switch (@typeInfo(@TypeOf(modelKey))) { + .Pointer => |ptr| { + switch (ptr.size) { + .One => { + switch (@typeInfo(ptr.child)) { + // Add a whereIn with the array. + .Array => try modelQuery.whereIn(keyType, keyName, modelKey), + // Add a simple condition with the pointed value. + else => try modelQuery.whereValue(keyType, keyName, "=", modelKey.*), + } + }, + // Add a whereIn with the slice. + else => try modelQuery.whereIn(keyType, keyName, modelKey), + } + }, + // Add a simple condition with the given value. + else => try modelQuery.whereValue(keyType, keyName, "=", modelKey), + } + } else { + // Accept arrays / slices of keys, and simple keys. + // Uniformize modelKey parameter to a slice. + const modelKeysList: []const Self.KeyType = switch (@typeInfo(@TypeOf(modelKey))) { + .Pointer => |ptr| switch (ptr.size) { + .One => switch (@typeInfo(ptr.child)) { + // Already an array. + .Array => @as([]const Self.KeyType, modelKey), + // Convert the pointer to an array. + else => &[1]Self.KeyType{@as(Self.KeyType, modelKey.*)}, + }, + // Already a slice. + else => @as([]const Self.KeyType, modelKey), + }, + // Convert the value to an array. + else => &[1]Self.KeyType{@as(Self.KeyType, modelKey)}, + }; + + // Initialize keys conditions list. + const conditions: []_sql.SqlParams = try allocator.alloc(_sql.SqlParams, modelKeysList.len); + defer allocator.free(conditions); + + // For each model key, add its conditions. + for (modelKeysList, conditions) |_modelKey, *condition| { + condition.* = try modelQuery.newCondition().@"and"( + &try buildCompositeKeysConditions(TableShape, config.key, modelQuery.newCondition(), _modelKey) + ); } - // Set WHERE conditions in the query. - modelQuery.where(try modelQuery.newCondition().@"and"(&conditions)); + // Set WHERE conditions in the query with all keys conditions. + modelQuery.where(try modelQuery.newCondition().@"or"(conditions)); } // Execute query and return its result. @@ -179,3 +254,24 @@ pub fn RepositoryResult(comptime Model: type) type { } }; } + +/// Build conditions for given composite keys, with a model key structure. +pub fn buildCompositeKeysConditions(comptime TableShape: type, comptime keys: []const []const u8, conditionsBuilder: _conditions.Builder, modelKey: anytype) ![keys.len]_sql.SqlParams { + // Conditions list for all keys in the composite key. + var conditions: [keys.len]_sql.SqlParams = undefined; + + inline for (keys, &conditions) |keyName, *condition| { + const keyType = std.meta.fields(TableShape)[std.meta.fieldIndex(TableShape, keyName).?].type; + + if (std.meta.fieldIndex(@TypeOf(modelKey), keyName)) |_| { + // The field exists in the key structure, create its condition. + condition.* = try conditionsBuilder.value(keyType, keyName, "=", @field(modelKey, keyName)); + } else { + // The field doesn't exist, compilation error. + @compileError("The key structure must include a field for " ++ keyName); + } + } + + // Return conditions for the current model key. + return conditions; +} diff --git a/tests/composite.zig b/tests/composite.zig index 2322edd..7fe85aa 100644 --- a/tests/composite.zig +++ b/tests/composite.zig @@ -136,4 +136,24 @@ test "composite model create, save and find" { try std.testing.expectEqual(1, result4.models.len); try std.testing.expectEqualDeep(newModel, result4.first().?.*); + + + // Try to find multiple models at once. + var result5 = try CompositeModelRepository.find(std.testing.allocator, database, &[_]CompositeModelRepository.KeyType{ + .{ + .firstcol = newModel.firstcol, + .secondcol = newModel.secondcol, + }, + .{ + .firstcol = result3.first().?.firstcol, + .secondcol = result3.first().?.secondcol, + }, + }); + defer result5.deinit(); + + try std.testing.expectEqual(2, result5.models.len); + try std.testing.expectEqual(newModel.firstcol, result5.models[0].firstcol); + try std.testing.expectEqualStrings(newModel.secondcol, result5.models[0].secondcol); + try std.testing.expectEqual(result3.first().?.firstcol, result5.models[1].firstcol); + try std.testing.expectEqualStrings(result3.first().?.secondcol, result5.models[1].secondcol); } diff --git a/tests/repository.zig b/tests/repository.zig index cc45ec2..73bdb0a 100644 --- a/tests/repository.zig +++ b/tests/repository.zig @@ -310,4 +310,13 @@ test "model create, save and find" { defer result4.deinit(); // Will clear some values in newModel. try std.testing.expectEqualDeep(newModel, result4.first().?.*); + + + // Try to find multiple models at once. + var result5 = try MyModelRepository.find(std.testing.allocator, database, &[_]i32{1, newModel.id}); + defer result5.deinit(); + + try std.testing.expectEqual(2, result5.models.len); + try std.testing.expectEqual(1, result5.models[0].id); + try std.testing.expectEqual(newModel.id, result5.models[1].id); }