diff --git a/db/migrations/20241212075345_validator_db.sql b/db/migrations/20241212075345_validator_db.sql index 58e6e50..a162ddb 100644 --- a/db/migrations/20241212075345_validator_db.sql +++ b/db/migrations/20241212075345_validator_db.sql @@ -35,7 +35,6 @@ CREATE TABLE piece ( chunk_idx INTEGER, -- Index of the chunk in the file piece_idx INTEGER, -- Index of the piece in the chunk piece_type INTEGER CHECK (piece_type IN (0, 1)), -- Type of the piece (0: data, 1: parity) - tag TEXT, -- APDP Tag of the piece signature TEXT -- Signature of the DHT entry by the miner storing the piece ); diff --git a/settings.toml.example b/settings.toml.example index 287e48d..418a461 100644 --- a/settings.toml.example +++ b/settings.toml.example @@ -39,6 +39,9 @@ store_dir = "object_store" [validator] synthetic = false +top_miner_ratio = 0.7 +ssl_certfile = "cert.pem" +ssl_keyfile = "key.pem" [validator.neuron] num_concurrent_forwards = 1 diff --git a/storb/config.py b/storb/config.py index 745839e..62005b8 100644 --- a/storb/config.py +++ b/storb/config.py @@ -274,3 +274,10 @@ def add_validator_args(self): help="Query timeout", default=self.settings.validator.query.timeout, ) + + self._parser.add_argument( + "--top_miner_ratio", + type=float, + help="Top miner ratio", + default=self.settings.validator.top_miner_ratio, + ) diff --git a/storb/util/uids.py b/storb/util/uids.py index a8f8064..e9e9fbb 100644 --- a/storb/util/uids.py +++ b/storb/util/uids.py @@ -1,8 +1,12 @@ import random +import numpy as np from fiber.chain.metagraph import Metagraph from storb.neuron import Neuron +from storb.util.logging import get_logger + +logger = get_logger(__name__) def check_hotkey_availability(metagraph: Metagraph, hotkey: str) -> bool: @@ -80,3 +84,151 @@ def get_random_hotkeys(self: Neuron, k: int, exclude: list[int] = None) -> list[ ) hotkeys = random.sample(available_hotkeys, k) return hotkeys + + +def get_ranked_hotkeys( + self: Neuron, + k: int, + exclude: list[int] = None, + top_fraction: float = 0.7, + low_fraction: float = 0.3, +) -> list[str]: + """Returns a mixed list of k miner hotkeys from the metagraph, combining + high-ranked, low-ranked, and randomly selected miners based on + EMA scores in `self.scores`. + + Parameters + ---------- + k : int + Total number of hotkeys to return. + exclude : list[int], optional + List of hotkeys to exclude from the selection. + top_fraction : float, optional + Fraction of k to fill with top-ranked miners. Must be between 0 and 1. + Default is 0.7. + low_fraction : float, optional + Fraction of k to fill with low-ranked miners. Must be between 0 and 1. + Default is 0.3. + + Returns + ------- + mixed_hotkeys : list[str] + A list of hotkeys, ordered as [top_hotkeys + low_hotkeys + random_hotkeys]. + The total length will be <= k (it may be shorter if there aren't enough + available hotkeys). + + Notes + ----- + - `self.scores` is a numpy array of shape [num_nodes]. + - The node's ID (index in `self.scores`) is `self.metagraph.nodes[hotkey].node_id`. + """ + + # Validate inputs + if not 0 <= top_fraction <= 1: + raise ValueError("`top_fraction` must be between 0 and 1.") + if not 0 <= low_fraction <= 1: + raise ValueError("`low_fraction` must be between 0 and 1.") + if top_fraction + low_fraction > 1: + raise ValueError("top_fraction + low_fraction must not exceed 1.") + if k <= 0: + raise ValueError("`k` must be a positive integer.") + + exclude_set = set(exclude) if exclude else set() + + # Gather available hotkeys + candidate_hotkeys = [] + avail_hotkeys = [] + for hotkey in self.metagraph.nodes: + hotkey_is_available = check_hotkey_availability(self.metagraph, hotkey) + hotkey_is_not_excluded = hotkey not in exclude_set + + if hotkey_is_available: + avail_hotkeys.append(hotkey) + if hotkey_is_not_excluded: + candidate_hotkeys.append(hotkey) + + # Adjust k if needed + k = min(k, len(avail_hotkeys)) + if k == 0: + return [] + + # Sort candidate hotkeys based on EMA scores in descending order + # (higher score = higher rank) + sorted_hotkeys_desc = sorted( + candidate_hotkeys, + key=lambda hk: self.scores[self.metagraph.nodes[hk].node_id], + reverse=True, + ) + + # Calculate how many top and low we need + num_top = int(k * top_fraction) + num_low = int(k * low_fraction) + + # Select top-ranked hotkeys + top_hotkeys = sorted_hotkeys_desc[:num_top] + + low_candidates = sorted_hotkeys_desc[-num_low:] + low_hotkeys = [hk for hk in low_candidates if hk not in top_hotkeys] + # Recompute `num_low` in case of overlap + num_low = len(low_hotkeys) + + # Calculate how many hotkeys remain to be filled randomly + remaining = k - (num_top + num_low) + random_selected = [] + + if remaining > 0: + # The middle set excludes both top and low + middle_hotkeys = set(candidate_hotkeys) - set(top_hotkeys) - set(low_hotkeys) + middle_hotkeys = list(middle_hotkeys) + + if middle_hotkeys: + # Assign weights proportional to each node's score among the middle + scores_middle = [ + self.scores[self.metagraph.nodes[hk].node_id] for hk in middle_hotkeys + ] + max_score = max(scores_middle) or 1.0 + + # Normalize + weights = [s / max_score for s in scores_middle] + total_weight = sum(weights) + + if total_weight > 0: + normalized_weights = [w / total_weight for w in weights] + else: + # If all are zero + normalized_weights = [1 / len(middle_hotkeys)] * len(middle_hotkeys) + + # Draw without replacement using those weights + actual_num_random = min(remaining, len(middle_hotkeys)) + try: + random_selected = list( + np.random.choice( + middle_hotkeys, + size=actual_num_random, + replace=False, + p=normalized_weights, + ) + ) + except ValueError as e: + logger.error(f"Random choice failed: {e}") + random_selected = [] + + # Combine all selected + mixed_hotkeys = top_hotkeys + low_hotkeys + random_selected + current_total = len(mixed_hotkeys) + + # If we still don't have k, fill the remainder from any available hotkeys + if current_total < k: + needed = k - current_total + # Exclude any we've already chosen + not_selected = set(avail_hotkeys) - set(mixed_hotkeys) + additional_hotkeys = list(not_selected) + if additional_hotkeys: + additional_sample = random.sample( + additional_hotkeys, min(needed, len(additional_hotkeys)) + ) + mixed_hotkeys.extend(additional_sample) + + # Final trim in case we overshoot (unlikely, but just to be safe) + mixed_hotkeys = mixed_hotkeys[:k] + return mixed_hotkeys diff --git a/storb/validator/piece_processing.py b/storb/validator/piece_processing.py index daedddb..ecaf4f0 100644 --- a/storb/validator/piece_processing.py +++ b/storb/validator/piece_processing.py @@ -1,17 +1,18 @@ import asyncio import json +import math from datetime import datetime import numpy as np from storb import db, protocol from storb.challenge import APDPTag -from storb.constants import QUERY_TIMEOUT +from storb.constants import NUM_UIDS_QUERY, QUERY_TIMEOUT from storb.dht.piece_dht import PieceDHTValue from storb.protocol import PieceChallenge from storb.util.logging import get_logger from storb.util.message_signing import PieceMessage, sign_message -from storb.util.piece import piece_hash +from storb.util.piece import Piece, PieceType, piece_hash from storb.util.query import Payload from storb.validator.types import PieceTask, ProcessedPieceResponse @@ -89,7 +90,7 @@ def consume_piece_queue(self): logger.info("Exiting consume_piece_queue thread.") - async def process_pieces(self, pieces, hotkeys): + async def process_pieces(self, pieces: list[Piece], hotkeys: list[str]): """ Process each piece: compute piece hash, store them on the miners, update DB stats, update latencies, etc. @@ -100,7 +101,6 @@ async def process_pieces(self, pieces, hotkeys): piece_hashes = [] processed_pieces = [] - # Basic example of how you'd do batch queries: to_query = [] curr_batch_size = 0 uids = [] @@ -113,9 +113,10 @@ async def process_pieces(self, pieces, hotkeys): async def handle_batch_requests(): nonlocal to_query, latencies batch_responses = await asyncio.gather(*(t for _, t in to_query)) - + logger.info(f"Batch responses: {batch_responses}") for i, batch_result in enumerate(batch_responses): real_idx = to_query[i][0] + successful_miners = [] # batch_result is the list of (uid, payload_response) for uid, payload_resp in batch_result: miner_stats[uid]["store_attempts"] += 1 @@ -127,7 +128,13 @@ async def handle_batch_requests(): ) miner_stats[uid]["store_successes"] += 1 miner_stats[uid]["total_successes"] += 1 + successful_miners.append(uid) + + if len(successful_miners) == 0: + logger.error(f"No successful miners for piece {real_idx}") + return + logger.info(f"Successful miners: {successful_miners}") # Put piece info in the queue for DHT storing self.piece_miners.setdefault(piece_hashes[real_idx], []).extend( [uid for (uid, p) in batch_result if p and p.data] @@ -169,8 +176,69 @@ async def handle_batch_requests(): logger.info(f"Curr piece queue size: {self.piece_queue.qsize()}") to_query = [] + # Separate pieces into PARITY and DATA + data_pieces = [p for p in pieces if p.piece_type == PieceType.Data] + parity_pieces = [p for p in pieces if p.piece_type == PieceType.Parity] + + total_pieces = len(pieces) + top_miner_pieces_count = math.floor( + total_pieces * self.settings.validator.top_miner_ratio + ) + + top_miner_pieces = parity_pieces[:] + + if len(top_miner_pieces) < top_miner_pieces_count: + top_miner_pieces.extend( + data_pieces[: top_miner_pieces_count - len(top_miner_pieces)] + ) + + low_miner_pieces = data_pieces[top_miner_pieces_count:] + + top_miner_count = int(NUM_UIDS_QUERY * self.settings.validator.top_miner_ratio) + top_miners = hotkeys[:top_miner_count] + low_miners = hotkeys[top_miner_count:] + + logger.debug(f"Top miners: {top_miners}, Low miners: {low_miners}") + # Loop pieces - for idx, piece_info in enumerate(pieces): + for idx, piece_info in enumerate(top_miner_pieces): + p_hash = piece_hash(piece_info.data) + piece_hashes.append(p_hash) + processed_pieces.append( + protocol.ProcessedPieceInfo( + chunk_idx=piece_info.chunk_idx, + piece_type=piece_info.piece_type, + piece_idx=piece_info.piece_idx, + data=piece_info.data, + piece_id=p_hash, + ) + ) + # Create a store request + payload = Payload( + data=protocol.Store( + chunk_idx=piece_info.chunk_idx, + piece_type=piece_info.piece_type, + piece_idx=piece_info.piece_idx, + ), + file=piece_info.data, + ) + task = asyncio.create_task( + self.query_multiple_miners( + miner_hotkeys=top_miners, + endpoint="/store", + payload=payload, + method="POST", + ) + ) + to_query.append((idx, task)) + curr_batch_size += 1 + + # If batch is full, send + if curr_batch_size >= self.settings.validator.query.batch_size: + await handle_batch_requests() + curr_batch_size = 0 + + for idx, piece_info in enumerate(low_miner_pieces): p_hash = piece_hash(piece_info.data) piece_hashes.append(p_hash) processed_pieces.append( @@ -193,7 +261,7 @@ async def handle_batch_requests(): ) task = asyncio.create_task( self.query_multiple_miners( - miner_hotkeys=hotkeys, + miner_hotkeys=low_miners, endpoint="/store", payload=payload, method="POST", diff --git a/storb/validator/routes.py b/storb/validator/routes.py index 2c1e4e0..c4beff6 100644 --- a/storb/validator/routes.py +++ b/storb/validator/routes.py @@ -34,7 +34,7 @@ reconstruct_data_stream, ) from storb.util.query import Payload -from storb.util.uids import get_random_hotkeys +from storb.util.uids import get_ranked_hotkeys logger = get_logger(__name__) @@ -120,8 +120,12 @@ async def upload_file(self, file: UploadFile = File(...)) -> protocol.StoreRespo timestamp = str(datetime.now(UTC).timestamp()) - # TODO: Consider miner scores for selection, and not just their availability - hotkeys = get_random_hotkeys(self, NUM_UIDS_QUERY) + hotkeys = get_ranked_hotkeys( + self, + k=NUM_UIDS_QUERY, + top_fraction=self.settings.validator.top_miner_ratio, + low_fraction=(1 - self.settings.validator.top_miner_ratio), + ) chunk_hashes = [] piece_hashes = set() diff --git a/storb/validator/validator.py b/storb/validator/validator.py index 3a3bea8..409c492 100644 --- a/storb/validator/validator.py +++ b/storb/validator/validator.py @@ -7,6 +7,7 @@ import numpy as np import uvicorn +from fastapi.middleware.cors import CORSMiddleware from fiber.encrypted.miner.endpoints.handshake import ( factory_router as get_subnet_router, ) @@ -127,6 +128,13 @@ def app_init(self): self.app.add_middleware(LoggerMiddleware) self.app.add_middleware(FileSizeMiddleware) + self.app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["Content-Disposition"], + ) self.app.add_api_route( "/status", @@ -162,7 +170,17 @@ def app_init(self): self.app.include_router(get_subnet_router()) - config = uvicorn.Config(self.app, host="0.0.0.0", port=self.settings.api_port) + config = uvicorn.Config( + self.app, + host="0.0.0.0", + port=self.settings.api_port, + ssl_certfile=self.settings.validator.ssl_certfile + if self.settings.validator.ssl_certfile + else None, + ssl_keyfile=self.settings.validator.ssl_keyfile + if self.settings.validator.ssl_keyfile + else None, + ) self.server = uvicorn.Server(config) assert self.server, "Uvicorn server must be initialised"