diff --git a/src/postgresql.zig b/src/postgresql.zig index afbde25..bd3a896 100644 --- a/src/postgresql.zig +++ b/src/postgresql.zig @@ -37,7 +37,12 @@ pub fn handlePostgresqlError(err: anyerror, connection: *database.Connection, st defer statement.deinit(); defer connection.release(); - if (connection.connection.err) |sqlErr| { + return handleRawPostgresqlError(err, connection.connection); +} + +/// PostgreSQL raw error handling by ZRM. +pub fn handleRawPostgresqlError(err: anyerror, connection: *pg.Conn) anyerror { + if (connection.err) |sqlErr| { if (global.debugMode) { // If debug mode is enabled, show the PostgreSQL error. std.debug.print("PostgreSQL error\n{s}: {s}\n", .{sqlErr.code, sqlErr.message}); diff --git a/src/root.zig b/src/root.zig index c673b7c..f26ccae 100644 --- a/src/root.zig +++ b/src/root.zig @@ -1,4 +1,5 @@ const global = @import("global.zig"); +const session = @import("session.zig"); const repository = @import("repository.zig"); const insert = @import("insert.zig"); const _sql = @import("sql.zig"); @@ -15,6 +16,7 @@ pub const QueryParameter = _sql.QueryParameter; pub const SqlParams = _sql.SqlParams; pub const database = @import("database.zig"); +pub const Session = session.Session; pub const conditions = @import("conditions.zig"); diff --git a/src/session.zig b/src/session.zig new file mode 100644 index 0000000..043cd97 --- /dev/null +++ b/src/session.zig @@ -0,0 +1,122 @@ +const std = @import("std"); +const pg = @import("pg"); +const postgresql = @import("postgresql.zig"); +const database = @import("database.zig"); + +/// Session for multiple repository operations. +pub const Session = struct { + const Self = @This(); + + _database: *pg.Pool, + + /// The active connection for the session. + connection: *pg.Conn, + + /// Execute a comptime-known SQL command for the current session. + fn exec(self: Self, comptime sql: []const u8) !void { + _ = self.connection.exec(sql, .{}) catch |err| { + return postgresql.handleRawPostgresqlError(err, self.connection); + }; + } + + /// Begin a new transaction. + pub fn beginTransaction(self: Self) !void { + try self.exec("BEGIN;"); + } + + /// Rollback the current transaction. + pub fn rollbackTransaction(self: Self) !void { + try self.exec("ROLLBACK;"); + } + + /// Commit the current transaction. + pub fn commitTransaction(self: Self) !void { + try self.exec("COMMIT;"); + } + + /// Create a new savepoint with the given name. + pub fn savepoint(self: Self, comptime _savepoint: []const u8) !void { + try self.exec("SAVEPOINT " ++ _savepoint ++ ";"); + } + + /// Rollback to the savepoint with the given name. + pub fn rollbackTo(self: Self, comptime _savepoint: []const u8) !void { + try self.exec("ROLLBACK TO " ++ _savepoint ++ ";"); + } + + /// Initialize a new session. + pub fn init(_database: *pg.Pool) !Session { + return .{ + ._database = _database, + .connection = try _database.acquire(), + }; + } + + /// Deinitialize the session. + pub fn deinit(self: *Self) void { + self.connection.release(); + } + + /// Get a database connector instance for the current session. + pub fn connector(self: *Self) database.Connector { + return database.Connector{ + ._interface = .{ + .instance = self, + .getConnection = getConnection, + }, + }; + } + + // Connector implementation. + + /// Get the current connection. + fn getConnection(opaqueSelf: *anyopaque) !*database.Connection { + const self: *Self = @ptrCast(@alignCast(opaqueSelf)); + + // Initialize a new connection. + const sessionConnection = try self._database._allocator.create(SessionConnection); + sessionConnection.* = .{ + .session = self, + }; + + return try sessionConnection.connection(); + } +}; + +fn noRelease(_: *anyopaque) void {} + +/// A session connection. +const SessionConnection = struct { + const Self = @This(); + + /// Session of the connection. + session: *Session, + + /// Connection instance, to only keep one at a time. + _connection: ?database.Connection = null, + + /// Get a database connection. + pub fn connection(self: *Self) !*database.Connection { + if (self._connection == null) { + // A new connection needs to be initialized. + self._connection = .{ + .connection = self.session.connection, + ._interface = .{ + .instance = self, + .release = releaseConnection, + }, + }; + } + + return &(self._connection.?); + } + + // Implementation. + + /// Free the current connection (doesn't actually release the connection, as it is required to stay the same all along the session). + fn releaseConnection(self: *database.Connection) void { + // Free allocated connection. + const sessionConnection: *SessionConnection = @ptrCast(@alignCast(self._interface.instance)); + sessionConnection.session._database._allocator.destroy(sessionConnection); + } +}; diff --git a/tests/repository.zig b/tests/repository.zig index 199649b..2e90be9 100644 --- a/tests/repository.zig +++ b/tests/repository.zig @@ -30,7 +30,7 @@ const MySubmodel = struct { }; /// An example model. -const MyModel = struct { +pub const MyModel = struct { id: i32, name: []const u8, amount: f64, @@ -39,7 +39,7 @@ const MyModel = struct { }; /// SQL table shape of the example model. -const MyModelTable = struct { +pub const MyModelTable = struct { id: i32, name: []const u8, amount: f64, @@ -64,7 +64,7 @@ fn modelToSql(model: MyModel) !MyModelTable { } /// Declare a model repository. -const MyModelRepository = zrm.Repository(MyModel, MyModelTable, .{ +pub const MyModelRepository = zrm.Repository(MyModel, MyModelTable, .{ .table = "models", // Insert shape used by default for inserts in the repository. diff --git a/tests/root.zig b/tests/root.zig index bed1d89..0fbd1c7 100644 --- a/tests/root.zig +++ b/tests/root.zig @@ -4,4 +4,5 @@ comptime { _ = @import("query.zig"); _ = @import("repository.zig"); _ = @import("composite.zig"); + _ = @import("sessions.zig"); } diff --git a/tests/sessions.zig b/tests/sessions.zig new file mode 100644 index 0000000..c561be8 --- /dev/null +++ b/tests/sessions.zig @@ -0,0 +1,134 @@ +const std = @import("std"); +const pg = @import("pg"); +const zrm = @import("zrm"); +const repository = @import("repository.zig"); + +/// PostgreSQL database connection. +var database: *pg.Pool = undefined; + +/// Initialize database connection. +fn initDatabase(allocator: std.mem.Allocator) !void { + database = try pg.Pool.init(allocator, .{ + .connect = .{ + .host = "localhost", + .port = 5432, + }, + .auth = .{ + .username = "zrm", + .password = "zrm", + .database = "zrm", + }, + .size = 1, + }); +} + +test "session with rolled back transaction and savepoint" { + zrm.setDebug(true); + + // Initialize database. + try initDatabase(std.testing.allocator); + defer database.deinit(); + + // Start a new session and perform operations in a transaction. + var session = try zrm.Session.init(database); + defer session.deinit(); + + + try session.beginTransaction(); + + + // First UPDATE in the transaction. + { + var firstUpdate = repository.MyModelRepository.Update(struct { + name: []const u8, + }).init(std.testing.allocator, session.connector()); + defer firstUpdate.deinit(); + try firstUpdate.set(.{ + .name = "tempname", + }); + try firstUpdate.whereValue(usize, "id", "=", 1); + var firstUpdateResult = try firstUpdate.update(std.testing.allocator); + firstUpdateResult.deinit(); + } + + + // Set a savepoint. + try session.savepoint("my_savepoint"); + + + // Second UPDATE in the transaction. + { + var secondUpdate = repository.MyModelRepository.Update(struct { + amount: f64, + }).init(std.testing.allocator, session.connector()); + defer secondUpdate.deinit(); + try secondUpdate.set(.{ + .amount = 52.25, + }); + try secondUpdate.whereValue(usize, "id", "=", 1); + var secondUpdateResult = try secondUpdate.update(std.testing.allocator); + secondUpdateResult.deinit(); + } + + // SELECT before rollback to savepoint in the transaction. + { + var queryBeforeRollbackToSavepoint = repository.MyModelRepository.Query.init(std.testing.allocator, session.connector(), .{}); + try queryBeforeRollbackToSavepoint.whereValue(usize, "id", "=", 1); + defer queryBeforeRollbackToSavepoint.deinit(); + + // Get models. + var resultBeforeRollbackToSavepoint = try queryBeforeRollbackToSavepoint.get(std.testing.allocator); + defer resultBeforeRollbackToSavepoint.deinit(); + + // Check that one model has been retrieved, then check its type and values. + try std.testing.expectEqual(1, resultBeforeRollbackToSavepoint.models.len); + try std.testing.expectEqual(repository.MyModel, @TypeOf(resultBeforeRollbackToSavepoint.models[0].*)); + try std.testing.expectEqual(1, resultBeforeRollbackToSavepoint.models[0].id); + try std.testing.expectEqualStrings("tempname", resultBeforeRollbackToSavepoint.models[0].name); + try std.testing.expectEqual(52.25, resultBeforeRollbackToSavepoint.models[0].amount); + } + + + try session.rollbackTo("my_savepoint"); + + + // SELECT after rollback to savepoint in the transaction. + { + var queryAfterRollbackToSavepoint = repository.MyModelRepository.Query.init(std.testing.allocator, session.connector(), .{}); + try queryAfterRollbackToSavepoint.whereValue(usize, "id", "=", 1); + defer queryAfterRollbackToSavepoint.deinit(); + + // Get models. + var resultAfterRollbackToSavepoint = try queryAfterRollbackToSavepoint.get(std.testing.allocator); + defer resultAfterRollbackToSavepoint.deinit(); + + // Check that one model has been retrieved, then check its type and values. + try std.testing.expectEqual(1, resultAfterRollbackToSavepoint.models.len); + try std.testing.expectEqual(repository.MyModel, @TypeOf(resultAfterRollbackToSavepoint.models[0].*)); + try std.testing.expectEqual(1, resultAfterRollbackToSavepoint.models[0].id); + try std.testing.expectEqualStrings("tempname", resultAfterRollbackToSavepoint.models[0].name); + try std.testing.expectEqual(50.00, resultAfterRollbackToSavepoint.models[0].amount); + } + + + try session.rollbackTransaction(); + + + // SELECT outside of the rolled back transaction. + { + var queryOutside = repository.MyModelRepository.Query.init(std.testing.allocator, session.connector(), .{}); + try queryOutside.whereValue(usize, "id", "=", 1); + defer queryOutside.deinit(); + + // Get models. + var resultOutside = try queryOutside.get(std.testing.allocator); + defer resultOutside.deinit(); + + // Check that one model has been retrieved, then check its type and values. + try std.testing.expectEqual(1, resultOutside.models.len); + try std.testing.expectEqual(repository.MyModel, @TypeOf(resultOutside.models[0].*)); + try std.testing.expectEqual(1, resultOutside.models[0].id); + try std.testing.expectEqualStrings("test", resultOutside.models[0].name); + try std.testing.expectEqual(50.00, resultOutside.models[0].amount); + } +}