Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ If you find our work useful, please consider citing:

## Installation
- Clone repo and install submodules
```bash=
```bash
git clone --recurse-submodules https://github.com/tancik/StegaStamp.git
cd StegaStamp
```
- Install tensorflow (tested with tf 1.13)
- Python 3 required
- Download dependencies
```bash=
```bash
pip install -r requirements.txt
```

Expand All @@ -42,7 +42,7 @@ TRAIN_PATH = DIR_OF_DATASET_IMAGES
```

- Train model
```bash=
```bash
bash scripts/base.sh EXP_NAME
```
The training is performed in `train.py`. There are a number of hyperparameters, many corresponding to the augmentation parameters. `scripts/bash.sh` provides a good starting place.
Expand All @@ -52,15 +52,15 @@ The training code for the detector model (used to segment StegaStamps) is not in

### Tensorboard
To visualize the training run the following command and navigate to http://localhost:6006 in your browser.
```bash=
```bash
tensorboard --logdir logs
```

## Encoding a Message
The script `encode_image.py` can be used to encode a message into an image or a directory of images. The default model expects a utf-8 encoded secret that is <= 7 characters (100 bit message -> 56 bits after ECC).

Encode a message into an image:
```bash=
```bash
python encode_image.py \
saved_models/stegastamp_pretrained \
--image test_im.png \
Expand All @@ -69,11 +69,29 @@ python encode_image.py \
```
This will save both the StegaStamp and the residual that was applied to the original image.

## DCSS Forensic Mark (Real-time Video Encoding)
The script `encode_video.py` embeds DCSS forensic mark data into each video frame in a real-time, inline process. It computes a 16-bit timer index (15-minute increments across a 366-day cycle, repeating annually) and packs it with a 19/20-bit location identifier plus adapter bits into the 56-bit payload used by the StegaStamp model (BCH ECC is applied automatically).

Each payload is embedded across consecutive frames; by default the payload repeats for a 5-minute segment so that every 5-minute window includes all adapter bits. You can override the frame window if needed.

```bash
python encode_video.py \
saved_models/stegastamp_pretrained \
--video input.mp4 \
--save_video output.avi \
--location_id 12345 \
--location_bits 20 \
--adapter_bits 00000000000000000000 \
--start_time 2026-02-08T04:00:00+00:00
```

Adapter bit length must match the remaining payload space (20 bits when using a 20-bit location ID, 21 bits when using a 19-bit location ID).

## Decoding a Message
The script `decode_image.py` can be used to decode a message from a StegaStamp.

Example usage:
```bash=
```bash
python decode_image.py \
saved_models/stegastamp_pretrained \
--image out/test_hidden.png
Expand All @@ -84,7 +102,7 @@ The script `detector.py` can be used to detect and decode StegaStamps in an imag

To use the detector, make sure to download the detector model as described in the installation section. The recomended input video resolution is 1920x1080.

```bash=
```bash
python detector.py \
--detector_model detector_models/stegastamp_detector \
--decoder_model saved_models/stegastamp_pretrained \
Expand Down
133 changes: 133 additions & 0 deletions encode_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import argparse
from datetime import datetime, timedelta, timezone

import cv2
import numpy as np
from PIL import Image, ImageOps
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import signature_constants

import forensic_mark

SECONDS_PER_MINUTE = 60


def parse_start_time(value):
if value is None:
return datetime.now(timezone.utc)
parsed = datetime.fromisoformat(value)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed.astimezone(timezone.utc)


def load_model(model_path):
sess = tf.InteractiveSession(graph=tf.Graph())
model = tf.saved_model.loader.load(sess, [tag_constants.SERVING], model_path)

signature = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
input_secret = sess.graph.get_tensor_by_name(signature.inputs['secret'].name)
input_image = sess.graph.get_tensor_by_name(signature.inputs['image'].name)
output_stegastamp = sess.graph.get_tensor_by_name(signature.outputs['stegastamp'].name)
return sess, input_secret, input_image, output_stegastamp


def resize_for_model(frame_rgb, size):
resample = Image.Resampling.BILINEAR if hasattr(Image, "Resampling") else Image.BILINEAR
pil_image = Image.fromarray(frame_rgb)
fitted = ImageOps.fit(pil_image, size, method=resample)
return np.array(fitted, dtype=np.float32) / 255.0


