diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2362424..79a7aee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,4 +36,4 @@ jobs: sozo test - name: Check formatting run: | - scarb fmt --check + scarb fmt --check \ No newline at end of file diff --git a/src/model/game_model.cairo b/src/model/game_model.cairo index b079140..3074000 100644 --- a/src/model/game_model.cairo +++ b/src/model/game_model.cairo @@ -1,6 +1,26 @@ use starknet::{ContractAddress, contract_address_const}; // Keeps track of the state of the game +// New PlayerRating model to store player ratings +#[derive(Copy, Drop, Serde, Introspect)] +#[dojo::model] +pub struct PlayerRating { + #[key] + pub player: ContractAddress, + pub rating: u32 // Elo rating +} + +// Trait for PlayerRating operations +pub trait PlayerRatingTrait { + fn new(player: ContractAddress, rating: u32) -> PlayerRating; +} + +impl PlayerRatingImpl of PlayerRatingTrait { + fn new(player: ContractAddress, rating: u32) -> PlayerRating { + PlayerRating { player, rating } + } +} + #[derive(Serde, Copy, Drop, Introspect, PartialEq)] #[dojo::model] pub struct GameCounter { diff --git a/src/systems/Snooknet.cairo b/src/systems/Snooknet.cairo index d32080c..72e55fd 100644 --- a/src/systems/Snooknet.cairo +++ b/src/systems/Snooknet.cairo @@ -5,7 +5,7 @@ use dojo::event::EventStorage; use dojo_starter::interfaces::ISnooknet::ISnooknet; -use dojo_starter::model::game_model::{Game, GameTrait, GameState, GameCounter}; +use dojo_starter::model::game_model::{Game, GameTrait, GameState, GameCounter, PlayerRating}; use dojo_starter::model::tournament_model::{ Tournament as TournamentModel, TournamentTrait, TournamentStatus, TournamentReward, TournamentCounter, @@ -16,6 +16,12 @@ use dojo_starter::model::player_model::{Player, PlayerTrait}; // dojo decorator #[dojo::contract] pub mod Snooknet { + use super::{ISnooknet, Game, GameTrait, GameCounter, GameState, PlayerRating}; + use starknet::{ + ContractAddress, get_caller_address, get_block_timestamp, contract_address_const, + }; + use dojo::model::{ModelStorage}; + use dojo::event::EventStorage; use super::*; #[derive(Copy, Drop, Serde)] @@ -26,6 +32,15 @@ pub mod Snooknet { pub timestamp: u64, } + #[derive(Copy, Drop, Serde)] + #[dojo::event] + pub struct RatingUpdated { + #[key] + pub player: ContractAddress, + pub new_rating: u32, + } + + #[derive(Copy, Drop, Serde)] #[dojo::event] pub struct GameCreated { @@ -110,6 +125,10 @@ pub mod Snooknet { let player_1 = get_caller_address(); let player_2 = opponent; + // Initialize ratings for players if they don't exist + self.ensure_player_rating(player_1); + self.ensure_player_rating(player_2); + // Create a new game let mut new_game: Game = GameTrait::new( game_id, @@ -137,11 +156,49 @@ pub mod Snooknet { let caller = get_caller_address(); let timestamp = get_block_timestamp(); + assert((caller == game.player1) || (caller == game.player2), 'Not a Player'); + + assert( + (winner == game.player1) + || (winner == game.player2) + || (winner == contract_address_const::<0x0>()), + 'Invalid winner', + ); + + // Ensure game is not already finished + assert(game.state != GameState::Finished, 'Game already ended'); + game.winner = winner; game.state = GameState::Finished; game.updated_at = get_block_timestamp(); + // Update player ratings using Elo algorithm + if winner != contract_address_const::<0x0>() { + // Not a draw + let (new_rating1, new_rating2) = self + .elo_function(game.player1, game.player2, winner); + let mut rating1: PlayerRating = world.read_model(game.player1); + let mut rating2: PlayerRating = world.read_model(game.player2); + rating1.rating = new_rating1; + rating2.rating = new_rating2; + world.write_model(@rating1); + world.write_model(@rating2); + world.emit_event(@RatingUpdated { player: game.player1, new_rating: new_rating1 }); + world.emit_event(@RatingUpdated { player: game.player2, new_rating: new_rating2 }); + } else { + // Draw: both players get 0.5 score + let (new_rating1, new_rating2) = self.elo_function_draw(game.player1, game.player2); + let mut rating1: PlayerRating = world.read_model(game.player1); + let mut rating2: PlayerRating = world.read_model(game.player2); + rating1.rating = new_rating1; + rating2.rating = new_rating2; + world.write_model(@rating1); + world.write_model(@rating2); + world.emit_event(@RatingUpdated { player: game.player1, new_rating: new_rating1 }); + world.emit_event(@RatingUpdated { player: game.player2, new_rating: new_rating2 }); + } + world.write_model(@game); world.emit_event(@Winner { game_id, winner }); world.emit_event(@GameEnded { game_id, timestamp }); @@ -242,6 +299,145 @@ pub mod Snooknet { self.world(@"Snooknet") } + fn ensure_player_rating(ref self: ContractState, player: ContractAddress) { + let mut world = self.world_default(); + let mut rating: PlayerRating = world.read_model(player); + if rating.rating == 0 { + rating.rating = 1500; + world.write_model(@rating); + } + } + + // Elo function for win/loss + fn elo_function( + ref self: ContractState, + player1: ContractAddress, + player2: ContractAddress, + winner: ContractAddress, + ) -> (u32, u32) { + let mut world = self.world_default(); + let rating1: PlayerRating = world.read_model(player1); + let rating2: PlayerRating = world.read_model(player2); + let r1 = rating1.rating; // u32 + let r2 = rating2.rating; // u32 + let k = 32_u32; // Elo K-factor + + // Calculate expected scores (scaled 0 to 1000) + let expected1 = self.calculate_expected(r1, r2); // u32 + let expected2 = self.calculate_expected(r2, r1); // u32 + + // Assign scores based on winner + let (score1, score2) = if winner == player1 { + (1000_u32, 0_u32) + } else { + (0_u32, 1000_u32) + }; + + // Calculate rating changes safely + let new_rating1 = if score1 >= expected1 { + // Positive or zero change (e.g., winner) + let delta = (score1 - expected1) * k / 1000; + r1 + delta + } else { + // Negative change (e.g., loser) + let delta = self.safe_subtract(expected1, score1) * k / 1000; + let min_rating = 100_u32; // Minimum rating to prevent too-low values + if r1 <= delta { + min_rating + } else { + r1 - delta + } + }; + + let new_rating2 = if score2 >= expected2 { + // Positive or zero change + let delta = (score2 - expected2) * k / 1000; + r2 + delta + } else { + // Negative change + let delta = self.safe_subtract(expected2, score2) * k / 1000; + let min_rating = 100_u32; + if r2 <= delta { + min_rating + } else { + r2 - delta + } + }; + + (new_rating1, new_rating2) + } + + // Elo function for draw + fn elo_function_draw( + ref self: ContractState, player1: ContractAddress, player2: ContractAddress, + ) -> (u32, u32) { + let mut world = self.world_default(); + let rating1: PlayerRating = world.read_model(player1); + let rating2: PlayerRating = world.read_model(player2); + let r1 = rating1.rating; + let r2 = rating2.rating; + let k = 32_u32; + + let expected1 = self.calculate_expected(r1, r2); + let expected2 = self.calculate_expected(r2, r1); + let score = 500_u32; // Draw score (scaled) + + let new_rating1 = if score >= expected1 { + let delta = (score - expected1) * k / 1000; + r1 + delta + } else { + let delta = self.safe_subtract(expected1, score) * k / 1000; + let min_rating = 100_u32; + if r1 <= delta { + min_rating + } else { + r1 - delta + } + }; + + let new_rating2 = if score >= expected2 { + let delta = (score - expected2) * k / 1000; + r2 + delta + } else { + let delta = self.safe_subtract(expected2, score) * k / 1000; + let min_rating = 100_u32; + if r2 <= delta { + min_rating + } else { + r2 - delta + } + }; + + (new_rating1, new_rating2) + } + + // Helper function to calculate expected score + fn calculate_expected(self: @ContractState, r1: u32, r2: u32) -> u32 { + if r1 > r2 { + let diff = r1 - r2; + if diff > 400 { + 1000_u32 + } else { + 500_u32 + (diff * 5) / 4 + } + } else { + let diff = r2 - r1; + if diff > 400 { + 0_u32 + } else { + 500_u32 - (diff * 5) / 4 + } + } + } + + fn safe_subtract(self: @ContractState, a: u32, b: u32) -> u32 { + if a >= b { + a - b + } else { + 0_u32 + } + } + fn create_new_tournament_id(ref self: ContractState) -> u256 { let mut world = self.world_default(); let mut tournament_counter: TournamentCounter = world.read_model('v0'); @@ -253,3 +449,4 @@ pub mod Snooknet { } } + diff --git a/src/tests/test_world.cairo b/src/tests/test_world.cairo index 74db6f7..03625bf 100644 --- a/src/tests/test_world.cairo +++ b/src/tests/test_world.cairo @@ -14,7 +14,10 @@ mod tests { use dojo_starter::interfaces::ISnooknet::{ISnooknetDispatcher, ISnooknetDispatcherTrait}; use dojo_starter::model::tournament_model::{TournamentStatus, TournamentReward, Tournament}; - use dojo_starter::model::game_model::{m_Game, GameState, m_GameCounter}; + use dojo_starter::model::game_model::{ + Game, m_Game, GameState, GameCounter, m_GameCounter, PlayerRating, m_PlayerRating, + }; + use starknet::{testing, get_caller_address, contract_address_const}; use dojo_starter::model::tournament_model::{m_Tournament, m_TournamentCounter}; use dojo_starter::model::player_model::{m_Player}; @@ -24,6 +27,8 @@ mod tests { resources: [ TestResource::Model(m_GameCounter::TEST_CLASS_HASH), TestResource::Model(m_Game::TEST_CLASS_HASH), + TestResource::Model(m_PlayerRating::TEST_CLASS_HASH), + TestResource::Event(Snooknet::e_RatingUpdated::TEST_CLASS_HASH), TestResource::Event(Snooknet::e_GameCreated::TEST_CLASS_HASH), TestResource::Event(Snooknet::e_Winner::TEST_CLASS_HASH), TestResource::Event(Snooknet::e_GameEnded::TEST_CLASS_HASH), @@ -145,6 +150,13 @@ mod tests { assert(game.state == GameState::Finished, 'Game not ended'); } + // New test: Verify player rating initialization + #[test] + fn test_player_rating_initialization() { + let caller_1 = contract_address_const::<'aji'>(); + let player_1 = contract_address_const::<'player'>(); + } + #[test] fn test_create_tournament() { let ndef = namespace_def(); @@ -154,6 +166,21 @@ mod tests { let (contract_address, _) = world.dns(@"Snooknet").unwrap(); let actions_system = ISnooknetDispatcher { contract_address }; + testing::set_contract_address(caller_1); + let game_id = actions_system.create_match(player_1, 400); + + let rating1: PlayerRating = world.read_model(caller_1); + let rating2: PlayerRating = world.read_model(player_1); + assert(rating1.rating == 1500, 'Player 1 rating not initialized'); + assert(rating2.rating == 1500, 'Player 2 rating not initialized'); + } + + // New test: Verify rating update after a win + #[test] + fn test_rating_update_after_win() { + let caller_1 = contract_address_const::<'aji'>(); + let player_1 = contract_address_const::<'player'>(); + let organizer = contract_address_const::<'organizer'>(); set_contract_address(organizer); @@ -189,6 +216,24 @@ mod tests { let (contract_address, _) = world.dns(@"Snooknet").unwrap(); let actions_system = ISnooknetDispatcher { contract_address }; + testing::set_contract_address(caller_1); + let game_id = actions_system.create_match(player_1, 400); + + testing::set_contract_address(caller_1); + actions_system.end_match(game_id, caller_1); + + let rating1: PlayerRating = world.read_model(caller_1); + let rating2: PlayerRating = world.read_model(player_1); + assert(rating1.rating > 1500, 'Winner rating not increased'); + assert(rating2.rating < 1500, 'Loser rating not decreased'); + } + + // New test: Verify rating update after a draw + #[test] + fn test_rating_update_after_draw() { + let caller_1 = contract_address_const::<'aji'>(); + let player_1 = contract_address_const::<'player'>(); + let organizer = contract_address_const::<'organizer'>(); set_contract_address(organizer); @@ -226,6 +271,63 @@ mod tests { let (contract_address, _) = world.dns(@"Snooknet").unwrap(); let actions_system = ISnooknetDispatcher { contract_address }; + testing::set_contract_address(caller_1); + let game_id = actions_system.create_match(player_1, 400); + + testing::set_contract_address(caller_1); + actions_system + .end_match(game_id, contract_address_const::<0x0>()); // Zero address indicates draw + + let rating1: PlayerRating = world.read_model(caller_1); + let rating2: PlayerRating = world.read_model(player_1); + assert(rating1.rating == 1500, 'Player 1 rating changed in draw'); + assert(rating2.rating == 1500, 'Player 2 rating changed in draw'); + } + + // New test: Ending an already ended game should panic + #[test] + #[should_panic] + fn test_end_game_already_ended() { + let caller_1 = contract_address_const::<'aji'>(); + let player_1 = contract_address_const::<'player'>(); + + let ndef = namespace_def(); + let mut world = spawn_test_world([ndef].span()); + world.sync_perms_and_inits(contract_defs()); + + let (contract_address, _) = world.dns(@"Snooknet").unwrap(); + let actions_system = ISnooknetDispatcher { contract_address }; + + testing::set_contract_address(caller_1); + let game_id = actions_system.create_match(player_1, 400); + + testing::set_contract_address(caller_1); + actions_system.end_match(game_id, caller_1); + + // Attempt to end the game again + actions_system.end_match(game_id, caller_1); + } + + // New test: Ending a game with an invalid winner should panic + #[test] + #[should_panic] + fn test_invalid_winner() { + let caller_1 = contract_address_const::<'aji'>(); + let player_1 = contract_address_const::<'player'>(); + let invalid_winner = contract_address_const::<'invalid'>(); + + let ndef = namespace_def(); + let mut world = spawn_test_world([ndef].span()); + world.sync_perms_and_inits(contract_defs()); + + let (contract_address, _) = world.dns(@"Snooknet").unwrap(); + let actions_system = ISnooknetDispatcher { contract_address }; + + testing::set_contract_address(caller_1); + let game_id = actions_system.create_match(player_1, 400); + + testing::set_contract_address(caller_1); + actions_system.end_match(game_id, invalid_winner); let organizer = contract_address_const::<'organizer'>(); set_contract_address(organizer);