diff --git a/build.zig.zon b/build.zig.zon index 959aea3..baec0de 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -10,16 +10,10 @@ .hash = "12200fe147879d72381633e6f44d76db2c8a603cda1969b4e474c15c31052dbb24b7", }, .pg = .{ - .url = "git+https://github.com/karlseguin/pg.zig?ref=zig-0.13#239a4468163a49d8c0d03285632eabe96003e9e2", - .hash = "1220a1d7e51e2fa45e547c76a9e099c09d06e14b0b9bfc6baa89367f56f1ded399a0", + .url = "git+https://github.com/Madeorsk/pg.zig?ref=expose-mapper-zig-0.13#11c91a714858c517539069fca620918c77a771c7", + .hash = "122047c505cca855c4fb2c831bc6056527175bbd111fe2dc77488799b576da213075", }, }, - .paths = .{ - "build.zig", - "build.zig.zon", - "src", - "README.md", - "LICENSE" - }, + .paths = .{ "build.zig", "build.zig.zon", "src", "README.md", "LICENSE" }, } diff --git a/src/conditions.zig b/src/conditions.zig index 5b1fdfa..c8cea4d 100644 --- a/src/conditions.zig +++ b/src/conditions.zig @@ -5,7 +5,7 @@ const errors = @import("errors.zig"); const Static = @This(); /// Create a value condition on a column. -pub fn value(comptime ValueType: type, allocator: std.mem.Allocator, comptime _column: []const u8, comptime operator: []const u8, _value: ValueType) !_sql.SqlParams { +pub fn value(comptime ValueType: type, allocator: std.mem.Allocator, comptime _column: []const u8, comptime operator: []const u8, _value: ValueType) !_sql.RawQuery { // Initialize the SQL condition string. var comptimeSql: [_column.len + 1 + operator.len + 1 + 1]u8 = undefined; @memcpy(comptimeSql[0.._column.len], _column); @@ -18,8 +18,8 @@ pub fn value(comptime ValueType: type, allocator: std.mem.Allocator, comptime _c std.mem.copyForwards(u8, sqlBuf, &comptimeSql); // Initialize parameters array. - const params = try allocator.alloc(_sql.QueryParameter, 1); - params[0] = try _sql.QueryParameter.fromValue(_value); + const params = try allocator.alloc(_sql.RawQueryParameter, 1); + params[0] = try _sql.RawQueryParameter.fromValue(_value); // Return the built SQL condition. return .{ @@ -29,7 +29,7 @@ pub fn value(comptime ValueType: type, allocator: std.mem.Allocator, comptime _c } /// Create a column condition on a column. -pub fn column(allocator: std.mem.Allocator, comptime _column: []const u8, comptime operator: []const u8, comptime valueColumn: []const u8) !_sql.SqlParams { +pub fn column(allocator: std.mem.Allocator, comptime _column: []const u8, comptime operator: []const u8, comptime valueColumn: []const u8) !_sql.RawQuery { // Initialize the SQL condition string. var comptimeSql: [_column.len + 1 + operator.len + 1 + valueColumn.len]u8 = undefined; @memcpy(comptimeSql[0.._column.len], _column); @@ -45,19 +45,19 @@ pub fn column(allocator: std.mem.Allocator, comptime _column: []const u8, compti // Return the built SQL condition. return .{ .sql = sqlBuf, - .params = &[0]_sql.QueryParameter{}, + .params = &[0]_sql.RawQueryParameter{}, }; } /// Create an IN condition on a column. -pub fn in(comptime ValueType: type, allocator: std.mem.Allocator, _column: []const u8, _value: []const ValueType) !_sql.SqlParams { +pub fn in(comptime ValueType: type, allocator: std.mem.Allocator, _column: []const u8, _value: []const ValueType) !_sql.RawQuery { // Generate parameters SQL. const parametersSql = try _sql.generateParametersSql(allocator, _value.len); // Get all query parameters from given values. - var valueParameters: []_sql.QueryParameter = try allocator.alloc(_sql.QueryParameter, _value.len); + var valueParameters: []_sql.RawQueryParameter = try allocator.alloc(_sql.RawQueryParameter, _value.len); for (0.._value.len) |i| { // Convert every given value to a query parameter. - valueParameters[i] = try _sql.QueryParameter.fromValue(_value[i]); + valueParameters[i] = try _sql.RawQueryParameter.fromValue(_value[i]); } // Initialize the SQL condition string. @@ -75,7 +75,7 @@ pub fn in(comptime ValueType: type, allocator: std.mem.Allocator, _column: []con } /// Generic conditions combiner generator. -fn conditionsCombiner(comptime keyword: []const u8, allocator: std.mem.Allocator, subconditions: []const _sql.SqlParams) !_sql.SqlParams { +fn conditionsCombiner(comptime keyword: []const u8, allocator: std.mem.Allocator, subconditions: []const _sql.RawQuery) !_sql.RawQuery { if (subconditions.len == 0) { // At least one condition is required. return errors.ZrmError.AtLeastOneConditionRequired; @@ -97,7 +97,7 @@ fn conditionsCombiner(comptime keyword: []const u8, allocator: std.mem.Allocator // Initialize the SQL condition string. var sqlBuf = try allocator.alloc(u8, sqlSize); // Initialize the query parameters array. - var parameters = try allocator.alloc(_sql.QueryParameter, queryParametersCount); + var parameters = try allocator.alloc(_sql.RawQueryParameter, queryParametersCount); var sqlBufCursor: usize = 0; var parametersCursor: usize = 0; // Add first parenthesis. @@ -117,7 +117,7 @@ fn conditionsCombiner(comptime keyword: []const u8, allocator: std.mem.Allocator } // Add query parameters to the array. - std.mem.copyForwards(_sql.QueryParameter, parameters[parametersCursor..parametersCursor+subcondition.params.len], subcondition.params); + std.mem.copyForwards(_sql.RawQueryParameter, parameters[parametersCursor..parametersCursor+subcondition.params.len], subcondition.params); parametersCursor += subcondition.params.len; } @@ -132,12 +132,12 @@ fn conditionsCombiner(comptime keyword: []const u8, allocator: std.mem.Allocator } /// Create an AND condition between multiple sub-conditions. -pub fn @"and"(allocator: std.mem.Allocator, subconditions: []const _sql.SqlParams) !_sql.SqlParams { +pub fn @"and"(allocator: std.mem.Allocator, subconditions: []const _sql.RawQuery) !_sql.RawQuery { return conditionsCombiner("AND", allocator, subconditions); } /// Create an OR condition between multiple sub-conditions. -pub fn @"or"(allocator: std.mem.Allocator, subconditions: []const _sql.SqlParams) !_sql.SqlParams { +pub fn @"or"(allocator: std.mem.Allocator, subconditions: []const _sql.RawQuery) !_sql.RawQuery { return conditionsCombiner("OR", allocator, subconditions); } @@ -148,27 +148,27 @@ pub const Builder = struct { allocator: std.mem.Allocator, /// Create a value condition on a column. - pub fn value(self: Self, comptime ValueType: type, comptime _column: []const u8, comptime operator: []const u8, _value: ValueType) !_sql.SqlParams { + pub fn value(self: Self, comptime ValueType: type, comptime _column: []const u8, comptime operator: []const u8, _value: ValueType) !_sql.RawQuery { return Static.value(ValueType, self.allocator, _column, operator, _value); } /// Create a column condition on a column. - pub fn column(self: Self, comptime _column: []const u8, comptime operator: []const u8, comptime valueColumn: []const u8) !_sql.SqlParams { + pub fn column(self: Self, comptime _column: []const u8, comptime operator: []const u8, comptime valueColumn: []const u8) !_sql.RawQuery { return Static.column(self.allocator, _column, operator, valueColumn); } /// Create an IN condition on a column. - pub fn in(self: Self, comptime ValueType: type, _column: []const u8, _value: []const ValueType) !_sql.SqlParams { + pub fn in(self: Self, comptime ValueType: type, _column: []const u8, _value: []const ValueType) !_sql.RawQuery { return Static.in(ValueType, self.allocator, _column, _value); } /// Create an AND condition between multiple sub-conditions. - pub fn @"and"(self: Self, subconditions: []const _sql.SqlParams) !_sql.SqlParams { + pub fn @"and"(self: Self, subconditions: []const _sql.RawQuery) !_sql.RawQuery { return Static.@"and"(self.allocator, subconditions); } /// Create an OR condition between multiple sub-conditions. - pub fn @"or"(self: Self, subconditions: []const _sql.SqlParams) !_sql.SqlParams { + pub fn @"or"(self: Self, subconditions: []const _sql.RawQuery) !_sql.RawQuery { return Static.@"or"(self.allocator, subconditions); } diff --git a/src/insert.zig b/src/insert.zig index b5c772c..0d251e7 100644 --- a/src/insert.zig +++ b/src/insert.zig @@ -49,7 +49,7 @@ pub fn Insertable(comptime StructType: type) type { pub fn RepositoryInsertConfiguration(comptime InsertShape: type) type { return struct { values: []const Insertable(InsertShape) = undefined, - returning: ?_sql.SqlParams = null, + returning: ?_sql.RawQuery = null, }; } @@ -187,7 +187,7 @@ pub fn RepositoryInsert(comptime Model: type, comptime TableShape: type, comptim } /// Set selected columns for RETURNING clause. - pub fn returning(self: *Self, _select: _sql.SqlParams) void { + pub fn returning(self: *Self, _select: _sql.RawQuery) void { self.insertConfig.returning = _select; } @@ -200,7 +200,7 @@ pub fn RepositoryInsert(comptime Model: type, comptime TableShape: type, comptim self.returning(.{ // Join selected columns. .sql = std.mem.join(self.arena.allocator(), ", ", _select), - .params = &[_]_sql.QueryParameter{}, // No parameters. + .params = &[_]_sql.RawQueryParameter{}, // No parameters. }); } @@ -208,7 +208,7 @@ pub fn RepositoryInsert(comptime Model: type, comptime TableShape: type, comptim pub fn returningAll(self: *Self) void { self.returning(.{ .sql = "*", - .params = &[_]_sql.QueryParameter{}, // No parameters. + .params = &[_]_sql.RawQueryParameter{}, // No parameters. }); } @@ -236,56 +236,53 @@ pub fn RepositoryInsert(comptime Model: type, comptime TableShape: type, comptim ) else 0; // Initialize SQL buffer. - const sqlBuf = try self.arena.allocator().alloc(u8, fixedSqlSize + valuesSqlSize + returningSize); + var sqlBuf = try std.ArrayList(u8).initCapacity(self.arena.allocator(), fixedSqlSize + valuesSqlSize + returningSize); + defer sqlBuf.deinit(); // Append initial "INSERT INTO table VALUES ". - @memcpy(sqlBuf[0..sqlBase.len],sqlBase); - var sqlBufCursor: usize = sqlBase.len; + try sqlBuf.appendSlice(sqlBase); // Start parameter counter at 1. var currentParameter: usize = 1; if (self.insertConfig.values.len == 0) { // No values, output an empty values set. - std.mem.copyForwards(u8, sqlBuf[sqlBufCursor..sqlBufCursor+2], "()"); - sqlBufCursor += 2; + try sqlBuf.appendSlice("()"); } else { // Build values set. for (self.insertConfig.values) |_| { // Add the first '('. - sqlBuf[sqlBufCursor] = '('; sqlBufCursor += 1; + try sqlBuf.append('('); inline for (columns) |_| { // Create the parameter string and append it to the SQL buffer. - const paramSize = 1 + try _sql.computeRequiredSpaceForParameter(currentParameter) + 1; - _ = try std.fmt.bufPrint(sqlBuf[sqlBufCursor..sqlBufCursor+paramSize], "${d},", .{currentParameter}); - sqlBufCursor += paramSize; + try sqlBuf.writer().print("${d},", .{currentParameter}); // Increment parameter count. currentParameter += 1; } // Replace the final ',' with a ')'. - sqlBuf[sqlBufCursor - 1] = ')'; + sqlBuf.items[sqlBuf.items.len - 1] = ')'; // Add the final ','. - sqlBuf[sqlBufCursor] = ','; sqlBufCursor += 1; + try sqlBuf.append(','); } - sqlBufCursor -= 1; + + // Remove the last ','. + _ = sqlBuf.pop(); } // Append RETURNING clause, if there is one defined. if (self.insertConfig.returning) |_returning| { - @memcpy(sqlBuf[sqlBufCursor..sqlBufCursor+(1 + returningClause.len + 1)], " " ++ returningClause ++ " "); + try sqlBuf.appendSlice(" " ++ returningClause ++ " "); // Copy RETURNING clause content and replace parameters, if there are some. try _sql.copyAndReplaceSqlParameters(¤tParameter, - _returning.params.len, - sqlBuf[sqlBufCursor+(1+returningClause.len+1)..sqlBufCursor+returningSize], _returning.sql + _returning.params.len, sqlBuf.writer(), _returning.sql ); - sqlBufCursor += returningSize; } // ";" to end the query. - sqlBuf[sqlBufCursor] = ';'; sqlBufCursor += 1; + try sqlBuf.append(';'); // Save built SQL query. - self.sql = sqlBuf; + self.sql = try sqlBuf.toOwnedSlice(); } /// Execute the insert query. diff --git a/src/postgresql.zig b/src/postgresql.zig index 5b9344d..cbb9198 100644 --- a/src/postgresql.zig +++ b/src/postgresql.zig @@ -5,6 +5,7 @@ const global = @import("global.zig"); const errors = @import("errors.zig"); const database = @import("database.zig"); const _sql = @import("sql.zig"); +const _relations = @import("relations.zig"); const repository = @import("repository.zig"); /// PostgreSQL query error details. @@ -14,7 +15,7 @@ pub const PostgresqlError = struct { }; /// Try to bind query parameters to the statement. -pub fn bindQueryParameters(statement: *pg.Stmt, parameters: []const _sql.QueryParameter) !void { +pub fn bindQueryParameters(statement: *pg.Stmt, parameters: []const _sql.RawQueryParameter) !void { for (parameters) |parameter| { // Try to bind each parameter in the slice. try bindQueryParameter(statement, parameter); @@ -22,7 +23,7 @@ pub fn bindQueryParameters(statement: *pg.Stmt, parameters: []const _sql.QueryPa } /// Try to bind a query parameter to the statement. -pub fn bindQueryParameter(statement: *pg.Stmt, parameter: _sql.QueryParameter) !void { +pub fn bindQueryParameter(statement: *pg.Stmt, parameter: _sql.RawQueryParameter) !void { switch (parameter) { .integer => |integer| try statement.bind(integer), .number => |number| try statement.bind(number), @@ -57,6 +58,27 @@ pub fn handleRawPostgresqlError(err: anyerror, connection: *pg.Conn) anyerror { } } +/// Make a PostgreSQL result mapper with the given prefix, if there is one. +pub fn makeMapper(comptime T: type, result: *pg.Result, allocator: std.mem.Allocator, optionalPrefix: ?[]const u8) !pg.Mapper(T) { + var column_indexes: [std.meta.fields(T).len]?usize = undefined; + + inline for (std.meta.fields(T), 0..) |field, i| { + if (optionalPrefix) |prefix| { + const fullName = try std.fmt.allocPrint(allocator, "{s}" ++ field.name, .{prefix}); + defer allocator.free(fullName); + column_indexes[i] = result.columnIndex(fullName); + } else { + column_indexes[i] = result.columnIndex(field.name); + } + } + + return .{ + .result = result, + .allocator = allocator, + .column_indexes = column_indexes, + }; +} + /// Generic query results mapping. pub fn mapResults(comptime Model: type, comptime TableShape: type, repositoryConfig: repository.RepositoryConfiguration(Model, TableShape), @@ -66,7 +88,7 @@ pub fn mapResults(comptime Model: type, comptime TableShape: type, // Create an arena for mapper data. var mapperArena = std.heap.ArenaAllocator.init(allocator); // Get result mapper. - const mapper = queryResult.mapper(TableShape, .{ .allocator = mapperArena.allocator() }); + const mapper = try makeMapper(TableShape, queryResult, mapperArena.allocator(), null); // Initialize models list. var models = std.ArrayList(*Model).init(allocator); diff --git a/src/query.zig b/src/query.zig index d2a6cf1..2863c60 100644 --- a/src/query.zig +++ b/src/query.zig @@ -5,26 +5,31 @@ const errors = @import("errors.zig"); const database = @import("database.zig"); const postgresql = @import("postgresql.zig"); const _sql = @import("sql.zig"); -const conditions = @import("conditions.zig"); +const _conditions = @import("conditions.zig"); +const relations = @import("relations.zig"); const repository = @import("repository.zig"); +const InlineRelationsResult = struct { + +}; + /// Repository query configuration structure. pub const RepositoryQueryConfiguration = struct { - select: ?_sql.SqlParams = null, - join: ?_sql.SqlParams = null, - where: ?_sql.SqlParams = null, + select: ?_sql.RawQuery = null, + join: ?_sql.RawQuery = null, + where: ?_sql.RawQuery = null, + with: ?[]const relations.Eager = null, }; /// Repository models query manager. /// Manage query string build and its execution. pub fn RepositoryQuery(comptime Model: type, comptime TableShape: type, comptime repositoryConfig: repository.RepositoryConfiguration(Model, TableShape)) type { // Pre-compute SQL buffer size. - const selectClause = "SELECT"; - const fromClause = "FROM"; - const whereClause = "WHERE"; - // SELECT ? FROM {repositoryConfig.table}??; - const fixedSqlSize = selectClause.len + 1 + 0 + 1 + fromClause.len + 1 + repositoryConfig.table.len + 0 + 0 + 1; - const defaultSelectSql = "*"; + const fromClause = " FROM \"" ++ repositoryConfig.table ++ "\""; + const defaultSelectSql = "\"" ++ repositoryConfig.table ++ "\".*"; + + // Model key type. + const KeyType = repository.ModelKeyType(Model, TableShape, repositoryConfig); return struct { const Self = @This(); @@ -34,10 +39,14 @@ pub fn RepositoryQuery(comptime Model: type, comptime TableShape: type, comptime connection: *database.Connection = undefined, queryConfig: RepositoryQueryConfiguration, + /// List of loaded inline relations. + inlineRelations: []relations.Eager = undefined, + + query: ?_sql.RawQuery = null, sql: ?[]const u8 = null, /// Set selected columns. - pub fn select(self: *Self, _select: _sql.SqlParams) void { + pub fn select(self: *Self, _select: _sql.RawQuery) void { self.queryConfig.select = _select; } @@ -50,131 +59,236 @@ pub fn RepositoryQuery(comptime Model: type, comptime TableShape: type, comptime self.select(.{ // Join selected columns. .sql = std.mem.join(self.arena.allocator(), ", ", _select), - .params = &[_]_sql.QueryParameter{}, // No parameters. + .params = &[_]_sql.RawQueryParameter{}, // No parameters. }); } /// Set JOIN clause. - pub fn join(self: *Self, _join: _sql.SqlParams) void { + pub fn join(self: *Self, _join: _sql.RawQuery) void { self.queryConfig.join = _join; } /// Set WHERE conditions. - pub fn where(self: *Self, _where: _sql.SqlParams) void { + pub fn where(self: *Self, _where: _sql.RawQuery) void { self.queryConfig.where = _where; } /// Create a new condition builder. - pub fn newCondition(self: *Self) conditions.Builder { - return conditions.Builder.init(self.arena.allocator()); + pub fn newCondition(self: *Self) _conditions.Builder { + return _conditions.Builder.init(self.arena.allocator()); } /// Set a WHERE value condition. pub fn whereValue(self: *Self, comptime ValueType: type, comptime _column: []const u8, comptime operator: []const u8, _value: ValueType) !void { self.where( - try conditions.value(ValueType, self.arena.allocator(), _column, operator, _value) + try _conditions.value(ValueType, self.arena.allocator(), _column, operator, _value) ); } /// Set a WHERE column condition. pub fn whereColumn(self: *Self, comptime _column: []const u8, comptime operator: []const u8, comptime _valueColumn: []const u8) !void { self.where( - try conditions.column(self.arena.allocator(), _column, operator, _valueColumn) + try _conditions.column(self.arena.allocator(), _column, operator, _valueColumn) ); } /// Set a WHERE IN condition. pub fn whereIn(self: *Self, comptime ValueType: type, comptime _column: []const u8, _value: []const ValueType) !void { self.where( - try conditions.in(ValueType, self.arena.allocator(), _column, _value) + try _conditions.in(ValueType, self.arena.allocator(), _column, _value) ); } + /// Set a WHERE from model key(s). + /// For simple keys: modelKey type must match the type of its corresponding field. + /// modelKey can be an array / slice of keys. + /// For composite keys: modelKey must be a struct with all the keys, matching the type of their corresponding field. + /// modelKey can be an array / slice of these structs. + pub fn whereKey(self: *Self, modelKey: anytype) !void { + if (repositoryConfig.key.len == 1) { + // Find key name and its type. + const keyName = repositoryConfig.key[0]; + const keyType = std.meta.fields(TableShape)[std.meta.fieldIndex(TableShape, keyName).?].type; + + // Accept arrays / slices of keys, and simple keys. + switch (@typeInfo(@TypeOf(modelKey))) { + .Pointer => |ptr| { + switch (ptr.size) { + .One => { + switch (@typeInfo(ptr.child)) { + // Add a whereIn with the array. + .Array => { + if (ptr.child == u8) + // If the child is a string, use it as a simple value. + try self.whereValue(KeyType, keyName, "=", modelKey) + else + // Otherwise, use it as an array. + try self.whereIn(keyType, keyName, modelKey); + }, + // Add a simple condition with the pointed value. + else => try self.whereValue(keyType, keyName, "=", modelKey.*), + } + }, + // Add a whereIn with the slice. + else => { + if (ptr.child == u8) + // If the child is a string, use it as a simple value. + try self.whereValue(KeyType, keyName, "=", modelKey) + else + // Otherwise, use it as an array. + try self.whereIn(keyType, keyName, modelKey); + }, + } + }, + // Add a simple condition with the given value. + else => try self.whereValue(keyType, keyName, "=", modelKey), + } + } else { + // Accept arrays / slices of keys, and simple keys. + // Uniformize modelKey parameter to a slice. + const modelKeysList: []const KeyType = switch (@typeInfo(@TypeOf(modelKey))) { + .Pointer => |ptr| switch (ptr.size) { + .One => switch (@typeInfo(ptr.child)) { + // Already an array. + .Array => @as([]const KeyType, modelKey), + // Convert the pointer to an array. + else => &[1]KeyType{@as(KeyType, modelKey.*)}, + }, + // Already a slice. + else => @as([]const KeyType, modelKey), + }, + // Convert the value to an array. + else => &[1]KeyType{@as(KeyType, modelKey)}, + }; + + // Initialize keys conditions list. + const conditions: []_sql.RawQuery = try self.arena.allocator().alloc(_sql.RawQuery, modelKeysList.len); + defer self.arena.allocator().free(conditions); + + // For each model key, add its conditions. + for (modelKeysList, conditions) |_modelKey, *condition| { + condition.* = try self.newCondition().@"and"( + &try buildCompositeKeysConditions(TableShape, repositoryConfig.key, self.newCondition(), _modelKey) + ); + } + + // Set WHERE conditions in the query with all keys conditions. + self.where(try self.newCondition().@"or"(conditions)); + } + } + + /// Set relations to eager load. + pub fn with(self: *Self, relation: relations.ModelRelation) !void { + // Take an array of eager relations (which can have subrelations). + const allocator = self.arena.allocator(); + + // Make a relation instance. + const relationInstance = try allocator.create(relation.relation); + + // Add the new relation to a newly allocated array, with one more space. + const newPos = if (self.queryConfig.with) |_with| _with.len else 0; + var newWith = try allocator.alloc(relations.Eager, newPos + 1); + newWith[newPos] = .{ + .field = relation.field, + .relation = relationInstance.*.relation(), + .with = &[0]relations.Eager{}, //TODO handle subrelations with dotted syntax + }; + + if (self.queryConfig.with) |_with| { + // Copy existing relations. + @memcpy(newWith[0..newPos], _with); + // Free previous array. + allocator.free(_with); + } + + // Save the newly allocated array. + self.queryConfig.with = newWith; + } + + /// Build inline relations query part. + fn buildInlineRelations(self: *Self) !?struct{ + select: []const u8, + join: _sql.RawQuery, + } { + if (self.queryConfig.with) |_with| { + // Initialize an ArrayList of query parts for relations. + var inlineRelations = try std.ArrayList(_sql.RawQuery).initCapacity(self.arena.allocator(), _with.len); + defer inlineRelations.deinit(); + var inlineRelationsSelect = try std.ArrayList([]const u8).initCapacity(self.arena.allocator(), _with.len); + defer inlineRelationsSelect.deinit(); + + // Initialize an ArrayList to store all loaded inline relations. + var loadedRelations = std.ArrayList(relations.Eager).init(self.arena.allocator()); + defer loadedRelations.deinit(); + + for (_with) |_relation| { + // Append each inline relation to the ArrayList. + if (_relation.relation.inlineMapping()) { + try loadedRelations.append(_relation); // Store the loaded inline relation. + + // Get an allocator for local allocations. + const localAllocator = self.arena.allocator(); + + // Build table alias and fields prefix. + const tableAlias = try std.fmt.allocPrint(localAllocator, "relations.{s}", .{_relation.field}); + defer localAllocator.free(tableAlias); + const prefix = try std.fmt.allocPrint(localAllocator, "{s}.", .{tableAlias}); + defer localAllocator.free(prefix); + + // Alter query to get relation fields. + try inlineRelations.append(try _relation.relation.genJoin(self.arena.allocator(), tableAlias)); + const relationSelect = try _relation.relation.genSelect(localAllocator, tableAlias, prefix); + try inlineRelationsSelect.append(relationSelect); + } + } + + self.inlineRelations = try loadedRelations.toOwnedSlice(); + + // Return the inline relations query part. + return .{ + .select = try std.mem.join(self.arena.allocator(), ", ", inlineRelationsSelect.items), + .join = try _sql.RawQuery.fromConcat(self.arena.allocator(), inlineRelations.items), + }; + } else { + // Nothing. + return null; + } + } + /// Build SQL query. pub fn buildSql(self: *Self) !void { - // Start parameter counter at 1. - var currentParameter: usize = 1; + // Build inline relations query part. + const inlineRelations = try self.buildInlineRelations(); + defer if (inlineRelations) |_inlineRelations| self.arena.allocator().free(_inlineRelations.join.sql); + defer if (inlineRelations) |_inlineRelations| self.arena.allocator().free(_inlineRelations.join.params); + defer if (inlineRelations) |_inlineRelations| self.arena.allocator().free(_inlineRelations.select); - // Compute SELECT size. - var selectSize: usize = defaultSelectSql.len; - if (self.queryConfig.select) |_select| { - selectSize = _select.sql.len + _sql.computeRequiredSpaceForParametersNumbers(_select.params.len, currentParameter - 1); - currentParameter += _select.params.len; - } - - // Compute JOIN size. - var joinSize: usize = 0; - if (self.queryConfig.join) |_join| { - joinSize = 1 + _join.sql.len + _sql.computeRequiredSpaceForParametersNumbers(_join.params.len, currentParameter - 1); - currentParameter += _join.params.len; - } - - // Compute WHERE size. - var whereSize: usize = 0; - if (self.queryConfig.where) |_where| { - whereSize = 1 + whereClause.len + _where.sql.len + 1 + _sql.computeRequiredSpaceForParametersNumbers(_where.params.len, currentParameter - 1); - currentParameter += _where.params.len; - } - - // Allocate SQL buffer from computed size. - const sqlBuf = try self.arena.allocator().alloc(u8, fixedSqlSize - + (selectSize) - + (joinSize) - + (whereSize) - ); - - // Fill SQL buffer. - - // Restart parameter counter at 1. - currentParameter = 1; - - // SELECT clause. - @memcpy(sqlBuf[0..selectClause.len+1], selectClause ++ " "); - var sqlBufCursor: usize = selectClause.len+1; - - // Copy SELECT clause content and replace parameters, if there are some. - try _sql.copyAndReplaceSqlParameters(¤tParameter, - if (self.queryConfig.select) |_select| _select.params.len else 0, - sqlBuf[sqlBufCursor..sqlBufCursor+selectSize], - if (self.queryConfig.select) |_select| _select.sql else defaultSelectSql, - ); - sqlBufCursor += selectSize; - - // FROM clause. - sqlBuf[sqlBufCursor] = ' '; sqlBufCursor += 1; - std.mem.copyForwards(u8, sqlBuf[sqlBufCursor..sqlBufCursor+fromClause.len], fromClause); sqlBufCursor += fromClause.len; - sqlBuf[sqlBufCursor] = ' '; sqlBufCursor += 1; - - // Table name. - std.mem.copyForwards(u8, sqlBuf[sqlBufCursor..sqlBufCursor+repositoryConfig.table.len], repositoryConfig.table); sqlBufCursor += repositoryConfig.table.len; - - // JOIN clause. - if (self.queryConfig.join) |_join| { - sqlBuf[sqlBufCursor] = ' '; - // Copy JOIN clause and replace parameters, if there are some. - try _sql.copyAndReplaceSqlParameters(¤tParameter, - _join.params.len, - sqlBuf[sqlBufCursor+1..sqlBufCursor+joinSize], _join.sql - ); - sqlBufCursor += joinSize; - } - - // WHERE clause. - if (self.queryConfig.where) |_where| { - @memcpy(sqlBuf[sqlBufCursor..sqlBufCursor+(1 + whereClause.len + 1)], " " ++ whereClause ++ " "); - // Copy WHERE clause content and replace parameters, if there are some. - try _sql.copyAndReplaceSqlParameters(¤tParameter, - _where.params.len, - sqlBuf[sqlBufCursor+(1+whereClause.len+1)..sqlBufCursor+whereSize], _where.sql - ); - sqlBufCursor += whereSize; - } - - // ";" to end the query. - sqlBuf[sqlBufCursor] = ';'; sqlBufCursor += 1; + // Build the full SQL query from all its parts. + const sqlQuery = _sql.RawQuery{ + .sql = try std.mem.join(self.arena.allocator(), "", &[_][]const u8{ + "SELECT ", if (self.queryConfig.select) |_select| _select.sql else defaultSelectSql, + if (inlineRelations) |_| ", " else "", + if (inlineRelations) |_inlineRelations| _inlineRelations.select else "", + fromClause, + if (self.queryConfig.join) |_| " " else "", + if (self.queryConfig.join) |_join| _join.sql else "", + if (inlineRelations) |_| " " else "", + if (inlineRelations) |_inlineRelations| _inlineRelations.join.sql else "", + if (self.queryConfig.where) |_| " WHERE " else "", + if (self.queryConfig.where) |_where| _where.sql else "", + ";", + }), + .params = try std.mem.concat(self.arena.allocator(), _sql.RawQueryParameter, &[_][]const _sql.RawQueryParameter{ + if (self.queryConfig.select) |_select| _select.params else &[0]_sql.RawQueryParameter{}, + if (self.queryConfig.join) |_join| _join.params else &[0]_sql.RawQueryParameter{}, + if (inlineRelations) |_inlineRelations| _inlineRelations.join.params else &[0]_sql.RawQueryParameter{}, + if (self.queryConfig.where) |_where| _where.params else &[0]_sql.RawQueryParameter{}, + }) + }; // Save built SQL query. - self.sql = sqlBuf; + self.query = sqlQuery; + self.sql = try sqlQuery.build(self.arena.allocator()); } /// Execute the built query. @@ -195,12 +309,8 @@ pub fn RepositoryQuery(comptime Model: type, comptime TableShape: type, comptime catch |err| return postgresql.handlePostgresqlError(err, self.connection, &statement); // Bind query parameters. - if (self.queryConfig.select) |_select| - try postgresql.bindQueryParameters(&statement, _select.params); - if (self.queryConfig.join) |_join| - try postgresql.bindQueryParameters(&statement, _join.params); - if (self.queryConfig.where) |_where| - try postgresql.bindQueryParameters(&statement, _where.params); + postgresql.bindQueryParameters(&statement, self.query.?.params) + catch |err| return postgresql.handlePostgresqlError(err, self.connection, &statement); // Execute the query and get its result. const result = statement.execute() @@ -241,3 +351,24 @@ pub fn RepositoryQuery(comptime Model: type, comptime TableShape: type, comptime } }; } + +/// Build conditions for given composite keys, with a model key structure. +pub fn buildCompositeKeysConditions(comptime TableShape: type, comptime keys: []const []const u8, conditionsBuilder: _conditions.Builder, modelKey: anytype) ![keys.len]_sql.RawQuery { + // Conditions list for all keys in the composite key. + var conditions: [keys.len]_sql.RawQuery = undefined; + + inline for (keys, &conditions) |keyName, *condition| { + const keyType = std.meta.fields(TableShape)[std.meta.fieldIndex(TableShape, keyName).?].type; + + if (std.meta.fieldIndex(@TypeOf(modelKey), keyName)) |_| { + // The field exists in the key structure, create its condition. + condition.* = try conditionsBuilder.value(keyType, keyName, "=", @field(modelKey, keyName)); + } else { + // The field doesn't exist, compilation error. + @compileError("The key structure must include a field for " ++ keyName); + } + } + + // Return conditions for the current model key. + return conditions; +} diff --git a/src/relations.zig b/src/relations.zig new file mode 100644 index 0000000..521496d --- /dev/null +++ b/src/relations.zig @@ -0,0 +1,343 @@ +const std = @import("std"); +const pg = @import("pg"); +const _sql = @import("sql.zig"); +const repository = @import("repository.zig"); +const _query = @import("query.zig"); + +/// Configure a "one to many" or "many to many" relation. +pub const ManyConfiguration = union(enum) { + /// Direct one-to-many relation using a distant foreign key. + direct: struct { + /// The distant foreign key name pointing to the current model. + foreignKey: []const u8, + /// Current model key name. + /// Use the default key name of the current model. + modelKey: ?[]const u8 = null, + }, + + /// Used when performing a many-to-many relation through an association table. + through: struct { + /// Name of the join table. + table: []const u8, + /// The local foreign key name. + /// Use the default key name of the current model. + foreignKey: ?[]const u8 = null, + /// The foreign key name in the join table. + joinForeignKey: []const u8, + /// The model key name in the join table. + joinModelKey: []const u8, + /// Associated model key name. + /// Use the default key name of the associated model. + modelKey: ?[]const u8 = null, + }, +}; + +/// Make a "one to many" or "many to many" relation. +pub fn many(comptime fromRepo: anytype, comptime toRepo: anytype, comptime config: ManyConfiguration) type { + return typedMany( + fromRepo.ModelType, fromRepo.TableType, fromRepo.config, + toRepo.ModelType, toRepo.TableType, toRepo.config, + config, + ); +} + +/// Internal implementation of a new "one to many" or "many to many" relation. +pub fn typedMany( + comptime FromModel: type, comptime FromTable: type, + comptime fromRepositoryConfig: repository.RepositoryConfiguration(FromModel, FromTable), + comptime ToModel: type, comptime ToTable: type, + comptime toRepositoryConfig: repository.RepositoryConfiguration(ToModel, ToTable), + comptime config: ManyConfiguration) type { + + // Get foreign key from relation config or repository config. + const foreignKey = switch (config) { + .direct => |direct| direct.foreignKey, + .through => |through| if (through.foreignKey) |_foreignKey| _foreignKey else toRepositoryConfig.key[0], + }; + + // Get model key from relation config or repository config. + const modelKey = switch (config) { + .direct => |direct| if (direct.modelKey) |_modelKey| _modelKey else fromRepositoryConfig.key[0], + .through => |through| if (through.modelKey) |_modelKey| _modelKey else fromRepositoryConfig.key[0], + }; + + const FromKeyType = std.meta.fields(FromModel)[std.meta.fieldIndex(FromModel, fromRepositoryConfig.key[0]).?].type; + const QueryType = _query.RepositoryQuery(ToModel, ToTable, toRepositoryConfig); + const SelectBuilder = _sql.SelectBuilder(ToTable); + + return struct { + const Self = @This(); + + fn inlineMapping(_: *anyopaque) bool { + return false; + } + + fn genJoin(_: *anyopaque, _: std.mem.Allocator, _: []const u8) !_sql.RawQuery { + unreachable; // No possible join in a many relation. + } + + fn genSelect(_: *anyopaque, allocator: std.mem.Allocator, table: []const u8, prefix: []const u8) ![]const u8 { + return SelectBuilder.build(allocator, table, prefix); + } + + fn buildQuery(_: *anyopaque, opaqueModels: []const anyopaque, opaqueQuery: *anyopaque) !void { + var models: []const FromModel = undefined; + models.len = opaqueModels.len; + models.ptr = @ptrCast(@alignCast(opaqueModels.ptr)); + const query: *QueryType = @ptrCast(@alignCast(opaqueQuery)); + + // Prepare given models IDs. + const modelsIds = try query.arena.allocator().alloc(FromKeyType, models.len); + for (models, modelsIds) |model, *modelId| { + modelId.* = @field(model, fromRepositoryConfig.key[0]); + } + + switch (config) { + .direct => { + // Build WHERE condition. + try query.whereIn(FromKeyType, "\"" ++ toRepositoryConfig.table ++ "\".\"" ++ foreignKey ++ "\"", modelsIds); + }, + .through => |through| { + query.join(.{ + .sql = "INNER JOIN \"" ++ through.table ++ "\" ON " ++ + "\"" ++ toRepositoryConfig.table ++ "\"." ++ modelKey ++ " = " ++ "\"" ++ through.table ++ "\"." ++ through.joinModelKey, + .params = &[0]_sql.RawQueryParameter{}, + }); + + // Build WHERE condition. + try query.whereIn(FromKeyType, "\"" ++ through.table ++ "\".\"" ++ through.joinForeignKey ++ "\"", modelsIds); + }, + } + } + + pub fn relation(self: *Self) Relation { + return .{ + ._interface = .{ + .instance = self, + + .inlineMapping = inlineMapping, + .genJoin = genJoin, + .genSelect = genSelect, + .buildQuery = buildQuery, + }, + }; + } + }; +} + + +/// Configure a "one to one" relation. +pub const OneConfiguration = union(enum) { + /// Direct one-to-one relation using a local foreign key. + direct: struct { + /// The local foreign key name. + foreignKey: []const u8, + /// Associated model key name. + /// Use the default key name of the associated model. + modelKey: ?[]const u8 = null, + }, + + /// Reverse one-to-one relation using distant foreign key. + reverse: struct { + /// The distant foreign key name. + foreignKey: []const u8, + /// Current model key name. + /// Use the default key name of the current model. + modelKey: ?[]const u8 = null, + }, + + /// Used when performing a one-to-one relation through an association table. + through: struct { + /// Name of the join table. + table: []const u8, + /// The local foreign key name. + /// Use the default key name of the current model. + foreignKey: ?[]const u8 = null, + /// The foreign key name in the join table. + joinForeignKey: []const u8, + /// The model key name in the join table. + joinModelKey: []const u8, + /// Associated model key name. + /// Use the default key name of the associated model. + modelKey: ?[]const u8 = null, + }, +}; + +/// Make a "one to one" relation. +pub fn one(comptime fromRepo: anytype, comptime toRepo: anytype, comptime config: OneConfiguration) type { + return typedOne( + fromRepo.ModelType, fromRepo.TableType, fromRepo.config, + toRepo.ModelType, toRepo.TableType, toRepo.config, + config, + ); +} + +/// Internal implementation of a new "one to one" relation. +fn typedOne( + comptime FromModel: type, comptime FromTable: type, + comptime fromRepositoryConfig: repository.RepositoryConfiguration(FromModel, FromTable), + comptime ToModel: type, comptime ToTable: type, + comptime toRepositoryConfig: repository.RepositoryConfiguration(ToModel, ToTable), + comptime config: OneConfiguration) type { + + const FromKeyType = std.meta.fields(FromModel)[std.meta.fieldIndex(FromModel, fromRepositoryConfig.key[0]).?].type; + const QueryType = _query.RepositoryQuery(ToModel, ToTable, toRepositoryConfig); + const SelectBuilder = _sql.SelectBuilder(ToTable); + + // Get foreign key from relation config or repository config. + const foreignKey = switch (config) { + .direct => |direct| direct.foreignKey, + .reverse => |reverse| reverse.foreignKey, + .through => |through| if (through.foreignKey) |_foreignKey| _foreignKey else fromRepositoryConfig.key[0], + }; + + // Get model key from relation config or repository config. + const modelKey = switch (config) { + .direct => |direct| if (direct.modelKey) |_modelKey| _modelKey else toRepositoryConfig.key[0], + .reverse => |reverse| if (reverse.modelKey) |_modelKey| _modelKey else toRepositoryConfig.key[0], + .through => |through| if (through.modelKey) |_modelKey| _modelKey else toRepositoryConfig.key[0], + }; + + return struct { + const Self = @This(); + + fn inlineMapping(_: *anyopaque) bool { + return true; + } + + fn genJoin(_: *anyopaque, allocator: std.mem.Allocator, alias: []const u8) !_sql.RawQuery { + return switch (config) { + .direct => (.{ + .sql = try std.fmt.allocPrint(allocator, "LEFT JOIN \"" ++ toRepositoryConfig.table ++ "\" AS \"{s}\" ON " ++ + "\"" ++ fromRepositoryConfig.table ++ "\"." ++ foreignKey ++ " = \"{s}\"." ++ modelKey, .{alias, alias}), + .params = &[0]_sql.RawQueryParameter{}, + }), + + .reverse => (.{ + .sql = try std.fmt.allocPrint(allocator, "LEFT JOIN \"" ++ toRepositoryConfig.table ++ "\" AS \"{s}\" ON " ++ + "\"" ++ fromRepositoryConfig.table ++ "\"." ++ modelKey ++ " = \"{s}\"." ++ foreignKey, .{alias, alias}), + .params = &[0]_sql.RawQueryParameter{}, + }), + + .through => |through| (.{ + .sql = try std.fmt.allocPrint(allocator, "LEFT JOIN \"" ++ through.table ++ "\" AS \"{s}_pivot\" ON " ++ + "\"" ++ fromRepositoryConfig.table ++ "\"." ++ foreignKey ++ " = " ++ "\"{s}_pivot\"." ++ through.joinForeignKey ++ + "LEFT JOIN \"" ++ toRepositoryConfig.table ++ "\" AS \"{s}\" ON " ++ + "\"{s}_pivot\"." ++ through.joinModelKey ++ " = " ++ "\"{s}\"." ++ modelKey, .{alias, alias, alias, alias, alias}), + .params = &[0]_sql.RawQueryParameter{}, + }), + }; + } + + fn genSelect(_: *anyopaque, allocator: std.mem.Allocator, table: []const u8, prefix: []const u8) ![]const u8 { + return SelectBuilder.build(allocator, table, prefix); + } + + fn buildQuery(_: *anyopaque, opaqueModels: []const anyopaque, opaqueQuery: *anyopaque) !void { + var models: []const FromModel = undefined; + models.len = opaqueModels.len; + models.ptr = @ptrCast(@alignCast(opaqueModels.ptr)); + const query: *QueryType = @ptrCast(@alignCast(opaqueQuery)); + + // Prepare given models IDs. + const modelsIds = try query.arena.allocator().alloc(FromKeyType, models.len); + for (models, modelsIds) |model, *modelId| { + modelId.* = @field(model, fromRepositoryConfig.key[0]); + } + + switch (config) { + .direct => { + query.join((_sql.RawQuery{ + .sql = "INNER JOIN \"" ++ fromRepositoryConfig.table ++ "\" ON \"" ++ toRepositoryConfig.table ++ "\"." ++ modelKey ++ " = \"" ++ fromRepositoryConfig.table ++ "\"." ++ foreignKey, + .params = &[0]_sql.RawQueryParameter{}, + })); + + // Build WHERE condition. + try query.whereIn(FromKeyType, "\"" ++ fromRepositoryConfig.table ++ "\".\"" ++ fromRepositoryConfig.key[0] ++ "\"", modelsIds); + }, + .reverse => { + // Build WHERE condition. + try query.whereIn(FromKeyType, "\"" ++ toRepositoryConfig.table ++ "\".\"" ++ foreignKey ++ "\"", modelsIds); + }, + .through => |through| { + query.join(.{ + .sql = "INNER JOIN \"" ++ through.table ++ "\" ON " ++ + "\"" ++ toRepositoryConfig.table ++ "\"." ++ modelKey ++ " = " ++ "\"" ++ through.table ++ "\"." ++ through.joinModelKey, + .params = &[0]_sql.RawQueryParameter{}, + }); + + // Build WHERE condition. + try query.whereIn(FromKeyType, "\"" ++ through.table ++ "\".\"" ++ through.joinForeignKey ++ "\"", modelsIds); + }, + } + } + + pub fn relation(self: *Self) Relation { + return .{ + ._interface = .{ + .instance = self, + + .inlineMapping = inlineMapping, + .genJoin = genJoin, + .genSelect = genSelect, + .buildQuery = buildQuery, + }, + }; + } + }; +} + + +/// Generic model relation interface. +pub const Relation = struct { + const Self = @This(); + + _interface: struct { + instance: *anyopaque, + + inlineMapping: *const fn (self: *anyopaque) bool, + genJoin: *const fn (self: *anyopaque, allocator: std.mem.Allocator, alias: []const u8) anyerror!_sql.RawQuery, + genSelect: *const fn (self: *anyopaque, allocator: std.mem.Allocator, table: []const u8, prefix: []const u8) anyerror![]const u8, + buildQuery: *const fn (self: *anyopaque, models: []const anyopaque, query: *anyopaque) anyerror!void, + }, + + /// Relation mapping is done inline: this means that it's done at the same time the model is mapped, + /// and that the associated data will be retrieved in the main query. + pub fn inlineMapping(self: Self) bool { + return self._interface.inlineMapping(self._interface.instance); + } + + /// In case of inline mapping, generate a JOIN clause to retrieve the associated data. + pub fn genJoin(self: Self, allocator: std.mem.Allocator, alias: []const u8) !_sql.RawQuery { + return self._interface.genJoin(self._interface.instance, allocator, alias); + } + + /// Generate a SELECT clause to retrieve the associated data, with the given table and prefix. + pub fn genSelect(self: Self, allocator: std.mem.Allocator, table: []const u8, prefix: []const u8) ![]const u8 { + return self._interface.genSelect(self._interface.instance, allocator, table, prefix); + } + + /// Build the query to retrieve relation data. + /// Is always used when inline mapping is not possible, but also when loading relations lazily. + pub fn buildQuery(self: Self, models: []const anyopaque, query: *anyopaque) !void { + return self._interface.buildQuery(self._interface.instance, models, query); + } +}; + + +/// A model relation object. +pub const ModelRelation = struct { + relation: type, + field: []const u8, +}; + + +/// Structure of an eager loaded relation. +pub const Eager = struct { + /// Model field to fill for the relation. + field: []const u8, + /// The relation to eager load. + relation: Relation, + /// Subrelations to eager load. + with: []const Eager, +}; diff --git a/src/repository.zig b/src/repository.zig index f251a72..feb52c4 100644 --- a/src/repository.zig +++ b/src/repository.zig @@ -3,6 +3,7 @@ const zollections = @import("zollections"); const database = @import("database.zig"); const _sql = @import("sql.zig"); const _conditions = @import("conditions.zig"); +const _relations = @import("relations.zig"); const query = @import("query.zig"); const insert = @import("insert.zig"); const update = @import("update.zig"); @@ -73,17 +74,84 @@ pub fn ModelKeyType(comptime Model: type, comptime TableShape: type, comptime co } } +/// Model relations definition type. +pub fn RelationsDefinitionType(comptime rawDefinition: anytype) type { + const rawDefinitionType = @typeInfo(@TypeOf(rawDefinition)); + + // Build model relations fields. + var fields: [rawDefinitionType.Struct.fields.len]std.builtin.Type.StructField = undefined; + inline for (rawDefinitionType.Struct.fields, &fields) |originalField, *field| { + field.* = .{ + .name = originalField.name, + .type = _relations.ModelRelation, + .default_value = null, + .is_comptime = false, + .alignment = @alignOf(_relations.ModelRelation), + }; + } + + // Return built type. + return @Type(.{ + .Struct = std.builtin.Type.Struct{ + .layout = std.builtin.Type.ContainerLayout.auto, + .fields = &fields, + .decls = &[_]std.builtin.Type.Declaration{}, + .is_tuple = false, + }, + }); +} + /// Repository of structures of a certain type. -pub fn Repository(comptime Model: type, comptime TableShape: type, comptime config: RepositoryConfiguration(Model, TableShape)) type { +pub fn Repository(comptime Model: type, comptime TableShape: type, comptime repositoryConfig: RepositoryConfiguration(Model, TableShape)) type { return struct { const Self = @This(); + pub const ModelType = Model; + pub const TableType = TableShape; + pub const config = repositoryConfig; + pub const Query: type = query.RepositoryQuery(Model, TableShape, config); pub const Insert: type = insert.RepositoryInsert(Model, TableShape, config, config.insertShape); /// Type of one model key. pub const KeyType = ModelKeyType(Model, TableShape, config); + pub const relations = struct { + /// Make a "one to one" relation. + pub fn one(comptime toRepo: anytype, comptime oneConfig: _relations.OneConfiguration) type { + return _relations.one(Self, toRepo, oneConfig); + } + + /// Make a "one to many" or "many to many" relation. + pub fn many(comptime toRepo: anytype, comptime manyConfig: _relations.ManyConfiguration) type { + return _relations.many(Self, toRepo, manyConfig); + } + + /// Define a relations object for a repository. + pub fn define(rawDefinition: anytype) RelationsDefinitionType(rawDefinition) { + const rawDefinitionType = @TypeOf(rawDefinition); + + // Initialize final relations definition. + var definition: RelationsDefinitionType(rawDefinition) = undefined; + + // Check that the definition structure only include known fields. + inline for (std.meta.fieldNames(rawDefinitionType)) |fieldName| { + if (!@hasField(Model, fieldName)) { + @compileError("No corresponding field for relation " ++ fieldName); + } + + // Alter definition structure to add the field name. + @field(definition, fieldName) = .{ + .relation = @field(rawDefinition, fieldName), + .field = fieldName, + }; + } + + // Return altered definition structure. + return definition; + } + }; + pub fn InsertCustom(comptime InsertShape: type) type { return insert.RepositoryInsert(Model, TableShape, config, InsertShape); } @@ -102,76 +170,7 @@ pub fn Repository(comptime Model: type, comptime TableShape: type, comptime conf var modelQuery = Self.Query.init(allocator, connector, .{}); defer modelQuery.deinit(); - if (config.key.len == 1) { - // Find key name and its type. - const keyName = config.key[0]; - const keyType = std.meta.fields(TableShape)[std.meta.fieldIndex(TableShape, keyName).?].type; - - // Accept arrays / slices of keys, and simple keys. - switch (@typeInfo(@TypeOf(modelKey))) { - .Pointer => |ptr| { - switch (ptr.size) { - .One => { - switch (@typeInfo(ptr.child)) { - // Add a whereIn with the array. - .Array => { - if (ptr.child == u8) - // If the child is a string, use it as a simple value. - try modelQuery.whereValue(KeyType, keyName, "=", modelKey) - else - // Otherwise, use it as an array. - try modelQuery.whereIn(keyType, keyName, modelKey); - }, - // Add a simple condition with the pointed value. - else => try modelQuery.whereValue(keyType, keyName, "=", modelKey.*), - } - }, - // Add a whereIn with the slice. - else => { - if (ptr.child == u8) - // If the child is a string, use it as a simple value. - try modelQuery.whereValue(KeyType, keyName, "=", modelKey) - else - // Otherwise, use it as an array. - try modelQuery.whereIn(keyType, keyName, modelKey); - }, - } - }, - // Add a simple condition with the given value. - else => try modelQuery.whereValue(keyType, keyName, "=", modelKey), - } - } else { - // Accept arrays / slices of keys, and simple keys. - // Uniformize modelKey parameter to a slice. - const modelKeysList: []const Self.KeyType = switch (@typeInfo(@TypeOf(modelKey))) { - .Pointer => |ptr| switch (ptr.size) { - .One => switch (@typeInfo(ptr.child)) { - // Already an array. - .Array => @as([]const Self.KeyType, modelKey), - // Convert the pointer to an array. - else => &[1]Self.KeyType{@as(Self.KeyType, modelKey.*)}, - }, - // Already a slice. - else => @as([]const Self.KeyType, modelKey), - }, - // Convert the value to an array. - else => &[1]Self.KeyType{@as(Self.KeyType, modelKey)}, - }; - - // Initialize keys conditions list. - const conditions: []_sql.SqlParams = try allocator.alloc(_sql.SqlParams, modelKeysList.len); - defer allocator.free(conditions); - - // For each model key, add its conditions. - for (modelKeysList, conditions) |_modelKey, *condition| { - condition.* = try modelQuery.newCondition().@"and"( - &try buildCompositeKeysConditions(TableShape, config.key, modelQuery.newCondition(), _modelKey) - ); - } - - // Set WHERE conditions in the query with all keys conditions. - modelQuery.where(try modelQuery.newCondition().@"or"(conditions)); - } + try modelQuery.whereKey(modelKey); // Execute query and return its result. return try modelQuery.get(allocator); @@ -210,7 +209,7 @@ pub fn Repository(comptime Model: type, comptime TableShape: type, comptime conf updateQuery.returningAll(); // Initialize conditions array. - var conditions: [config.key.len]_sql.SqlParams = undefined; + var conditions: [config.key.len]_sql.RawQuery = undefined; inline for (config.key, &conditions) |keyName, *condition| { // Add a where condition for each key. condition.* = try updateQuery.newCondition().value(@TypeOf(@field(modelSql, keyName)), keyName, "=", @field(modelSql, keyName)); @@ -271,24 +270,3 @@ pub fn RepositoryResult(comptime Model: type) type { } }; } - -/// Build conditions for given composite keys, with a model key structure. -pub fn buildCompositeKeysConditions(comptime TableShape: type, comptime keys: []const []const u8, conditionsBuilder: _conditions.Builder, modelKey: anytype) ![keys.len]_sql.SqlParams { - // Conditions list for all keys in the composite key. - var conditions: [keys.len]_sql.SqlParams = undefined; - - inline for (keys, &conditions) |keyName, *condition| { - const keyType = std.meta.fields(TableShape)[std.meta.fieldIndex(TableShape, keyName).?].type; - - if (std.meta.fieldIndex(@TypeOf(modelKey), keyName)) |_| { - // The field exists in the key structure, create its condition. - condition.* = try conditionsBuilder.value(keyType, keyName, "=", @field(modelKey, keyName)); - } else { - // The field doesn't exist, compilation error. - @compileError("The key structure must include a field for " ++ keyName); - } - } - - // Return conditions for the current model key. - return conditions; -} diff --git a/src/root.zig b/src/root.zig index f26ccae..11fd096 100644 --- a/src/root.zig +++ b/src/root.zig @@ -12,8 +12,10 @@ pub const RepositoryResult = repository.RepositoryResult; pub const Insertable = insert.Insertable; -pub const QueryParameter = _sql.QueryParameter; -pub const SqlParams = _sql.SqlParams; +pub const relations = @import("relations.zig"); + +pub const RawQueryParameter = _sql.RawQueryParameter; +pub const RawQuery = _sql.RawQuery; pub const database = @import("database.zig"); pub const Session = session.Session; diff --git a/src/sql.zig b/src/sql.zig index 9a4d4f1..3d50bae 100644 --- a/src/sql.zig +++ b/src/sql.zig @@ -2,9 +2,67 @@ const std = @import("std"); const errors = @import("errors.zig"); /// A structure with SQL and its parameters. -pub const SqlParams = struct { +pub const RawQuery = struct { + const Self = @This(); + sql: []const u8, - params: []const QueryParameter, + params: []const RawQueryParameter, + + /// Build an SQL query with all the given query parts, separated by a space. + pub fn fromConcat(allocator: std.mem.Allocator, queries: []const RawQuery) !Self { + // Allocate an array with all SQL queries. + const queriesSql = try allocator.alloc([]const u8, queries.len); + defer allocator.free(queriesSql); + + // Allocate an array with all parameters arrays. + const queriesParams = try allocator.alloc([]const RawQueryParameter, queries.len); + defer allocator.free(queriesSql); + + // Fill SQL queries and parameters arrays. + for (queries, queriesSql, queriesParams) |_query, *_querySql, *_queryParam| { + _querySql.* = _query.sql; + _queryParam.* = _query.params; + } + + // Build final query with its parameters. + return Self{ + .sql = try std.mem.join(allocator, " ", queriesSql), + .params = try std.mem.concat(allocator, RawQueryParameter, queriesParams), + }; + } + + /// Build a full SQL query with numbered parameters. + pub fn build(self: Self, allocator: std.mem.Allocator) ![]u8 { + if (self.params.len <= 0) { + // No parameters, just copy SQL. + return allocator.dupe(u8, self.sql); + } else { + // Copy SQL and replace '?' by numbered parameters. + const sqlSize = self.sql.len + computeRequiredSpaceForNumbers(self.params.len); + var sqlBuf = try std.ArrayList(u8).initCapacity(allocator, sqlSize); + defer sqlBuf.deinit(); + + // Parameter counter. + var currentParameter: usize = 1; + + for (self.sql) |char| { + // Copy each character but '?', replaced by the current parameter string. + + if (char == '?') { + // Copy the parameter string in place of '?'. + try sqlBuf.writer().print("${d}", .{currentParameter}); + // Increment parameter count. + currentParameter += 1; + } else { + // Simply pass the current character. + try sqlBuf.append(char); + } + } + + // Return the built SQL query. + return sqlBuf.toOwnedSlice(); + } + } }; /// Generate parameters SQL in the form of "?,?,?,?" @@ -21,6 +79,143 @@ pub fn generateParametersSql(allocator: std.mem.Allocator, parametersCount: u64) return sql; } +/// Compute required string size of numbers for the given parameters count. +pub fn computeRequiredSpaceForNumbers(parametersCount: usize) usize { + var numbersSize: usize = 0; // Initialize the required size. + var remaining = parametersCount; // Initialize the remaining parameters to count. + var currentSliceSize: usize = 9; // Initialize the first slice size of numbers. + var i: usize = 1; // Initialize the current slice count. + + while (remaining > 0) { + // Compute the count of numbers in the current slice. + const numbersCount = @min(remaining, currentSliceSize); + // Add the required string size of all numbers in this slice. + numbersSize += i * numbersCount; + // Subtract the counted numbers in this current slice. + remaining -= numbersCount; + // Move to the next slice. + i += 1; + currentSliceSize *= 10; + } + + // Return the computed numbers size. + return numbersSize; +} + +/// Compute required string size for the given parameter number. +pub fn computeRequiredSpaceForParameter(parameterNumber: usize) !usize { + var i: usize = 1; + while (parameterNumber >= try std.math.powi(usize, 10, i)) { + i += 1; + } + return i; +} + +/// A query parameter. +pub const RawQueryParameter = union(enum) { + string: []const u8, + integer: i64, + number: f64, + bool: bool, + null: void, + + /// Convert any value to a query parameter. + pub fn fromValue(value: anytype) errors.ZrmError!RawQueryParameter { + // Get given value type. + const valueType = @typeInfo(@TypeOf(value)); + + return switch (valueType) { + .Int, .ComptimeInt => return .{ .integer = @intCast(value), }, + .Float, .ComptimeFloat => return .{ .number = @floatCast(value), }, + .Bool => return .{ .bool = value, }, + .Null => return .{ .null = true, }, + .Pointer => |pointer| { + if (pointer.size == .One) { + // Get pointed value. + return RawQueryParameter.fromValue(value.*); + } else { + // Can only take an array of u8 (= string). + if (pointer.child == u8) { + return .{ .string = value }; + } else { + return errors.ZrmError.UnsupportedTableType; + } + } + }, + .Enum, .EnumLiteral => { + return .{ .string = @tagName(value) }; + }, + .Optional => { + if (value) |val| { + // The optional value is defined, use it as a query parameter. + return RawQueryParameter.fromValue(val); + } else { + // If an optional value is not defined, set it to NULL. + return .{ .null = true }; + } + }, + else => return errors.ZrmError.UnsupportedTableType + }; + } +}; + + +/// SELECT query part builder for a given table. +pub fn SelectBuilder(comptime TableShape: type) type { + // Get fields count in the table shape. + const columnsCount = @typeInfo(TableShape).Struct.fields.len; + + // Sum of lengths of all selected columns formats. + var _selectColumnsLength = 0; + + const selectColumns = comptime select: { + // Initialize the select columns array. + var _select: [columnsCount][]const u8 = undefined; + + // For each field, generate a format string. + for (@typeInfo(TableShape).Struct.fields, &_select) |field, *columnSelect| { + // Select the current field column. + columnSelect.* = "\"{s}\".\"" ++ field.name ++ "\" AS \"{s}" ++ field.name ++ "\""; + _selectColumnsLength = _selectColumnsLength + columnSelect.len; + } + + break :select _select; + }; + + // Export computed select columns length. + const selectColumnsLength = _selectColumnsLength; + + return struct { + /// Build a SELECT query part for a given table, renaming columns with the given prefix. + pub fn build(allocator: std.mem.Allocator, table: []const u8, prefix: []const u8) ![]const u8 { + // Initialize full select string with precomputed size. + var fullSelect = try std.ArrayList(u8).initCapacity(allocator, + selectColumnsLength // static SQL size. + + columnsCount*(table.len - 2 + prefix.len - 2) // replacing %s and %s by table and prefix. + + (columnsCount - 1) * 2 // ", " + ); + defer fullSelect.deinit(); + + var first = true; + inline for (selectColumns) |columnSelect| { + // Add ", " between all selected columns. + if (first) { + first = false; + } else { + try fullSelect.appendSlice(", "); + } + + try fullSelect.writer().print(columnSelect, .{table, prefix}); + } + + return fullSelect.toOwnedSlice(); // Return built full select. + } + }; +} + + + + /// Compute required string size of numbers for the given parameters count, with taking in account the already used parameters numbers. pub fn computeRequiredSpaceForParametersNumbers(parametersCount: usize, alreadyUsedParameters: usize) usize { var remainingUsedParameters = alreadyUsedParameters; // Initialize the count of used parameters to mark as taken. @@ -55,149 +250,25 @@ pub fn computeRequiredSpaceForParametersNumbers(parametersCount: usize, alreadyU return numbersSize; } -/// Compute required string size of numbers for the given parameters count. -pub fn computeRequiredSpaceForNumbers(parametersCount: usize) usize { - var numbersSize: usize = 0; // Initialize the required size. - var remaining = parametersCount; // Initialize the remaining parameters to count. - var currentSliceSize: usize = 9; // Initialize the first slice size of numbers. - var i: usize = 1; // Initialize the current slice count. - - while (remaining > 0) { - // Compute the count of numbers in the current slice. - const numbersCount = @min(remaining, currentSliceSize); - // Add the required string size of all numbers in this slice. - numbersSize += i * numbersCount; - // Subtract the counted numbers in this current slice. - remaining -= numbersCount; - // Move to the next slice. - i += 1; - currentSliceSize *= 10; - } - - // Return the computed numbers size. - return numbersSize; -} - -/// Compute required string size for the given parameter number. -pub fn computeRequiredSpaceForParameter(parameterNumber: usize) !usize { - var i: usize = 1; - while (parameterNumber >= try std.math.powi(usize, 10, i)) { - i += 1; - } - return i; -} - -pub fn copyAndReplaceSqlParameters(currentParameter: *usize, parametersCount: usize, dest: []u8, source: []const u8) !void { +/// Copy the given source query and replace '?' parameters by numbered parameters. +pub fn copyAndReplaceSqlParameters(currentParameter: *usize, parametersCount: usize, writer: std.ArrayList(u8).Writer, source: []const u8) !void { // If there are no parameters, just copy source SQL. if (parametersCount <= 0) { - std.mem.copyForwards(u8, dest, source); + try writer.writeAll(source); + return; } - // Current dest cursor. - var destCursor: usize = 0; - for (source) |char| { // Copy each character but '?', replaced by the current parameter string. if (char == '?') { - // Create the parameter string. - const paramSize = 1 + try computeRequiredSpaceForParameter(currentParameter.*); // Copy the parameter string in place of '?'. - _ = try std.fmt.bufPrint(dest[destCursor..destCursor+paramSize], "${d}", .{currentParameter.*}); - // Add parameter string length to the current query cursor. - destCursor += paramSize; + try writer.print("${d}", .{currentParameter.*}); // Increment parameter count. currentParameter.* += 1; } else { - // Simply pass the current character. - dest[destCursor] = char; - destCursor += 1; + // Simply write the current character. + try writer.writeByte(char); } } } - -pub fn numberSqlParameters(sql: []const u8, comptime parametersCount: usize) [sql.len + computeRequiredSpaceForNumbers(parametersCount)]u8 { - // If there are no parameters, just return built SQL. - if (parametersCount <= 0) { - return @as([sql.len]u8, sql[0..sql.len].*); - } - - // New query buffer. - var query: [sql.len + computeRequiredSpaceForNumbers(parametersCount)]u8 = undefined; - - // Current query cursor. - var queryCursor: usize = 0; - // Current parameter count. - var currentParameter: usize = 1; - - for (sql) |char| { - // Copy each character but '?', replaced by the current parameter string. - - if (char == '?') { - var buffer: [computeRequiredSpaceForParameter(currentParameter)]u8 = undefined; - // Create the parameter string. - const paramStr = try std.fmt.bufPrint(&buffer, "${d}", .{currentParameter}); - // Copy the parameter string in place of '?'. - @memcpy(query[queryCursor..(queryCursor + paramStr.len)], paramStr); - // Add parameter string length to the current query cursor. - queryCursor += paramStr.len; - // Increment parameter count. - currentParameter += 1; - } else { - // Simply pass the current character. - query[queryCursor] = char; - queryCursor += 1; - } - } - - // Return built query. - return query; -} - -/// A query parameter. -pub const QueryParameter = union(enum) { - string: []const u8, - integer: i64, - number: f64, - bool: bool, - null: void, - - /// Convert any value to a query parameter. - pub fn fromValue(value: anytype) errors.ZrmError!QueryParameter { - // Get given value type. - const valueType = @typeInfo(@TypeOf(value)); - - return switch (valueType) { - .Int, .ComptimeInt => return .{ .integer = @intCast(value), }, - .Float, .ComptimeFloat => return .{ .number = @floatCast(value), }, - .Bool => return .{ .bool = value, }, - .Null => return .{ .null = true, }, - .Pointer => |pointer| { - if (pointer.size == .One) { - // Get pointed value. - return QueryParameter.fromValue(value.*); - } else { - // Can only take an array of u8 (= string). - if (pointer.child == u8) { - return .{ .string = value }; - } else { - return errors.ZrmError.UnsupportedTableType; - } - } - }, - .Enum, .EnumLiteral => { - return .{ .string = @tagName(value) }; - }, - .Optional => { - if (value) |val| { - // The optional value is defined, use it as a query parameter. - return QueryParameter.fromValue(val); - } else { - // If an optional value is not defined, set it to NULL. - return .{ .null = true }; - } - }, - else => return errors.ZrmError.UnsupportedTableType - }; - } -}; diff --git a/src/update.zig b/src/update.zig index a24a2c9..920df26 100644 --- a/src/update.zig +++ b/src/update.zig @@ -12,8 +12,8 @@ const repository = @import("repository.zig"); pub fn RepositoryUpdateConfiguration(comptime UpdateShape: type) type { return struct { value: ?UpdateShape = null, - where: ?_sql.SqlParams = null, - returning: ?_sql.SqlParams = null, + where: ?_sql.RawQuery = null, + returning: ?_sql.RawQuery = null, }; } @@ -113,7 +113,7 @@ pub fn RepositoryUpdate(comptime Model: type, comptime TableShape: type, comptim } /// Set WHERE conditions. - pub fn where(self: *Self, _where: _sql.SqlParams) void { + pub fn where(self: *Self, _where: _sql.RawQuery) void { self.updateConfig.where = _where; } @@ -144,7 +144,7 @@ pub fn RepositoryUpdate(comptime Model: type, comptime TableShape: type, comptim } /// Set selected columns for RETURNING clause. - pub fn returning(self: *Self, _select: _sql.SqlParams) void { + pub fn returning(self: *Self, _select: _sql.RawQuery) void { self.updateConfig.returning = _select; } @@ -157,7 +157,7 @@ pub fn RepositoryUpdate(comptime Model: type, comptime TableShape: type, comptim self.returning(.{ // Join selected columns. .sql = std.mem.join(self.arena.allocator(), ", ", _select), - .params = &[_]_sql.QueryParameter{}, // No parameters. + .params = &[_]_sql.RawQueryParameter{}, // No parameters. }); } @@ -165,7 +165,7 @@ pub fn RepositoryUpdate(comptime Model: type, comptime TableShape: type, comptim pub fn returningAll(self: *Self) void { self.returning(.{ .sql = "*", - .params = &[_]_sql.QueryParameter{}, // No parameters. + .params = &[_]_sql.RawQueryParameter{}, // No parameters. }); } @@ -203,61 +203,53 @@ pub fn RepositoryUpdate(comptime Model: type, comptime TableShape: type, comptim } // Allocate SQL buffer from computed size. - const sqlBuf = try self.arena.allocator().alloc(u8, fixedSqlSize + var sqlBuf = try std.ArrayList(u8).initCapacity(self.arena.allocator(), fixedSqlSize + (setSize) + (whereSize) + (returningSize) ); - - // Fill SQL buffer. + defer sqlBuf.deinit(); // Restart parameter counter at 1. currentParameter = 1; // SQL query initialisation. - @memcpy(sqlBuf[0..sqlBase.len], sqlBase); - var sqlBufCursor: usize = sqlBase.len; + try sqlBuf.appendSlice(sqlBase); // Add SET columns values. inline for (columns) |column| { // Create the SET string and append it to the SQL buffer. - const setColumnSize = column.len + 1 + 1 + try _sql.computeRequiredSpaceForParameter(currentParameter) + 1; - _ = try std.fmt.bufPrint(sqlBuf[sqlBufCursor..sqlBufCursor+setColumnSize], "{s}=${d},", .{column, currentParameter}); - sqlBufCursor += setColumnSize; + try sqlBuf.writer().print("{s}=${d},", .{column, currentParameter}); // Increment parameter count. currentParameter += 1; } // Overwrite the last ','. - sqlBufCursor -= 1; + _ = sqlBuf.pop(); // WHERE clause. if (self.updateConfig.where) |_where| { - @memcpy(sqlBuf[sqlBufCursor..sqlBufCursor+(1 + whereClause.len + 1)], " " ++ whereClause ++ " "); + try sqlBuf.appendSlice(" " ++ whereClause ++ " "); // Copy WHERE clause content and replace parameters, if there are some. try _sql.copyAndReplaceSqlParameters(¤tParameter, - _where.params.len, - sqlBuf[sqlBufCursor+(1+whereClause.len+1)..sqlBufCursor+whereSize], _where.sql + _where.params.len, sqlBuf.writer(), _where.sql ); - sqlBufCursor += whereSize; } // Append RETURNING clause, if there is one defined. if (self.updateConfig.returning) |_returning| { - @memcpy(sqlBuf[sqlBufCursor..sqlBufCursor+(1 + returningClause.len + 1)], " " ++ returningClause ++ " "); + try sqlBuf.appendSlice(" " ++ returningClause ++ " "); // Copy RETURNING clause content and replace parameters, if there are some. try _sql.copyAndReplaceSqlParameters(¤tParameter, - _returning.params.len, - sqlBuf[sqlBufCursor+(1+returningClause.len+1)..sqlBufCursor+returningSize], _returning.sql + _returning.params.len, sqlBuf.writer(), _returning.sql ); - sqlBufCursor += returningSize; } // ";" to end the query. - sqlBuf[sqlBufCursor] = ';'; sqlBufCursor += 1; + try sqlBuf.append(';'); // Save built SQL query. - self.sql = sqlBuf; + self.sql = try sqlBuf.toOwnedSlice(); } /// Execute the update query. diff --git a/tests/initdb.sql b/tests/initdb.sql index 844873d..b039b77 100644 --- a/tests/initdb.sql +++ b/tests/initdb.sql @@ -2,18 +2,29 @@ DROP SCHEMA public CASCADE; CREATE SCHEMA public; --- Create default models table. +-- Create models table. CREATE TABLE models ( id SERIAL PRIMARY KEY, name VARCHAR NOT NULL, amount NUMERIC(12, 2) NOT NULL ); +-- Create submodels table. +CREATE TABLE submodels ( + uuid UUID PRIMARY KEY, + label VARCHAR NOT NULL, + parent_id INT NULL, + FOREIGN KEY (parent_id) REFERENCES models ON DELETE RESTRICT ON UPDATE CASCADE +); +CREATE INDEX submodels_parent_id_index ON submodels(parent_id); + -- Insert default data. INSERT INTO models(name, amount) VALUES ('test', 50); INSERT INTO models(name, amount) VALUES ('updatable', 33.12); +INSERT INTO submodels(uuid, label, parent_id) VALUES ('f6868a5b-2efc-455f-b76e-872df514404f', 'test', 1); +INSERT INTO submodels(uuid, label, parent_id) VALUES ('013ef171-9781-40e9-b843-f6bc11890070', 'another', 1); --- Create default composite models table. +-- Create composite models table. CREATE TABLE composite_models ( firstcol SERIAL NOT NULL, secondcol VARCHAR NOT NULL, diff --git a/tests/query.zig b/tests/query.zig index f7e9ef3..5e27e3d 100644 --- a/tests/query.zig +++ b/tests/query.zig @@ -39,9 +39,9 @@ test "zrm.conditions combined" { var arena = std.heap.ArenaAllocator.init(std.testing.allocator); defer arena.deinit(); - const condition = try zrm.conditions.@"and"(arena.allocator(), &[_]zrm.SqlParams{ + const condition = try zrm.conditions.@"and"(arena.allocator(), &[_]zrm.RawQuery{ try zrm.conditions.value(usize, arena.allocator(), "test", "=", 5), - try zrm.conditions.@"or"(arena.allocator(), &[_]zrm.SqlParams{ + try zrm.conditions.@"or"(arena.allocator(), &[_]zrm.RawQuery{ try zrm.conditions.in(usize, arena.allocator(), "intest", &[_]usize{2, 3, 8}), try zrm.conditions.column(arena.allocator(), "firstcol", "<>", "secondcol"), }), diff --git a/tests/relations.zig b/tests/relations.zig new file mode 100644 index 0000000..5d40707 --- /dev/null +++ b/tests/relations.zig @@ -0,0 +1,54 @@ +const std = @import("std"); +const pg = @import("pg"); +const zrm = @import("zrm"); +const repository = @import("repository.zig"); + +/// PostgreSQL database connection. +var database: *pg.Pool = undefined; + +/// Initialize database connection. +fn initDatabase(allocator: std.mem.Allocator) !void { + database = try pg.Pool.init(allocator, .{ + .connect = .{ + .host = "localhost", + .port = 5432, + }, + .auth = .{ + .username = "zrm", + .password = "zrm", + .database = "zrm", + }, + .size = 1, + }); +} + +test "belongsTo" { + zrm.setDebug(true); + + try initDatabase(std.testing.allocator); + defer database.deinit(); + var poolConnector = zrm.database.PoolConnector{ + .pool = database, + }; + + // Build a query of submodels. + var myQuery = repository.MySubmodelRepository.Query.init(std.testing.allocator, poolConnector.connector(), .{}); + defer myQuery.deinit(); + // Retrieve parents of submodels from relation. + try myQuery.with(repository.MySubmodelRelations.parent); + + try myQuery.buildSql(); + + // Get query result. + var result = try myQuery.get(std.testing.allocator); + defer result.deinit(); + + // Checking result. + try std.testing.expectEqual(2, result.models.len); + try std.testing.expectEqual(1, result.models[0].parent_id); + try std.testing.expectEqual(1, result.models[1].parent_id); + try std.testing.expectEqual(repository.MyModel, @TypeOf(result.models[0].parent.?)); + try std.testing.expectEqual(repository.MyModel, @TypeOf(result.models[1].parent.?)); + try std.testing.expectEqual(1, result.models[0].parent.?.id); + try std.testing.expectEqual(1, result.models[1].parent.?.id); +} diff --git a/tests/repository.zig b/tests/repository.zig index 2e90be9..4735d95 100644 --- a/tests/repository.zig +++ b/tests/repository.zig @@ -26,9 +26,56 @@ const MySubmodel = struct { uuid: []const u8, label: []const u8, + parent_id: ?i32 = null, parent: ?MyModel = null, }; +const MySubmodelTable = struct { + uuid: []const u8, + label: []const u8, + + parent_id: ?i32 = null, +}; + +// Convert an SQL row to a model. +fn submodelFromSql(raw: MySubmodelTable) !MySubmodel { + return .{ + .uuid = raw.uuid, + .label = raw.label, + .parent_id = raw.parent_id, + }; +} + +/// Convert a model to an SQL row. +fn submodelToSql(model: MySubmodel) !MySubmodelTable { + return .{ + .uuid = model.uuid, + .label = model.label, + .parent_id = model.parent_id, + }; +} + +/// Declare a model repository. +pub const MySubmodelRepository = zrm.Repository(MySubmodel, MySubmodelTable, .{ + .table = "submodels", + + // Insert shape used by default for inserts in the repository. + .insertShape = MySubmodelTable, + + .key = &[_][]const u8{"uuid"}, + + .fromSql = &submodelFromSql, + .toSql = &submodelToSql, +}); + +pub const MySubmodelRelations = MySubmodelRepository.relations.define(.{ + .parent = MySubmodelRepository.relations.one(MyModelRepository, .{ + .direct = .{ + .foreignKey = "parent_id", + }, + }), +}); + /// An example model. pub const MyModel = struct { id: i32, @@ -79,6 +126,14 @@ pub const MyModelRepository = zrm.Repository(MyModel, MyModelTable, .{ .toSql = &modelToSql, }); +pub const MyModelRelations = MyModelRepository.relations.define(.{ + .submodels = MyModelRepository.relations.many(MySubmodelRepository, .{ + .direct = .{ + .foreignKey = "parent_id", + } + }), +}); + test "model structures" { // Initialize a test model. @@ -115,7 +170,7 @@ test "repository query SQL builder" { try query.whereIn(usize, "id", &[_]usize{1, 2}); try query.buildSql(); - const expectedSql = "SELECT * FROM models WHERE id IN ($1,$2);"; + const expectedSql = "SELECT \"models\".* FROM \"models\" WHERE id IN ($1,$2);"; try std.testing.expectEqual(expectedSql.len, query.sql.?.len); try std.testing.expectEqualStrings(expectedSql, query.sql.?); } @@ -138,7 +193,7 @@ test "repository element retrieval" { try query.buildSql(); // Check built SQL. - const expectedSql = "SELECT * FROM models WHERE id = $1;"; + const expectedSql = "SELECT \"models\".* FROM \"models\" WHERE id = $1;"; try std.testing.expectEqual(expectedSql.len, query.sql.?.len); try std.testing.expectEqualStrings(expectedSql, query.sql.?); @@ -166,11 +221,11 @@ test "repository complex SQL query" { var query = MyModelRepository.Query.init(std.testing.allocator, poolConnector.connector(), .{}); defer query.deinit(); query.where( - try query.newCondition().@"or"(&[_]zrm.SqlParams{ + try query.newCondition().@"or"(&[_]zrm.RawQuery{ try query.newCondition().value(usize, "id", "=", 1), - try query.newCondition().@"and"(&[_]zrm.SqlParams{ + try query.newCondition().@"and"(&[_]zrm.RawQuery{ try query.newCondition().in(usize, "id", &[_]usize{100000, 200000, 300000}), - try query.newCondition().@"or"(&[_]zrm.SqlParams{ + try query.newCondition().@"or"(&[_]zrm.RawQuery{ try query.newCondition().value(f64, "amount", ">", 12.13), try query.newCondition().value([]const u8, "name", "=", "test"), }) @@ -179,7 +234,7 @@ test "repository complex SQL query" { ); try query.buildSql(); - const expectedSql = "SELECT * FROM models WHERE (id = $1 OR (id IN ($2,$3,$4) AND (amount > $5 OR name = $6)));"; + const expectedSql = "SELECT \"models\".* FROM \"models\" WHERE (id = $1 OR (id IN ($2,$3,$4) AND (amount > $5 OR name = $6)));"; try std.testing.expectEqual(expectedSql.len, query.sql.?.len); try std.testing.expectEqualStrings(expectedSql, query.sql.?); diff --git a/tests/root.zig b/tests/root.zig index 0fbd1c7..343be44 100644 --- a/tests/root.zig +++ b/tests/root.zig @@ -5,4 +5,5 @@ comptime { _ = @import("repository.zig"); _ = @import("composite.zig"); _ = @import("sessions.zig"); + _ = @import("relations.zig"); }