def resize_for_output(stegastamp, size):
resample = Image.Resampling.BILINEAR if hasattr(Image, "Resampling") else Image.BILINEAR
stega_img = Image.fromarray(stegastamp)
resized = stega_img.resize(size, resample=resample)
return np.array(resized)


def main():
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str)
parser.add_argument('--video', type=str, required=True)
parser.add_argument('--save_video', type=str, required=True)
parser.add_argument('--location_id', type=int, required=True)
parser.add_argument('--location_bits', type=int, choices=[19, 20], default=20)
parser.add_argument('--adapter_bits', type=str, default=None)
parser.add_argument('--start_time', type=str, default=None,
help="ISO-8601 start time (defaults to now in UTC)")
parser.add_argument('--fourcc', type=str, default="XVID",
help="FourCC codec for output video (default: XVID)")
parser.add_argument('--segment_minutes', type=int, default=5)
parser.add_argument('--frames_per_payload', type=int, default=None,
help="Override frames per payload window (default: segment length)")
args = parser.parse_args()

start_time = parse_start_time(args.start_time)
sess, input_secret, input_image, output_stegastamp = load_model(args.model)

cap = cv2.VideoCapture(args.video)
if not cap.isOpened():
raise ValueError("Unable to open input video.")
fps = cap.get(cv2.CAP_PROP_FPS)
if fps <= 0:
fps = 30.0
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

if len(args.fourcc) != 4:
raise ValueError("FourCC codec must be four characters.")
fourcc = cv2.VideoWriter_fourcc(*args.fourcc)
out = cv2.VideoWriter(args.save_video, fourcc, fps, (width, height))
if not out.isOpened():
raise ValueError("Unable to open output video writer.")

frames_per_payload = args.frames_per_payload
if frames_per_payload is None:
frames_per_payload = max(1, int(round(fps * args.segment_minutes * SECONDS_PER_MINUTE)))
if frames_per_payload <= 0:
raise ValueError("Frames per payload must be positive.")

frame_index = 0
current_secret = None
last_timer_index = None
last_payload_frame = -frames_per_payload

while True:
ret, frame = cap.read()
if not ret:
break

elapsed = timedelta(seconds=frame_index / fps)
frame_time = start_time + elapsed
timer_index = forensic_mark.timer_index_for_datetime(frame_time)

if timer_index != last_timer_index or frame_index - last_payload_frame >= frames_per_payload:
current_secret = forensic_mark.build_secret_bits(
timer_index=timer_index,
location_id=args.location_id,
location_bits=args.location_bits,
adapter_bits=args.adapter_bits,
)
last_timer_index = timer_index
last_payload_frame = frame_index

frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
model_input = resize_for_model(frame_rgb, (400, 400))
feed_dict = {input_secret: [current_secret], input_image: [model_input]}
hidden_img = sess.run(output_stegastamp, feed_dict=feed_dict)[0]

rescaled = (hidden_img * 255).astype(np.uint8)
output_rgb = resize_for_output(rescaled, (width, height))
output_bgr = cv2.cvtColor(output_rgb, cv2.COLOR_RGB2BGR)
out.write(output_bgr)

frame_index += 1

cap.release()
out.release()


if __name__ == "__main__":
main()
131 changes: 131 additions & 0 deletions forensic_mark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import bchlib
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone

BCH_POLYNOMIAL = 137
BCH_BITS = 5

TIMER_INTERVAL_MINUTES = 15
TIMERS_PER_DAY = 24 * 60 // TIMER_INTERVAL_MINUTES
# DCSS timer indices use a fixed 366-day annual cycle.
DAYS_PER_YEAR = 366
TOTAL_TIMERS = DAYS_PER_YEAR * TIMERS_PER_DAY
TIMER_BITS = 16
DATA_BITS = 56
DATA_BYTES = DATA_BITS // 8
SECRET_BITS = 100
FOOTER_BITS = 4


@dataclass(frozen=True)
class ForensicMark:
timer_index: int
location_id: int
adapter_bits: str
location_bits: int


def normalize_datetime(value):
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)


