diff --git a/db/migrations/20241212075345_validator_db.sql b/db/migrations/20241212075345_validator_db.sql index 74eff2d..58e6e50 100644 --- a/db/migrations/20241212075345_validator_db.sql +++ b/db/migrations/20241212075345_validator_db.sql @@ -27,11 +27,11 @@ CREATE TABLE chunk ( signature TEXT -- Signature of the DHT entry by the validator ); --- Table for mapping piece_hash to piece metadata and miner_id +-- Table for mapping piece_hash to piece metadata and miner_uids CREATE TABLE piece ( piece_hash TEXT PRIMARY KEY, -- Piece ID validator_id INTEGER, -- ID of the validator - miner_id TEXT, -- IDs of the miner in a JSON Array + miner_uids TEXT, -- IDs of the miner in a JSON Array chunk_idx INTEGER, -- Index of the chunk in the file piece_idx INTEGER, -- Index of the piece in the chunk piece_type INTEGER CHECK (piece_type IN (0, 1)), -- Type of the piece (0: data, 1: parity) @@ -39,6 +39,13 @@ CREATE TABLE piece ( signature TEXT -- Signature of the DHT entry by the miner storing the piece ); +CREATE TABLE piece_challenge ( + miner_uids TEXT, -- IDs of the miner in a JSON Array + piece_hash TEXT, -- Piece ID + challenge_timestamp TEXT, -- Timestamp of the challenge + tag TEXT -- APDP Tag of the piece +); + -- Table for miner stats -- CREATE TABLE miner_stats ( miner_uid INTEGER PRIMARY KEY, @@ -64,3 +71,6 @@ DROP TABLE IF EXISTS chunk; -- Drop the `tracker` table DROP TABLE IF EXISTS tracker; + +-- Drop the `piece_challenge` table +DROP TABLE IF EXISTS piece_challenge; diff --git a/docs/challenge.md b/docs/challenge.md index cba63a8..e212eac 100644 --- a/docs/challenge.md +++ b/docs/challenge.md @@ -160,7 +160,7 @@ The Challenge System operates asynchronously to support distributed and scalable The validator selects a miner and a specific data piece to challenge. ```python - await self.challenge_miner(miner_id, piece_id, tag) + await self.challenge_miner(miner_uids, piece_id, tag) ``` - **Process:** diff --git a/storb/db.py b/storb/db.py index fdd0c5f..373bbdd 100644 --- a/storb/db.py +++ b/storb/db.py @@ -10,6 +10,7 @@ from storb.dht.chunk_dht import ChunkDHTValue from storb.dht.piece_dht import PieceDHTValue from storb.dht.tracker_dht import TrackerDHTValue +from storb.protocol import PieceChallenge @asynccontextmanager @@ -229,16 +230,15 @@ async def delete_chunk_entry(conn: aiosqlite.Connection, chunk_hash: str): async def set_piece_entry(conn: aiosqlite.Connection, entry: PieceDHTValue): query = """ - INSERT INTO piece (piece_hash, validator_id, miner_id, chunk_idx, piece_idx, piece_type, tag, signature) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO piece (piece_hash, validator_id, miner_uids, chunk_idx, piece_idx, piece_type, signature) + VALUES (?, ?, ?, ?, ?, ?, ?) ON CONFLICT(piece_hash) DO UPDATE SET validator_id = excluded.validator_id, - miner_id = excluded.miner_id, + miner_uids = excluded.miner_uids, chunk_idx = excluded.chunk_idx, piece_idx = excluded.piece_idx, piece_type = excluded.piece_type, - tag = excluded.tag, signature = excluded.signature """ await conn.execute( @@ -246,11 +246,10 @@ async def set_piece_entry(conn: aiosqlite.Connection, entry: PieceDHTValue): ( entry.piece_hash, entry.validator_id, - entry.miner_id, + entry.miner_uids, entry.chunk_idx, entry.piece_idx, entry.piece_type, - entry.tag, entry.signature, ), ) @@ -270,12 +269,11 @@ async def get_piece_entry( return PieceDHTValue( piece_hash=row[0], validator_id=row[1], - miner_id=row[2], + miner_uids=row[2], chunk_idx=row[3], piece_idx=row[4], piece_type=row[5], - tag=row[6], - signature=row[7], + signature=row[6], ) return None # Entry not found @@ -289,43 +287,68 @@ async def delete_piece_entry(conn: aiosqlite.Connection, piece_hash: str): await conn.commit() -async def get_random_piece( - conn: aiosqlite.Connection, validator_id: int -) -> PieceDHTValue | None: - """Randomly selects a piece from the `piece` table for a given validator. +async def set_piece_challenge_entry(conn: aiosqlite.Connection, entry: PieceChallenge): + query = """ + INSERT INTO piece_challenge (miner_uids, piece_hash, challenge_timestamp, tag) + VALUES (?, ?, ?, ?) + """ + await conn.execute( + query, + ( + entry.miner_uids, + entry.piece_hash, + entry.challenge_timestamp, + entry.tag, + ), + ) + await conn.commit() + + +async def get_piece_challenge_entry(conn: aiosqlite.Connection, piece_id: str): + query = """ + SELECT * FROM piece_challenge + WHERE piece_id = ? + """ + async with conn.execute(query, (piece_id,)) as cursor: + row = await cursor.fetchone() + if row: + return PieceChallenge( + miner_uids=row[0], + piece_hash=row[1], + challenge_timestamp=row[2], + tag=row[3], + ) + return None + + +async def get_random_piece_challenge_entry( + conn: aiosqlite.Connection, +) -> PieceChallenge | None: + """Randomly selects a piece from the `piece`_challenge table. Parameters ---------- conn : aiosqlite.Connection The database connection. - validator_id : int - The validator ID to query pieces for. Returns ------- - PieceEntry or None - A random PieceEntry object if a piece is found, or None if the table is empty. + PieceChallenge or None + A random PieceChallenge object if a piece is found, or None if the table is empty. """ query = """ - SELECT * FROM piece - WHERE validator_id = ? + SELECT * FROM piece_challenge ORDER BY RANDOM() LIMIT 1 """ - - async with conn.execute(query, (validator_id,)) as cursor: + async with conn.execute(query) as cursor: row = await cursor.fetchone() if row: - miner_ids = [int(i) for i in row[2].split(",")] - return PieceDHTValue( - piece_hash=row[0], - validator_id=row[1], - miner_id=miner_ids, - chunk_idx=row[3], - piece_idx=row[4], - piece_type=row[5], - tag=row[6], - signature=row[7], + return PieceChallenge( + miner_uids=row[0], + piece_hash=row[1], + challenge_timestamp=row[2], + tag=row[3], ) - return None # No pieces found + return None diff --git a/storb/dht/piece_dht.py b/storb/dht/piece_dht.py index 952217d..3cc0f32 100644 --- a/storb/dht/piece_dht.py +++ b/storb/dht/piece_dht.py @@ -6,11 +6,10 @@ class PieceDHTValue(BaseModel): piece_hash: str validator_id: int - miner_id: set[int] | str + miner_uids: set[int] | str chunk_idx: int piece_idx: int piece_type: PieceType - tag: str signature: str def to_dict(self) -> dict: diff --git a/storb/dht/storage.py b/storb/dht/storage.py index 9560864..e5453e3 100644 --- a/storb/dht/storage.py +++ b/storb/dht/storage.py @@ -253,15 +253,14 @@ async def _db_write_data(self, namespace: str, key: str, value: bytes): case "piece": val = PieceDHTValue.model_validate_json(value) - miner_ids = ",".join(str(i) for i in val.miner_id) + miner_uids = ",".join(str(i) for i in val.miner_uids) entry = PieceDHTValue( piece_hash=key, validator_id=val.validator_id, - miner_id=miner_ids, + miner_uids=miner_uids, chunk_idx=val.chunk_idx, piece_idx=val.piece_idx, piece_type=val.piece_type, - tag=val.tag, signature=val.signature, ) logger.debug(f"flushing piece entry {entry} to disk") @@ -342,16 +341,15 @@ async def _db_read_data(self, key: bytes) -> DHTValue: entry = await db.get_piece_entry(conn, db_key) if entry is None: return None - miner_ids = [int(i) for i in entry.miner_id.split(",")] + miner_uids = [int(i) for i in entry.miner_uids.split(",")] return ( PieceDHTValue( piece_hash=entry.piece_hash, validator_id=entry.validator_id, - miner_id=miner_ids, + miner_uids=miner_uids, chunk_idx=entry.chunk_idx, piece_idx=entry.piece_idx, piece_type=entry.piece_type, - tag=entry.tag, signature=entry.signature, ) .model_dump_json() diff --git a/storb/miner/miner.py b/storb/miner/miner.py index e0b479d..dfa480f 100644 --- a/storb/miner/miner.py +++ b/storb/miner/miner.py @@ -11,9 +11,9 @@ ) from pydantic import ValidationError -from storb import protocol from storb.constants import QUERY_TIMEOUT, NeuronType from storb.neuron import Neuron +from storb.protocol import AckChallenge, NewChallenge, ProofResponse, Retrieve, Store from storb.util.logging import get_logger from storb.util.message_signing import verify_message from storb.util.middleware import LoggerMiddleware @@ -30,7 +30,7 @@ def __init__(self): self.object_store = ObjectStore(store_dir=self.settings.store_dir) self.challenge_queue: asyncio.PriorityQueue[ - tuple[datetime.datetime, protocol.NewChallenge] + tuple[datetime.datetime, NewChallenge] ] = asyncio.PriorityQueue() async def start(self): @@ -81,7 +81,7 @@ def app_init(self): "/store", self.store_piece, methods=["POST"], - response_model=protocol.Store, + response_model=Store, # TODO: bring this back # dependencies=[Depends(blacklist_low_stake), Depends(verify_request)], ) @@ -90,7 +90,7 @@ def app_init(self): "/retrieve", self.get_piece, methods=["POST"], - # response_model=protocol.Retrieve, + # response_model=Retrieve, # dependencies=[Depends(blacklist_low_stake), Depends(verify_request)], ) @@ -98,7 +98,7 @@ def app_init(self): "/challenge", self.ack_challenge, methods=["POST"], - response_model=protocol.AckChallenge, + response_model=AckChallenge, ) self.app.include_router(get_subnet_router()) @@ -115,7 +115,7 @@ async def status(self) -> str: async def store_piece( self, json_data: str = Form(None), file: UploadFile = None - ) -> protocol.Store: + ) -> Store: """Stores a piece which is received from a validator. Parameters @@ -127,13 +127,13 @@ async def store_piece( Returns ------- - protocol.Store + Store The response object containing details of the stored piece. """ logger.info("Received store request") - piece_info = protocol.Store.model_validate_json(json_data) + piece_info = Store.model_validate_json(json_data) piece_bytes = await file.read() piece_id = piece_hash(piece_bytes) logger.debug( @@ -142,7 +142,7 @@ async def store_piece( await self.object_store.write(piece_id, piece_bytes) - response = protocol.Store( + response = Store( piece_id=piece_id, chunk_idx=piece_info.chunk_idx, piece_idx=piece_info.piece_idx, @@ -151,12 +151,12 @@ async def store_piece( return response - async def get_piece(self, request: protocol.Retrieve): + async def get_piece(self, request: Retrieve): """Returns a piece from storage as JSON metadata and a file. Parameters ---------- - request : protocol.Retrieve + request : Retrieve The request object containing the piece_id to retrieve. """ @@ -167,7 +167,7 @@ async def get_piece(self, request: protocol.Retrieve): piece = await self.object_store.read(request.piece_id) # Create the JSON metadata response - metadata = protocol.Retrieve(piece_id=request.piece_id).model_dump_json() + metadata = Retrieve(piece_id=request.piece_id).model_dump_json() # Boundary definition boundary = str(uuid.uuid4()) @@ -194,17 +194,17 @@ async def iter_response(): iter_response(), media_type=f"multipart/mixed; boundary={boundary}" ) - async def ack_challenge(self, request: protocol.NewChallenge): + async def ack_challenge(self, request: NewChallenge): """Acknowledges a challenge from a validator, verifies it, and enqueues it upon success. Parameters ---------- - request : protocol.NewChallenge + request : NewChallenge The challenge request object. Returns ------- - protocol.AckChallenge + AckChallenge The response object containing the result of the challenge acknowledgement. """ @@ -223,9 +223,7 @@ async def ack_challenge(self, request: protocol.NewChallenge): logger.error( f"Failed to verify challenge {request.challenge_id} with validator {request.validator_id}" ) - return protocol.AckChallenge( - challenge_id=request.challenge_id, accept=False - ) + return AckChallenge(challenge_id=request.challenge_id, accept=False) try: deadline_dt = datetime.datetime.fromisoformat(request.challenge_deadline) @@ -233,16 +231,14 @@ async def ack_challenge(self, request: protocol.NewChallenge): logger.error( "Invalid challenge_deadline format. Must be valid ISO 8601 string." ) - return protocol.AckChallenge( - challenge_id=request.challenge_id, accept=False - ) + return AckChallenge(challenge_id=request.challenge_id, accept=False) await self.challenge_queue.put((deadline_dt, request)) logger.info( f"Challenge {request.challenge_id} enqueued with deadline {request.challenge_deadline}" ) - return protocol.AckChallenge(challenge_id=request.challenge_id, accept=True) + return AckChallenge(challenge_id=request.challenge_id, accept=True) async def consume_challenges(self): """Consumes challenges from the min-heap and processes them.""" @@ -302,7 +298,7 @@ async def consume_challenges(self): try: payload = Payload( - data=protocol.ProofResponse( + data=ProofResponse( challenge_id=challenge.challenge_id, piece_id=challenge.piece_id, proof=proof, diff --git a/storb/protocol.py b/storb/protocol.py index a9df8d0..b784758 100644 --- a/storb/protocol.py +++ b/storb/protocol.py @@ -36,7 +36,7 @@ class NewChallenge(BaseModel): challenge_id: str piece_id: str validator_id: int - miner_id: int + miner_uids: int challenge_deadline: str public_key: int public_exponent: int @@ -77,3 +77,10 @@ class GetMinersBase(BaseModel): pieces_metadata: Optional[list[list[PieceDHTValue]]] = Field( default=None ) # multi dimensional array of piece metadata. each row corresponds to a chunk + + +class PieceChallenge(BaseModel): + miner_uids: str + piece_hash: str + challenge_timestamp: str + tag: str diff --git a/storb/validator/challenge.py b/storb/validator/challenge.py index 3846e8b..7862a54 100644 --- a/storb/validator/challenge.py +++ b/storb/validator/challenge.py @@ -12,12 +12,12 @@ class ValiChallengeMixin: - async def challenge_miner(self, miner_id: int, piece_id: str, tag: str): + async def challenge_miner(self, miner_uids: int, piece_id: str, tag: str): """Challenge the miners to verify they are storing the pieces Parameters ---------- - miner_id : int + miner_uids : int The ID of the miner to challenge piece_id : str The ID of the piece to challenge the miner for @@ -25,7 +25,7 @@ async def challenge_miner(self, miner_id: int, piece_id: str, tag: str): The tag of the piece to challenge the miner for """ - logger.debug(f"Challenging miner {miner_id} for piece {piece_id}") + logger.debug(f"Challenging miner {miner_uids} for piece {piece_id}") # Create the challenge message challenge = self.challenge.issue_challenge(tag) try: @@ -41,7 +41,7 @@ async def challenge_miner(self, miner_id: int, piece_id: str, tag: str): challenge_id=uuid.uuid4().hex, piece_id=piece_id, validator_id=self.uid, - miner_id=miner_id, + miner_uids=miner_uids, challenge_deadline=challenge_deadline, public_key=self.challenge.key.rsa.public_key().public_numbers().n, public_exponent=self.challenge.key.rsa.public_key().public_numbers().e, @@ -51,9 +51,9 @@ async def challenge_miner(self, miner_id: int, piece_id: str, tag: str): logger.debug(f"Challenge message: {challenge_message}") # Send the challenge to the miner - miner_hotkey = list(self.metagraph.nodes.keys())[miner_id] + miner_hotkey = list(self.metagraph.nodes.keys())[miner_uids] if miner_hotkey is None: - logger.error(f"Miner {miner_id} not found in metagraph") + logger.error(f"Miner {miner_uids} not found in metagraph") return payload = Payload( @@ -62,16 +62,16 @@ async def challenge_miner(self, miner_id: int, piece_id: str, tag: str): time_elapsed=0, ) logger.info( - f"Sent challenge {challenge_message.challenge_id} to miner {miner_id} for piece {piece_id}" + f"Sent challenge {challenge_message.challenge_id} to miner {miner_uids} for piece {piece_id}" ) async with db.get_db_connection(db_dir=self.db_dir) as conn: miner_stats = await db.get_miner_stats( - conn=conn, miner_uid=challenge_message.miner_id + conn=conn, miner_uid=challenge_message.miner_uids ) miner_stats["challenge_attempts"] += 1 await db.update_stats( - conn=conn, miner_uid=challenge_message.miner_id, stats=miner_stats + conn=conn, miner_uid=challenge_message.miner_uids, stats=miner_stats ) logger.debug(f"PRF KEY: {payload.data.challenge.prf_key}") @@ -80,13 +80,13 @@ async def challenge_miner(self, miner_id: int, piece_id: str, tag: str): ) if response is None: - logger.error(f"Failed to challenge miner {miner_id}") + logger.error(f"Failed to challenge miner {miner_uids}") return self.challenges[challenge_message.challenge_id] = challenge_message - logger.debug(f"Received response from miner {miner_id}, response: {response}") - logger.info(f"Successfully challenged miner {miner_id}") + logger.debug(f"Received response from miner {miner_uids}, response: {response}") + logger.info(f"Successfully challenged miner {miner_uids}") async def remove_expired_challenges(self): """ @@ -134,7 +134,7 @@ async def verify_challenge(self, challenge_request: protocol.ProofResponse) -> b return False async with db.get_db_connection(db_dir=self.db_dir) as conn: - miner_stats = await db.get_miner_stats(conn, challenge.miner_id) + miner_stats = await db.get_miner_stats(conn, challenge.miner_uids) proof = challenge_request.proof try: @@ -174,7 +174,7 @@ async def verify_challenge(self, challenge_request: protocol.ProofResponse) -> b async with db.get_db_connection(db_dir=self.db_dir) as conn: miner_stats["challenge_successes"] += 1 await db.update_stats( - conn=conn, miner_uid=challenge.miner_id, stats=miner_stats + conn=conn, miner_uid=challenge.miner_uids, stats=miner_stats ) # remove challenge from memory diff --git a/storb/validator/piece_processing.py b/storb/validator/piece_processing.py index 8b6d3a9..daedddb 100644 --- a/storb/validator/piece_processing.py +++ b/storb/validator/piece_processing.py @@ -1,84 +1,93 @@ import asyncio +import json +from datetime import datetime import numpy as np from storb import db, protocol +from storb.challenge import APDPTag from storb.constants import QUERY_TIMEOUT +from storb.dht.piece_dht import PieceDHTValue +from storb.protocol import PieceChallenge +from storb.util.logging import get_logger from storb.util.message_signing import PieceMessage, sign_message from storb.util.piece import piece_hash from storb.util.query import Payload from storb.validator.types import PieceTask, ProcessedPieceResponse +logger = get_logger(__name__) + class PieceProcessingMixin: + async def store_piece_challenge( + self, miners: list[int], piece_task: PieceTask, tag: APDPTag + ): + """Store a piece challenge entry in the DB. + + Parameters + ---------- + miners : list[int] | str + The list of miner UIDs that stored the piece. + piece_task : PieceTask + The piece task that was processed. + tag : str + The cryptographic tag of the piece. + """ + + miners = json.dumps(miners) + async with db.get_db_connection(self.db_dir) as conn: + await db.set_piece_challenge_entry( + conn, + entry=PieceChallenge( + miner_uids=miners, + piece_hash=piece_task.piece_hash, + challenge_timestamp=datetime.now().isoformat(), + tag=tag, + ), + ) + def consume_piece_queue(self): """ Continuously consume the self.piece_queue, process each `PieceTask`, - and store piece metadata into the DHT. + and store piece metadata into the DHT (or DB) via an async method. + + Since this is running in a *thread* (not in an asyncio event loop), + we call 'asyncio.run(...)' to synchronously execute the async code. """ + while True: + piece_task = self.piece_queue.get() + + if piece_task is None: + self.piece_queue.task_done() + break + try: - piece_task = self.piece_queue.get() - if piece_task is None: - # Poison pill or sentinel - self.piece_queue.task_done() - break - # --------------- - # Generate APDP tag (if needed) - # --------------- - # try: - # tag = self.process_tag(piece_task.data) - # except Exception as e: - # self.logger.error(f"Error processing tag: {e}") - # self.piece_queue.task_done() - # continue - - # --------------- - # Sign piece and store to DHT - # --------------- - try: - message = PieceMessage( - piece_hash=piece_task.piece_hash, - chunk_idx=piece_task.chunk_idx, - piece_idx=piece_task.piece_idx, - piece_type=piece_task.piece_type, - ) - signature = sign_message(message, self.keypair) - except Exception as e: - self.logger.error(f"Error signing piece: {e}") - self.piece_queue.task_done() - continue + # Log that we received a task + logger.info(f"Consuming piece task: {piece_task.piece_hash}") + + # Generate the tag (synchronous) + tag_obj = self.challenge.generate_tag(piece_task.data) + tag_json = tag_obj.model_dump_json() - # get the set/list of miners that stored this piece + # Retrieve which miners stored this piece miners = self.piece_miners.get(piece_task.piece_hash, set()) - # store piece metadata in DHT - self.dht.store_piece_entry( - piece_hash=piece_task.piece_hash, - value=self.piece_dht_value_class( # e.g. PieceDHTValue - piece_hash=piece_task.piece_hash, - validator_id=self.uid, - miner_id=miners, - chunk_idx=piece_task.chunk_idx, - piece_idx=piece_task.piece_idx, - piece_type=piece_task.piece_type, - signature=signature, - ), + asyncio.run( + self.store_piece_challenge( + tag=tag_json, + miners=miners, + piece_task=piece_task, + ) ) - self.piece_queue.task_done() except Exception as e: - self.logger.error(f"Error consuming piece task: {e}") + logger.error(f"Error consuming piece task: {e}", exc_info=True) + + finally: self.piece_queue.task_done() - continue - def process_tag(self, data: bytes) -> str: - """ - Generate a cryptographic tag for a piece of data (APDP or another scheme). - """ - # Example: store a JSON string, or something - tag = self.challenge.generate_tag(data) - return tag.model_dump_json() + logger.info("Exiting consume_piece_queue thread.") async def process_pieces(self, pieces, hotkeys): """ @@ -123,15 +132,41 @@ async def handle_batch_requests(): self.piece_miners.setdefault(piece_hashes[real_idx], []).extend( [uid for (uid, p) in batch_result if p and p.data] ) + + try: + message = PieceMessage( + piece_hash=piece_hashes[real_idx], + chunk_idx=pieces[real_idx].chunk_idx, + piece_idx=pieces[real_idx].piece_idx, + piece_type=pieces[real_idx].piece_type, + ) + + signature = sign_message(message, self.keypair) + except Exception as e: + logger.error(f"Error signing piece: {e}") + continue + + miners = self.piece_miners.get(piece_hashes[real_idx], set()) + self.dht.store_piece_entry( + piece_hash=piece_hashes[real_idx], + value=PieceDHTValue( + piece_hash=piece_hashes[real_idx], + validator_id=self.uid, + miner_uids=miners, + chunk_idx=pieces[real_idx].chunk_idx, + piece_idx=pieces[real_idx].piece_idx, + piece_type=pieces[real_idx].piece_type, + signature=signature, + ), + ) + self.piece_queue.put( PieceTask( - piece_idx=real_idx, piece_hash=piece_hashes[real_idx], data=pieces[real_idx].data, - chunk_idx=pieces[real_idx].chunk_idx, - piece_type=pieces[real_idx].piece_type, ) ) + logger.info(f"Curr piece queue size: {self.piece_queue.qsize()}") to_query = [] # Loop pieces diff --git a/storb/validator/query.py b/storb/validator/query.py index 8a331a4..5ea599b 100644 --- a/storb/validator/query.py +++ b/storb/validator/query.py @@ -44,9 +44,9 @@ async def query_miner( server_addr = f"http://{node.ip}:{node.port}" logger.info(f"Querying miner at {server_addr}") symmetric_key = self.symmetric_keys.get(node.node_id) - # if not symmetric_key: - # logger.warning(f"Entry for node ID {node.node_id} not found") - # return node.node_id, None + if not symmetric_key: + logger.warning(f"Entry for node ID {node.node_id} not found") + return node.node_id, None _, symmetric_key_uuid = symmetric_key try: logger.info( diff --git a/storb/validator/routes.py b/storb/validator/routes.py index 5ae3442..2c1e4e0 100644 --- a/storb/validator/routes.py +++ b/storb/validator/routes.py @@ -318,12 +318,12 @@ async def get_file(self, infohash: str): latencies = np.full(len(self.metagraph.nodes), QUERY_TIMEOUT, dtype=np.float32) # TODO: check if the lengths of the chunk ids and chunks_metadata are the same - async def get_piece(piece_id, miner_id): + async def get_piece(piece_id, miner_uids): synapse = protocol.Retrieve(piece_id=piece_id) payload = Payload(data=synapse) try: response = await self.query_miner( - miner_hotkey=list(self.metagraph.nodes.keys())[miner_id], + miner_hotkey=list(self.metagraph.nodes.keys())[miner_uids], endpoint="/retrieve", method="POST", payload=payload, @@ -332,9 +332,9 @@ async def get_piece(piece_id, miner_id): return response except Exception as e: logger.error( - f"Error querying miner {miner_id} for piece {piece_id}: {e}" + f"Error querying miner {miner_uids} for piece {piece_id}: {e}" ) - return (list(self.metagraph.nodes.keys())[miner_id], None) + return (list(self.metagraph.nodes.keys())[miner_uids], None) for idx, _ in enumerate(tracker.chunk_ids): chunks_metadata: ChunkDHTValue = tracker.chunks_metadata[idx] @@ -351,16 +351,16 @@ async def get_piece(piece_id, miner_id): chunk_pieces_metadata = tracker.pieces_metadata[idx] for piece_idx, piece_id in enumerate(chunks_metadata.piece_hashes): - miner_ids = chunk_pieces_metadata[piece_idx].miner_id - if not miner_ids: + miner_uids = chunk_pieces_metadata[piece_idx].miner_uids + if not miner_uids: logger.error( f"No miners available for piece {piece_id} in chunk {idx}" ) piece_ids_match = False continue tasks = [ - asyncio.create_task(get_piece(piece_id, miner_id)) - for miner_id in miner_ids + asyncio.create_task(get_piece(piece_id, miner_uids)) + for miner_uids in miner_uids ] piece_found = False diff --git a/storb/validator/types.py b/storb/validator/types.py index b6c37d4..1c7a876 100644 --- a/storb/validator/types.py +++ b/storb/validator/types.py @@ -3,11 +3,8 @@ @dataclass class PieceTask: - piece_idx: int piece_hash: str data: bytes - chunk_idx: int - piece_type: int # or enum if you have from storb.util.piece import PieceType @dataclass diff --git a/storb/validator/validator.py b/storb/validator/validator.py index eedd120..3a3bea8 100644 --- a/storb/validator/validator.py +++ b/storb/validator/validator.py @@ -1,4 +1,5 @@ import asyncio +import json import queue import threading import time @@ -84,8 +85,9 @@ def __init__(self): self.loop = asyncio.get_event_loop() self.piece_queue = queue.Queue() + self.piece_consumer_threads: list[threading.Thread] = [] self.piece_miners: dict[str, set[int]] = {} - self.consumer_threads: list[threading.Thread] = [] + self.start_piece_consumers() def start_piece_consumers(self, num_consumers: int = PIECE_CONSUMERS): @@ -97,12 +99,10 @@ def start_piece_consumers(self, num_consumers: int = PIECE_CONSUMERS): Number of threads to start, by default PIECE_CONSUMERS """ - for i in range(num_consumers): - t = threading.Thread( - target=self.consume_piece_queue, name=f"PieceConsumer-{i+1}" - ) - t.start() - self.consumer_threads.append(t) + for _ in range(num_consumers): + thread = threading.Thread(target=self.consume_piece_queue) + thread.start() + self.piece_consumer_threads.append(thread) async def start(self): self.app_init() @@ -218,6 +218,8 @@ async def forward(self): - Updating the scores """ + logger.info(f"current tasks: {self.piece_queue.qsize()}") + # TODO: should we lock the db when scoring? # remove expired challenges @@ -228,17 +230,20 @@ async def forward(self): async with db.get_db_connection(self.db_dir) as conn: miner_stats = await db.get_all_miner_stats(conn) try: - challenge_piece = await db.get_random_piece(conn, self.uid) + challenge_piece = await db.get_random_piece_challenge_entry(conn) + if challenge_piece is None: + logger.info("No challenge pieces found") + return except Exception as e: logger.error(f"Failed to get random piece: {e}") return - - if isinstance(challenge_piece.miner_id, set): - random_miner = np.random.choice(list(challenge_piece.miner_id)) - challenge_piece.miner_id = random_miner + challenge_piece.miner_uids = json.loads(challenge_piece.miner_uids) + if isinstance(challenge_piece.miner_uids, list): + random_miner = np.random.choice(challenge_piece.miner_uids) + challenge_piece.miner_uids = random_miner else: logger.error( - f"Miner ID is not a list: {challenge_piece.miner_id} it is {type(challenge_piece.miner_id)}" + f"Miner ID is not a list: {challenge_piece.miner_uids} it is {type(challenge_piece.miner_uids)}" ) # skip challenge logger.warning( @@ -249,7 +254,7 @@ async def forward(self): if challenge_piece is not None: try: await self.challenge_miner( - challenge_piece.miner_id, + challenge_piece.miner_uids, challenge_piece.piece_hash, challenge_piece.tag, )