Add sessions (as a specific database connector) and transactions.

Closes #2
This commit is contained in:
Madeorsk 2024-10-22 16:28:06 +02:00
parent fee00d5bc9
commit cc4c808937
Signed by: Madeorsk
GPG key ID: 677E51CA765BB79F
6 changed files with 268 additions and 4 deletions

View file

@ -37,7 +37,12 @@ pub fn handlePostgresqlError(err: anyerror, connection: *database.Connection, st
defer statement.deinit(); defer statement.deinit();
defer connection.release(); defer connection.release();
if (connection.connection.err) |sqlErr| { return handleRawPostgresqlError(err, connection.connection);
}
/// PostgreSQL raw error handling by ZRM.
pub fn handleRawPostgresqlError(err: anyerror, connection: *pg.Conn) anyerror {
if (connection.err) |sqlErr| {
if (global.debugMode) { if (global.debugMode) {
// If debug mode is enabled, show the PostgreSQL error. // If debug mode is enabled, show the PostgreSQL error.
std.debug.print("PostgreSQL error\n{s}: {s}\n", .{sqlErr.code, sqlErr.message}); std.debug.print("PostgreSQL error\n{s}: {s}\n", .{sqlErr.code, sqlErr.message});

View file

@ -1,4 +1,5 @@
const global = @import("global.zig"); const global = @import("global.zig");
const session = @import("session.zig");
const repository = @import("repository.zig"); const repository = @import("repository.zig");
const insert = @import("insert.zig"); const insert = @import("insert.zig");
const _sql = @import("sql.zig"); const _sql = @import("sql.zig");
@ -15,6 +16,7 @@ pub const QueryParameter = _sql.QueryParameter;
pub const SqlParams = _sql.SqlParams; pub const SqlParams = _sql.SqlParams;
pub const database = @import("database.zig"); pub const database = @import("database.zig");
pub const Session = session.Session;
pub const conditions = @import("conditions.zig"); pub const conditions = @import("conditions.zig");

122
src/session.zig Normal file
View file

@ -0,0 +1,122 @@
const std = @import("std");
const pg = @import("pg");
const postgresql = @import("postgresql.zig");
const database = @import("database.zig");
/// Session for multiple repository operations.
pub const Session = struct {
const Self = @This();
_database: *pg.Pool,
/// The active connection for the session.
connection: *pg.Conn,
/// Execute a comptime-known SQL command for the current session.
fn exec(self: Self, comptime sql: []const u8) !void {
_ = self.connection.exec(sql, .{}) catch |err| {
return postgresql.handleRawPostgresqlError(err, self.connection);
};
}
/// Begin a new transaction.
pub fn beginTransaction(self: Self) !void {
try self.exec("BEGIN;");
}
/// Rollback the current transaction.
pub fn rollbackTransaction(self: Self) !void {
try self.exec("ROLLBACK;");
}
/// Commit the current transaction.
pub fn commitTransaction(self: Self) !void {
try self.exec("COMMIT;");
}
/// Create a new savepoint with the given name.
pub fn savepoint(self: Self, comptime _savepoint: []const u8) !void {
try self.exec("SAVEPOINT " ++ _savepoint ++ ";");
}
/// Rollback to the savepoint with the given name.
pub fn rollbackTo(self: Self, comptime _savepoint: []const u8) !void {
try self.exec("ROLLBACK TO " ++ _savepoint ++ ";");
}
/// Initialize a new session.
pub fn init(_database: *pg.Pool) !Session {
return .{
._database = _database,
.connection = try _database.acquire(),
};
}
/// Deinitialize the session.
pub fn deinit(self: *Self) void {
self.connection.release();
}
/// Get a database connector instance for the current session.
pub fn connector(self: *Self) database.Connector {
return database.Connector{
._interface = .{
.instance = self,
.getConnection = getConnection,
},
};
}
// Connector implementation.
/// Get the current connection.
fn getConnection(opaqueSelf: *anyopaque) !*database.Connection {
const self: *Self = @ptrCast(@alignCast(opaqueSelf));
// Initialize a new connection.
const sessionConnection = try self._database._allocator.create(SessionConnection);
sessionConnection.* = .{
.session = self,
};
return try sessionConnection.connection();
}
};
fn noRelease(_: *anyopaque) void {}
/// A session connection.
const SessionConnection = struct {
const Self = @This();
/// Session of the connection.
session: *Session,
/// Connection instance, to only keep one at a time.
_connection: ?database.Connection = null,
/// Get a database connection.
pub fn connection(self: *Self) !*database.Connection {
if (self._connection == null) {
// A new connection needs to be initialized.
self._connection = .{
.connection = self.session.connection,
._interface = .{
.instance = self,
.release = releaseConnection,
},
};
}
return &(self._connection.?);
}
// Implementation.
/// Free the current connection (doesn't actually release the connection, as it is required to stay the same all along the session).
fn releaseConnection(self: *database.Connection) void {
// Free allocated connection.
const sessionConnection: *SessionConnection = @ptrCast(@alignCast(self._interface.instance));
sessionConnection.session._database._allocator.destroy(sessionConnection);
}
};

View file

@ -30,7 +30,7 @@ const MySubmodel = struct {
}; };
/// An example model. /// An example model.
const MyModel = struct { pub const MyModel = struct {
id: i32, id: i32,
name: []const u8, name: []const u8,
amount: f64, amount: f64,
@ -39,7 +39,7 @@ const MyModel = struct {
}; };
/// SQL table shape of the example model. /// SQL table shape of the example model.
const MyModelTable = struct { pub const MyModelTable = struct {
id: i32, id: i32,
name: []const u8, name: []const u8,
amount: f64, amount: f64,
@ -64,7 +64,7 @@ fn modelToSql(model: MyModel) !MyModelTable {
} }
/// Declare a model repository. /// Declare a model repository.
const MyModelRepository = zrm.Repository(MyModel, MyModelTable, .{ pub const MyModelRepository = zrm.Repository(MyModel, MyModelTable, .{
.table = "models", .table = "models",
// Insert shape used by default for inserts in the repository. // Insert shape used by default for inserts in the repository.

View file

@ -4,4 +4,5 @@ comptime {
_ = @import("query.zig"); _ = @import("query.zig");
_ = @import("repository.zig"); _ = @import("repository.zig");
_ = @import("composite.zig"); _ = @import("composite.zig");
_ = @import("sessions.zig");
} }

134
tests/sessions.zig Normal file
View file

@ -0,0 +1,134 @@
const std = @import("std");
const pg = @import("pg");
const zrm = @import("zrm");
const repository = @import("repository.zig");
/// PostgreSQL database connection.
var database: *pg.Pool = undefined;
/// Initialize database connection.
fn initDatabase(allocator: std.mem.Allocator) !void {
database = try pg.Pool.init(allocator, .{
.connect = .{
.host = "localhost",
.port = 5432,
},
.auth = .{
.username = "zrm",
.password = "zrm",
.database = "zrm",
},
.size = 1,
});
}
test "session with rolled back transaction and savepoint" {
zrm.setDebug(true);
// Initialize database.
try initDatabase(std.testing.allocator);
defer database.deinit();
// Start a new session and perform operations in a transaction.
var session = try zrm.Session.init(database);
defer session.deinit();
try session.beginTransaction();
// First UPDATE in the transaction.
{
var firstUpdate = repository.MyModelRepository.Update(struct {
name: []const u8,
}).init(std.testing.allocator, session.connector());
defer firstUpdate.deinit();
try firstUpdate.set(.{
.name = "tempname",
});
try firstUpdate.whereValue(usize, "id", "=", 1);
var firstUpdateResult = try firstUpdate.update(std.testing.allocator);
firstUpdateResult.deinit();
}
// Set a savepoint.
try session.savepoint("my_savepoint");
// Second UPDATE in the transaction.
{
var secondUpdate = repository.MyModelRepository.Update(struct {
amount: f64,
}).init(std.testing.allocator, session.connector());
defer secondUpdate.deinit();
try secondUpdate.set(.{
.amount = 52.25,
});
try secondUpdate.whereValue(usize, "id", "=", 1);
var secondUpdateResult = try secondUpdate.update(std.testing.allocator);
secondUpdateResult.deinit();
}
// SELECT before rollback to savepoint in the transaction.
{
var queryBeforeRollbackToSavepoint = repository.MyModelRepository.Query.init(std.testing.allocator, session.connector(), .{});
try queryBeforeRollbackToSavepoint.whereValue(usize, "id", "=", 1);
defer queryBeforeRollbackToSavepoint.deinit();
// Get models.
var resultBeforeRollbackToSavepoint = try queryBeforeRollbackToSavepoint.get(std.testing.allocator);
defer resultBeforeRollbackToSavepoint.deinit();
// Check that one model has been retrieved, then check its type and values.
try std.testing.expectEqual(1, resultBeforeRollbackToSavepoint.models.len);
try std.testing.expectEqual(repository.MyModel, @TypeOf(resultBeforeRollbackToSavepoint.models[0].*));
try std.testing.expectEqual(1, resultBeforeRollbackToSavepoint.models[0].id);
try std.testing.expectEqualStrings("tempname", resultBeforeRollbackToSavepoint.models[0].name);
try std.testing.expectEqual(52.25, resultBeforeRollbackToSavepoint.models[0].amount);
}
try session.rollbackTo("my_savepoint");
// SELECT after rollback to savepoint in the transaction.
{
var queryAfterRollbackToSavepoint = repository.MyModelRepository.Query.init(std.testing.allocator, session.connector(), .{});
try queryAfterRollbackToSavepoint.whereValue(usize, "id", "=", 1);
defer queryAfterRollbackToSavepoint.deinit();
// Get models.
var resultAfterRollbackToSavepoint = try queryAfterRollbackToSavepoint.get(std.testing.allocator);
defer resultAfterRollbackToSavepoint.deinit();
// Check that one model has been retrieved, then check its type and values.
try std.testing.expectEqual(1, resultAfterRollbackToSavepoint.models.len);
try std.testing.expectEqual(repository.MyModel, @TypeOf(resultAfterRollbackToSavepoint.models[0].*));
try std.testing.expectEqual(1, resultAfterRollbackToSavepoint.models[0].id);
try std.testing.expectEqualStrings("tempname", resultAfterRollbackToSavepoint.models[0].name);
try std.testing.expectEqual(50.00, resultAfterRollbackToSavepoint.models[0].amount);
}
try session.rollbackTransaction();
// SELECT outside of the rolled back transaction.
{
var queryOutside = repository.MyModelRepository.Query.init(std.testing.allocator, session.connector(), .{});
try queryOutside.whereValue(usize, "id", "=", 1);
defer queryOutside.deinit();
// Get models.
var resultOutside = try queryOutside.get(std.testing.allocator);
defer resultOutside.deinit();
// Check that one model has been retrieved, then check its type and values.
try std.testing.expectEqual(1, resultOutside.models.len);
try std.testing.expectEqual(repository.MyModel, @TypeOf(resultOutside.models[0].*));
try std.testing.expectEqual(1, resultOutside.models[0].id);
try std.testing.expectEqualStrings("test", resultOutside.models[0].name);
try std.testing.expectEqual(50.00, resultOutside.models[0].amount);
}
}