def timer_index_for_datetime(value):
"""Return the DCSS timer index using a 366-day repeating cycle."""
value = normalize_datetime(value)
day_index = (value.timetuple().tm_yday - 1) % DAYS_PER_YEAR
quarter_hour = value.minute // TIMER_INTERVAL_MINUTES
timer_index = day_index * TIMERS_PER_DAY + value.hour * (60 // TIMER_INTERVAL_MINUTES) + quarter_hour
return timer_index % TOTAL_TIMERS


def datetime_for_timer_index(value, year=None):
if year is None:
year = datetime.now(timezone.utc).year
day_index, remainder = divmod(value % TOTAL_TIMERS, TIMERS_PER_DAY)
hour, quarter = divmod(remainder, 60 // TIMER_INTERVAL_MINUTES)
minute = quarter * TIMER_INTERVAL_MINUTES
base = datetime(year, 1, 1, tzinfo=timezone.utc) + timedelta(days=day_index)
return base.replace(hour=hour, minute=minute, second=0, microsecond=0)


def adapter_bit_length(location_bits):
"""Return the number of adapter bits available for the payload."""
return DATA_BITS - TIMER_BITS - location_bits


def _coerce_adapter_bits(adapter_bits, length):
if adapter_bits is None:
return "0" * length
if isinstance(adapter_bits, int):
if adapter_bits < 0 or adapter_bits >= (1 << length):
raise ValueError(f"Adapter value must fit in {length} bits.")
return format(adapter_bits, f"0{length}b")
candidate = str(adapter_bits).strip()
if candidate.startswith("0b"):
candidate = candidate[2:]
if set(candidate) <= {"0", "1"}:
if len(candidate) != length:
raise ValueError(f"Adapter bits must be {length} bits long.")
return candidate
value = int(candidate)
if value < 0 or value >= (1 << length):
raise ValueError(f"Adapter value must fit in {length} bits.")
return format(value, f"0{length}b")


def build_payload_bits(timer_index, location_id, location_bits=20, adapter_bits=None):
if location_bits not in (19, 20):
raise ValueError("Location bits must be 19 or 20.")
if timer_index < 0 or timer_index >= TOTAL_TIMERS:
raise ValueError(f"Timer index must be in [0, {TOTAL_TIMERS}).")
if location_id < 0 or location_id >= (1 << location_bits):
raise ValueError(f"Location id must fit in {location_bits} bits.")
adapter_len = adapter_bit_length(location_bits)
adapter_bits = _coerce_adapter_bits(adapter_bits, adapter_len)
payload_value = ((timer_index << (location_bits + adapter_len)) |
(location_id << adapter_len) |
int(adapter_bits, 2))
return format(payload_value, f"0{DATA_BITS}b")


def build_payload_bytes(timer_index, location_id, location_bits=20, adapter_bits=None):
payload_bits = build_payload_bits(timer_index, location_id, location_bits, adapter_bits)
payload_value = int(payload_bits, 2)
return payload_value.to_bytes(DATA_BYTES, "big")


def build_secret_bits(timer_index, location_id, location_bits=20, adapter_bits=None):
payload = bytearray(build_payload_bytes(timer_index, location_id, location_bits, adapter_bits))
bch = bchlib.BCH(BCH_POLYNOMIAL, BCH_BITS)
ecc = bch.encode(payload)
packet = payload + ecc
packet_binary = ''.join(format(x, '08b') for x in packet)
secret = [int(x) for x in packet_binary]
secret.extend([0] * FOOTER_BITS)
if len(secret) != SECRET_BITS:
raise ValueError("Secret bit length mismatch.")
return secret


def parse_payload_bits(payload_bits, location_bits=20):
if location_bits not in (19, 20):
raise ValueError("Location bits must be 19 or 20.")
payload_bits = payload_bits.zfill(DATA_BITS)
timer_bits = payload_bits[:TIMER_BITS]
location_start = TIMER_BITS
location_end = TIMER_BITS + location_bits
location_bits_str = payload_bits[location_start:location_end]
adapter_bits = payload_bits[location_end:]
return ForensicMark(
timer_index=int(timer_bits, 2),
location_id=int(location_bits_str, 2),
adapter_bits=adapter_bits,
location_bits=location_bits,
)


def parse_payload_bytes(payload, location_bits=20):
payload_bits = format(int.from_bytes(payload, "big"), f"0{DATA_BITS}b")
return parse_payload_bits(payload_bits, location_bits=location_bits)