From 23ea164215f8353b9e5095d0c007fa4014c75916 Mon Sep 17 00:00:00 2001 From: kqjy Date: Sat, 24 Jan 2026 19:38:17 +0800 Subject: [PATCH] Add bi-directional site replication with LWW conflict resolution --- app/__init__.py | 14 ++ app/config.py | 14 +- app/replication.py | 15 +- app/s3_api.py | 11 +- app/site_sync.py | 396 ++++++++++++++++++++++++++++++++++ tests/test_site_sync.py | 461 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 902 insertions(+), 9 deletions(-) create mode 100644 app/site_sync.py create mode 100644 tests/test_site_sync.py diff --git a/app/__init__.py b/app/__init__.py index 02c6472..2968c03 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -208,6 +208,20 @@ def create_app( system_metrics_collector.set_storage(storage) app.extensions["system_metrics"] = system_metrics_collector + site_sync_worker = None + if app.config.get("SITE_SYNC_ENABLED", False): + from .site_sync import SiteSyncWorker + site_sync_worker = SiteSyncWorker( + storage=storage, + connections=connections, + replication_manager=replication, + storage_root=storage_root, + interval_seconds=app.config.get("SITE_SYNC_INTERVAL_SECONDS", 60), + batch_size=app.config.get("SITE_SYNC_BATCH_SIZE", 100), + ) + site_sync_worker.start() + app.extensions["site_sync"] = site_sync_worker + @app.errorhandler(500) def internal_error(error): return render_template('500.html'), 500 diff --git a/app/config.py b/app/config.py index 8e04441..b39000f 100644 --- a/app/config.py +++ b/app/config.py @@ -94,6 +94,9 @@ class AppConfig: server_connection_limit: int server_backlog: int server_channel_timeout: int + site_sync_enabled: bool + site_sync_interval_seconds: int + site_sync_batch_size: int @classmethod def from_env(cls, overrides: Optional[Dict[str, Any]] = None) -> "AppConfig": @@ -201,6 +204,9 @@ class AppConfig: server_connection_limit = int(_get("SERVER_CONNECTION_LIMIT", 100)) server_backlog = int(_get("SERVER_BACKLOG", 1024)) server_channel_timeout = int(_get("SERVER_CHANNEL_TIMEOUT", 120)) + site_sync_enabled = str(_get("SITE_SYNC_ENABLED", "0")).lower() in {"1", "true", "yes", "on"} + site_sync_interval_seconds = int(_get("SITE_SYNC_INTERVAL_SECONDS", 60)) + site_sync_batch_size = int(_get("SITE_SYNC_BATCH_SIZE", 100)) return cls(storage_root=storage_root, max_upload_size=max_upload_size, @@ -249,7 +255,10 @@ class AppConfig: server_threads=server_threads, server_connection_limit=server_connection_limit, server_backlog=server_backlog, - server_channel_timeout=server_channel_timeout) + server_channel_timeout=server_channel_timeout, + site_sync_enabled=site_sync_enabled, + site_sync_interval_seconds=site_sync_interval_seconds, + site_sync_batch_size=site_sync_batch_size) def validate_and_report(self) -> list[str]: """Validate configuration and return a list of warnings/issues. @@ -420,4 +429,7 @@ class AppConfig: "SERVER_CONNECTION_LIMIT": self.server_connection_limit, "SERVER_BACKLOG": self.server_backlog, "SERVER_CHANNEL_TIMEOUT": self.server_channel_timeout, + "SITE_SYNC_ENABLED": self.site_sync_enabled, + "SITE_SYNC_INTERVAL_SECONDS": self.site_sync_interval_seconds, + "SITE_SYNC_BATCH_SIZE": self.site_sync_batch_size, } diff --git a/app/replication.py b/app/replication.py index 4eacdef..9cab869 100644 --- a/app/replication.py +++ b/app/replication.py @@ -27,6 +27,7 @@ STREAMING_THRESHOLD_BYTES = 10 * 1024 * 1024 REPLICATION_MODE_NEW_ONLY = "new_only" REPLICATION_MODE_ALL = "all" +REPLICATION_MODE_BIDIRECTIONAL = "bidirectional" def _create_s3_client(connection: RemoteConnection, *, health_check: bool = False) -> Any: @@ -127,10 +128,12 @@ class ReplicationRule: target_connection_id: str target_bucket: str enabled: bool = True - mode: str = REPLICATION_MODE_NEW_ONLY + mode: str = REPLICATION_MODE_NEW_ONLY created_at: Optional[float] = None stats: ReplicationStats = field(default_factory=ReplicationStats) - + sync_deletions: bool = True + last_pull_at: Optional[float] = None + def to_dict(self) -> dict: return { "bucket_name": self.bucket_name, @@ -140,8 +143,10 @@ class ReplicationRule: "mode": self.mode, "created_at": self.created_at, "stats": self.stats.to_dict(), + "sync_deletions": self.sync_deletions, + "last_pull_at": self.last_pull_at, } - + @classmethod def from_dict(cls, data: dict) -> "ReplicationRule": stats_data = data.pop("stats", {}) @@ -149,6 +154,10 @@ class ReplicationRule: data["mode"] = REPLICATION_MODE_NEW_ONLY if "created_at" not in data: data["created_at"] = None + if "sync_deletions" not in data: + data["sync_deletions"] = True + if "last_pull_at" not in data: + data["last_pull_at"] = None rule = cls(**data) rule.stats = ReplicationStats.from_dict(stats_data) if stats_data else ReplicationStats() return rule diff --git a/app/s3_api.py b/app/s3_api.py index f576c32..1f49e15 100644 --- a/app/s3_api.py +++ b/app/s3_api.py @@ -2446,7 +2446,8 @@ def object_handler(bucket_name: str, object_key: str): operation="Put", ) - if "S3ReplicationAgent" not in request.headers.get("User-Agent", ""): + user_agent = request.headers.get("User-Agent", "") + if "S3ReplicationAgent" not in user_agent and "SiteSyncAgent" not in user_agent: _replication_manager().trigger_replication(bucket_name, object_key, action="write") return response @@ -2592,7 +2593,7 @@ def object_handler(bucket_name: str, object_key: str): ) user_agent = request.headers.get("User-Agent", "") - if "S3ReplicationAgent" not in user_agent: + if "S3ReplicationAgent" not in user_agent and "SiteSyncAgent" not in user_agent: _replication_manager().trigger_replication(bucket_name, object_key, action="delete") return Response(status=204) @@ -2826,9 +2827,9 @@ def _copy_object(dest_bucket: str, dest_key: str, copy_source: str) -> Response: ) user_agent = request.headers.get("User-Agent", "") - if "S3ReplicationAgent" not in user_agent: + if "S3ReplicationAgent" not in user_agent and "SiteSyncAgent" not in user_agent: _replication_manager().trigger_replication(dest_bucket, dest_key, action="write") - + root = Element("CopyObjectResult") SubElement(root, "LastModified").text = meta.last_modified.isoformat() if meta.etag: @@ -3040,7 +3041,7 @@ def _complete_multipart_upload(bucket_name: str, object_key: str) -> Response: return _error_response("InvalidPart", str(exc), 400) user_agent = request.headers.get("User-Agent", "") - if "S3ReplicationAgent" not in user_agent: + if "S3ReplicationAgent" not in user_agent and "SiteSyncAgent" not in user_agent: _replication_manager().trigger_replication(bucket_name, object_key, action="write") root = Element("CompleteMultipartUploadResult") diff --git a/app/site_sync.py b/app/site_sync.py new file mode 100644 index 0000000..306ac28 --- /dev/null +++ b/app/site_sync.py @@ -0,0 +1,396 @@ +from __future__ import annotations + +import json +import logging +import tempfile +import threading +import time +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +import boto3 +from botocore.config import Config +from botocore.exceptions import ClientError + +if TYPE_CHECKING: + from .connections import ConnectionStore, RemoteConnection + from .replication import ReplicationManager, ReplicationRule + from .storage import ObjectStorage + +logger = logging.getLogger(__name__) + +SITE_SYNC_USER_AGENT = "SiteSyncAgent/1.0" +SITE_SYNC_CONNECT_TIMEOUT = 10 +SITE_SYNC_READ_TIMEOUT = 120 +CLOCK_SKEW_TOLERANCE_SECONDS = 1.0 + + +@dataclass +class SyncedObjectInfo: + last_synced_at: float + remote_etag: str + source: str + + def to_dict(self) -> Dict[str, Any]: + return { + "last_synced_at": self.last_synced_at, + "remote_etag": self.remote_etag, + "source": self.source, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SyncedObjectInfo": + return cls( + last_synced_at=data["last_synced_at"], + remote_etag=data["remote_etag"], + source=data["source"], + ) + + +@dataclass +class SyncState: + synced_objects: Dict[str, SyncedObjectInfo] = field(default_factory=dict) + last_full_sync: Optional[float] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "synced_objects": {k: v.to_dict() for k, v in self.synced_objects.items()}, + "last_full_sync": self.last_full_sync, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SyncState": + synced_objects = {} + for k, v in data.get("synced_objects", {}).items(): + synced_objects[k] = SyncedObjectInfo.from_dict(v) + return cls( + synced_objects=synced_objects, + last_full_sync=data.get("last_full_sync"), + ) + + +@dataclass +class SiteSyncStats: + last_sync_at: Optional[float] = None + objects_pulled: int = 0 + objects_skipped: int = 0 + conflicts_resolved: int = 0 + deletions_applied: int = 0 + errors: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return { + "last_sync_at": self.last_sync_at, + "objects_pulled": self.objects_pulled, + "objects_skipped": self.objects_skipped, + "conflicts_resolved": self.conflicts_resolved, + "deletions_applied": self.deletions_applied, + "errors": self.errors, + } + + +@dataclass +class RemoteObjectMeta: + key: str + size: int + last_modified: datetime + etag: str + + @classmethod + def from_s3_object(cls, obj: Dict[str, Any]) -> "RemoteObjectMeta": + return cls( + key=obj["Key"], + size=obj.get("Size", 0), + last_modified=obj["LastModified"], + etag=obj.get("ETag", "").strip('"'), + ) + + +def _create_sync_client(connection: "RemoteConnection") -> Any: + config = Config( + user_agent_extra=SITE_SYNC_USER_AGENT, + connect_timeout=SITE_SYNC_CONNECT_TIMEOUT, + read_timeout=SITE_SYNC_READ_TIMEOUT, + retries={"max_attempts": 2}, + signature_version="s3v4", + s3={"addressing_style": "path"}, + request_checksum_calculation="when_required", + response_checksum_validation="when_required", + ) + return boto3.client( + "s3", + endpoint_url=connection.endpoint_url, + aws_access_key_id=connection.access_key, + aws_secret_access_key=connection.secret_key, + region_name=connection.region or "us-east-1", + config=config, + ) + + +class SiteSyncWorker: + def __init__( + self, + storage: "ObjectStorage", + connections: "ConnectionStore", + replication_manager: "ReplicationManager", + storage_root: Path, + interval_seconds: int = 60, + batch_size: int = 100, + ): + self.storage = storage + self.connections = connections + self.replication_manager = replication_manager + self.storage_root = storage_root + self.interval_seconds = interval_seconds + self.batch_size = batch_size + self._lock = threading.Lock() + self._shutdown = threading.Event() + self._sync_thread: Optional[threading.Thread] = None + self._bucket_stats: Dict[str, SiteSyncStats] = {} + + def start(self) -> None: + if self._sync_thread is not None and self._sync_thread.is_alive(): + return + self._shutdown.clear() + self._sync_thread = threading.Thread( + target=self._sync_loop, name="site-sync-worker", daemon=True + ) + self._sync_thread.start() + logger.info("Site sync worker started (interval=%ds)", self.interval_seconds) + + def shutdown(self) -> None: + self._shutdown.set() + if self._sync_thread is not None: + self._sync_thread.join(timeout=10.0) + logger.info("Site sync worker shut down") + + def trigger_sync(self, bucket_name: str) -> Optional[SiteSyncStats]: + from .replication import REPLICATION_MODE_BIDIRECTIONAL + rule = self.replication_manager.get_rule(bucket_name) + if not rule or rule.mode != REPLICATION_MODE_BIDIRECTIONAL or not rule.enabled: + return None + return self._sync_bucket(rule) + + def get_stats(self, bucket_name: str) -> Optional[SiteSyncStats]: + with self._lock: + return self._bucket_stats.get(bucket_name) + + def _sync_loop(self) -> None: + while not self._shutdown.is_set(): + self._shutdown.wait(timeout=self.interval_seconds) + if self._shutdown.is_set(): + break + self._run_sync_cycle() + + def _run_sync_cycle(self) -> None: + from .replication import REPLICATION_MODE_BIDIRECTIONAL + for bucket_name, rule in list(self.replication_manager._rules.items()): + if self._shutdown.is_set(): + break + if rule.mode != REPLICATION_MODE_BIDIRECTIONAL or not rule.enabled: + continue + try: + stats = self._sync_bucket(rule) + with self._lock: + self._bucket_stats[bucket_name] = stats + except Exception as e: + logger.exception("Site sync failed for bucket %s: %s", bucket_name, e) + + def _sync_bucket(self, rule: "ReplicationRule") -> SiteSyncStats: + stats = SiteSyncStats() + connection = self.connections.get(rule.target_connection_id) + if not connection: + logger.warning("Connection %s not found for bucket %s", rule.target_connection_id, rule.bucket_name) + stats.errors += 1 + return stats + + try: + local_objects = self._list_local_objects(rule.bucket_name) + except Exception as e: + logger.error("Failed to list local objects for %s: %s", rule.bucket_name, e) + stats.errors += 1 + return stats + + try: + remote_objects = self._list_remote_objects(rule, connection) + except Exception as e: + logger.error("Failed to list remote objects for %s: %s", rule.bucket_name, e) + stats.errors += 1 + return stats + + sync_state = self._load_sync_state(rule.bucket_name) + local_keys = set(local_objects.keys()) + remote_keys = set(remote_objects.keys()) + + to_pull = [] + for key in remote_keys: + remote_meta = remote_objects[key] + local_meta = local_objects.get(key) + if local_meta is None: + to_pull.append(key) + else: + resolution = self._resolve_conflict(local_meta, remote_meta) + if resolution == "pull": + to_pull.append(key) + stats.conflicts_resolved += 1 + else: + stats.objects_skipped += 1 + + pulled_count = 0 + for key in to_pull: + if self._shutdown.is_set(): + break + if pulled_count >= self.batch_size: + break + remote_meta = remote_objects[key] + success = self._pull_object(rule, key, connection, remote_meta) + if success: + stats.objects_pulled += 1 + pulled_count += 1 + sync_state.synced_objects[key] = SyncedObjectInfo( + last_synced_at=time.time(), + remote_etag=remote_meta.etag, + source="remote", + ) + else: + stats.errors += 1 + + if rule.sync_deletions: + for key in list(sync_state.synced_objects.keys()): + if key not in remote_keys and key in local_keys: + tracked = sync_state.synced_objects[key] + if tracked.source == "remote": + local_meta = local_objects.get(key) + if local_meta and local_meta.last_modified.timestamp() <= tracked.last_synced_at: + success = self._apply_remote_deletion(rule.bucket_name, key) + if success: + stats.deletions_applied += 1 + del sync_state.synced_objects[key] + + sync_state.last_full_sync = time.time() + self._save_sync_state(rule.bucket_name, sync_state) + + with self.replication_manager._stats_lock: + rule.last_pull_at = time.time() + self.replication_manager.save_rules() + + stats.last_sync_at = time.time() + logger.info( + "Site sync completed for %s: pulled=%d, skipped=%d, conflicts=%d, deletions=%d, errors=%d", + rule.bucket_name, + stats.objects_pulled, + stats.objects_skipped, + stats.conflicts_resolved, + stats.deletions_applied, + stats.errors, + ) + return stats + + def _list_local_objects(self, bucket_name: str) -> Dict[str, Any]: + from .storage import ObjectMeta + objects = self.storage.list_objects_all(bucket_name) + return {obj.key: obj for obj in objects} + + def _list_remote_objects(self, rule: "ReplicationRule", connection: "RemoteConnection") -> Dict[str, RemoteObjectMeta]: + s3 = _create_sync_client(connection) + result: Dict[str, RemoteObjectMeta] = {} + paginator = s3.get_paginator("list_objects_v2") + try: + for page in paginator.paginate(Bucket=rule.target_bucket): + for obj in page.get("Contents", []): + meta = RemoteObjectMeta.from_s3_object(obj) + result[meta.key] = meta + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchBucket": + return {} + raise + return result + + def _resolve_conflict(self, local_meta: Any, remote_meta: RemoteObjectMeta) -> str: + local_ts = local_meta.last_modified.timestamp() + remote_ts = remote_meta.last_modified.timestamp() + + if abs(remote_ts - local_ts) < CLOCK_SKEW_TOLERANCE_SECONDS: + local_etag = local_meta.etag or "" + if remote_meta.etag == local_etag: + return "skip" + return "pull" if remote_meta.etag > local_etag else "keep" + + return "pull" if remote_ts > local_ts else "keep" + + def _pull_object( + self, + rule: "ReplicationRule", + object_key: str, + connection: "RemoteConnection", + remote_meta: RemoteObjectMeta, + ) -> bool: + s3 = _create_sync_client(connection) + tmp_path = None + try: + tmp_dir = self.storage_root / ".myfsio.sys" / "tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp_file: + tmp_path = Path(tmp_file.name) + + s3.download_file(rule.target_bucket, object_key, str(tmp_path)) + + head_response = s3.head_object(Bucket=rule.target_bucket, Key=object_key) + user_metadata = head_response.get("Metadata", {}) + + with open(tmp_path, "rb") as f: + self.storage.put_object( + rule.bucket_name, + object_key, + f, + metadata=user_metadata if user_metadata else None, + ) + + logger.debug("Pulled object %s/%s from remote", rule.bucket_name, object_key) + return True + + except ClientError as e: + logger.error("Failed to pull %s/%s: %s", rule.bucket_name, object_key, e) + return False + except Exception as e: + logger.error("Failed to store pulled object %s/%s: %s", rule.bucket_name, object_key, e) + return False + finally: + if tmp_path and tmp_path.exists(): + try: + tmp_path.unlink() + except OSError: + pass + + def _apply_remote_deletion(self, bucket_name: str, object_key: str) -> bool: + try: + self.storage.delete_object(bucket_name, object_key) + logger.debug("Applied remote deletion for %s/%s", bucket_name, object_key) + return True + except Exception as e: + logger.error("Failed to apply remote deletion for %s/%s: %s", bucket_name, object_key, e) + return False + + def _sync_state_path(self, bucket_name: str) -> Path: + return self.storage_root / ".myfsio.sys" / "buckets" / bucket_name / "site_sync_state.json" + + def _load_sync_state(self, bucket_name: str) -> SyncState: + path = self._sync_state_path(bucket_name) + if not path.exists(): + return SyncState() + try: + data = json.loads(path.read_text(encoding="utf-8")) + return SyncState.from_dict(data) + except (json.JSONDecodeError, OSError, KeyError) as e: + logger.warning("Failed to load sync state for %s: %s", bucket_name, e) + return SyncState() + + def _save_sync_state(self, bucket_name: str, state: SyncState) -> None: + path = self._sync_state_path(bucket_name) + path.parent.mkdir(parents=True, exist_ok=True) + try: + path.write_text(json.dumps(state.to_dict(), indent=2), encoding="utf-8") + except OSError as e: + logger.warning("Failed to save sync state for %s: %s", bucket_name, e) diff --git a/tests/test_site_sync.py b/tests/test_site_sync.py new file mode 100644 index 0000000..4975375 --- /dev/null +++ b/tests/test_site_sync.py @@ -0,0 +1,461 @@ +import io +import json +import time +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from app.connections import ConnectionStore, RemoteConnection +from app.replication import ( + ReplicationManager, + ReplicationRule, + REPLICATION_MODE_BIDIRECTIONAL, + REPLICATION_MODE_NEW_ONLY, +) +from app.site_sync import ( + SiteSyncWorker, + SyncState, + SyncedObjectInfo, + SiteSyncStats, + RemoteObjectMeta, + CLOCK_SKEW_TOLERANCE_SECONDS, +) +from app.storage import ObjectStorage + + +@pytest.fixture +def storage(tmp_path: Path): + storage_root = tmp_path / "data" + storage_root.mkdir(parents=True) + return ObjectStorage(storage_root) + + +@pytest.fixture +def connections(tmp_path: Path): + connections_path = tmp_path / "connections.json" + store = ConnectionStore(connections_path) + conn = RemoteConnection( + id="test-conn", + name="Test Remote", + endpoint_url="http://localhost:9000", + access_key="remote-access", + secret_key="remote-secret", + region="us-east-1", + ) + store.add(conn) + return store + + +@pytest.fixture +def replication_manager(storage, connections, tmp_path): + rules_path = tmp_path / "replication_rules.json" + storage_root = tmp_path / "data" + storage_root.mkdir(exist_ok=True) + manager = ReplicationManager(storage, connections, rules_path, storage_root) + yield manager + manager.shutdown(wait=False) + + +@pytest.fixture +def site_sync_worker(storage, connections, replication_manager, tmp_path): + storage_root = tmp_path / "data" + worker = SiteSyncWorker( + storage=storage, + connections=connections, + replication_manager=replication_manager, + storage_root=storage_root, + interval_seconds=60, + batch_size=100, + ) + yield worker + worker.shutdown() + + +class TestSyncedObjectInfo: + def test_to_dict(self): + info = SyncedObjectInfo( + last_synced_at=1234567890.0, + remote_etag="abc123", + source="remote", + ) + result = info.to_dict() + assert result["last_synced_at"] == 1234567890.0 + assert result["remote_etag"] == "abc123" + assert result["source"] == "remote" + + def test_from_dict(self): + data = { + "last_synced_at": 9876543210.0, + "remote_etag": "def456", + "source": "local", + } + info = SyncedObjectInfo.from_dict(data) + assert info.last_synced_at == 9876543210.0 + assert info.remote_etag == "def456" + assert info.source == "local" + + +class TestSyncState: + def test_to_dict(self): + state = SyncState( + synced_objects={ + "test.txt": SyncedObjectInfo( + last_synced_at=1000.0, + remote_etag="etag1", + source="remote", + ) + }, + last_full_sync=2000.0, + ) + result = state.to_dict() + assert "test.txt" in result["synced_objects"] + assert result["synced_objects"]["test.txt"]["remote_etag"] == "etag1" + assert result["last_full_sync"] == 2000.0 + + def test_from_dict(self): + data = { + "synced_objects": { + "file.txt": { + "last_synced_at": 3000.0, + "remote_etag": "etag2", + "source": "remote", + } + }, + "last_full_sync": 4000.0, + } + state = SyncState.from_dict(data) + assert "file.txt" in state.synced_objects + assert state.synced_objects["file.txt"].remote_etag == "etag2" + assert state.last_full_sync == 4000.0 + + def test_from_dict_empty(self): + state = SyncState.from_dict({}) + assert state.synced_objects == {} + assert state.last_full_sync is None + + +class TestSiteSyncStats: + def test_to_dict(self): + stats = SiteSyncStats( + last_sync_at=1234567890.0, + objects_pulled=10, + objects_skipped=5, + conflicts_resolved=2, + deletions_applied=1, + errors=0, + ) + result = stats.to_dict() + assert result["objects_pulled"] == 10 + assert result["objects_skipped"] == 5 + assert result["conflicts_resolved"] == 2 + assert result["deletions_applied"] == 1 + assert result["errors"] == 0 + + +class TestRemoteObjectMeta: + def test_from_s3_object(self): + obj = { + "Key": "test/file.txt", + "Size": 1024, + "LastModified": datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + "ETag": '"abc123def456"', + } + meta = RemoteObjectMeta.from_s3_object(obj) + assert meta.key == "test/file.txt" + assert meta.size == 1024 + assert meta.last_modified == datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + assert meta.etag == "abc123def456" + + +class TestReplicationRuleBidirectional: + def test_rule_with_bidirectional_mode(self): + rule = ReplicationRule( + bucket_name="sync-bucket", + target_connection_id="test-conn", + target_bucket="remote-bucket", + enabled=True, + mode=REPLICATION_MODE_BIDIRECTIONAL, + sync_deletions=True, + ) + assert rule.mode == REPLICATION_MODE_BIDIRECTIONAL + assert rule.sync_deletions is True + assert rule.last_pull_at is None + + def test_rule_to_dict_includes_new_fields(self): + rule = ReplicationRule( + bucket_name="sync-bucket", + target_connection_id="test-conn", + target_bucket="remote-bucket", + mode=REPLICATION_MODE_BIDIRECTIONAL, + sync_deletions=False, + last_pull_at=1234567890.0, + ) + result = rule.to_dict() + assert result["mode"] == REPLICATION_MODE_BIDIRECTIONAL + assert result["sync_deletions"] is False + assert result["last_pull_at"] == 1234567890.0 + + def test_rule_from_dict_with_new_fields(self): + data = { + "bucket_name": "sync-bucket", + "target_connection_id": "test-conn", + "target_bucket": "remote-bucket", + "mode": REPLICATION_MODE_BIDIRECTIONAL, + "sync_deletions": False, + "last_pull_at": 1234567890.0, + } + rule = ReplicationRule.from_dict(data) + assert rule.mode == REPLICATION_MODE_BIDIRECTIONAL + assert rule.sync_deletions is False + assert rule.last_pull_at == 1234567890.0 + + def test_rule_from_dict_defaults_new_fields(self): + data = { + "bucket_name": "sync-bucket", + "target_connection_id": "test-conn", + "target_bucket": "remote-bucket", + } + rule = ReplicationRule.from_dict(data) + assert rule.sync_deletions is True + assert rule.last_pull_at is None + + +class TestSiteSyncWorker: + def test_start_and_shutdown(self, site_sync_worker): + site_sync_worker.start() + assert site_sync_worker._sync_thread is not None + assert site_sync_worker._sync_thread.is_alive() + site_sync_worker.shutdown() + assert not site_sync_worker._sync_thread.is_alive() + + def test_trigger_sync_no_rule(self, site_sync_worker): + result = site_sync_worker.trigger_sync("nonexistent-bucket") + assert result is None + + def test_trigger_sync_wrong_mode(self, site_sync_worker, replication_manager): + rule = ReplicationRule( + bucket_name="new-only-bucket", + target_connection_id="test-conn", + target_bucket="remote-bucket", + mode=REPLICATION_MODE_NEW_ONLY, + enabled=True, + ) + replication_manager.set_rule(rule) + result = site_sync_worker.trigger_sync("new-only-bucket") + assert result is None + + def test_trigger_sync_disabled_rule(self, site_sync_worker, replication_manager): + rule = ReplicationRule( + bucket_name="disabled-bucket", + target_connection_id="test-conn", + target_bucket="remote-bucket", + mode=REPLICATION_MODE_BIDIRECTIONAL, + enabled=False, + ) + replication_manager.set_rule(rule) + result = site_sync_worker.trigger_sync("disabled-bucket") + assert result is None + + def test_get_stats_no_sync(self, site_sync_worker): + stats = site_sync_worker.get_stats("nonexistent") + assert stats is None + + def test_resolve_conflict_remote_newer(self, site_sync_worker): + local_meta = MagicMock() + local_meta.last_modified = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + local_meta.etag = "local123" + + remote_meta = RemoteObjectMeta( + key="test.txt", + size=100, + last_modified=datetime(2025, 1, 2, 12, 0, 0, tzinfo=timezone.utc), + etag="remote456", + ) + + result = site_sync_worker._resolve_conflict(local_meta, remote_meta) + assert result == "pull" + + def test_resolve_conflict_local_newer(self, site_sync_worker): + local_meta = MagicMock() + local_meta.last_modified = datetime(2025, 1, 2, 12, 0, 0, tzinfo=timezone.utc) + local_meta.etag = "local123" + + remote_meta = RemoteObjectMeta( + key="test.txt", + size=100, + last_modified=datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + etag="remote456", + ) + + result = site_sync_worker._resolve_conflict(local_meta, remote_meta) + assert result == "keep" + + def test_resolve_conflict_same_time_same_etag(self, site_sync_worker): + ts = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + local_meta = MagicMock() + local_meta.last_modified = ts + local_meta.etag = "same123" + + remote_meta = RemoteObjectMeta( + key="test.txt", + size=100, + last_modified=ts, + etag="same123", + ) + + result = site_sync_worker._resolve_conflict(local_meta, remote_meta) + assert result == "skip" + + def test_resolve_conflict_same_time_different_etag(self, site_sync_worker): + ts = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + local_meta = MagicMock() + local_meta.last_modified = ts + local_meta.etag = "aaa" + + remote_meta = RemoteObjectMeta( + key="test.txt", + size=100, + last_modified=ts, + etag="zzz", + ) + + result = site_sync_worker._resolve_conflict(local_meta, remote_meta) + assert result == "pull" + + def test_sync_state_persistence(self, site_sync_worker, tmp_path): + bucket_name = "test-bucket" + state = SyncState( + synced_objects={ + "file1.txt": SyncedObjectInfo( + last_synced_at=time.time(), + remote_etag="etag1", + source="remote", + ) + }, + last_full_sync=time.time(), + ) + + site_sync_worker._save_sync_state(bucket_name, state) + + loaded = site_sync_worker._load_sync_state(bucket_name) + assert "file1.txt" in loaded.synced_objects + assert loaded.synced_objects["file1.txt"].remote_etag == "etag1" + + def test_load_sync_state_nonexistent(self, site_sync_worker): + state = site_sync_worker._load_sync_state("nonexistent-bucket") + assert state.synced_objects == {} + assert state.last_full_sync is None + + @patch("app.site_sync._create_sync_client") + def test_list_remote_objects(self, mock_create_client, site_sync_worker, connections, replication_manager): + mock_client = MagicMock() + mock_paginator = MagicMock() + mock_paginator.paginate.return_value = [ + { + "Contents": [ + { + "Key": "file1.txt", + "Size": 100, + "LastModified": datetime(2025, 1, 1, tzinfo=timezone.utc), + "ETag": '"etag1"', + }, + { + "Key": "file2.txt", + "Size": 200, + "LastModified": datetime(2025, 1, 2, tzinfo=timezone.utc), + "ETag": '"etag2"', + }, + ] + } + ] + mock_client.get_paginator.return_value = mock_paginator + mock_create_client.return_value = mock_client + + rule = ReplicationRule( + bucket_name="local-bucket", + target_connection_id="test-conn", + target_bucket="remote-bucket", + mode=REPLICATION_MODE_BIDIRECTIONAL, + ) + conn = connections.get("test-conn") + + result = site_sync_worker._list_remote_objects(rule, conn) + + assert "file1.txt" in result + assert "file2.txt" in result + assert result["file1.txt"].size == 100 + assert result["file2.txt"].size == 200 + + def test_list_local_objects(self, site_sync_worker, storage): + storage.create_bucket("test-bucket") + storage.put_object("test-bucket", "file1.txt", io.BytesIO(b"content1")) + storage.put_object("test-bucket", "file2.txt", io.BytesIO(b"content2")) + + result = site_sync_worker._list_local_objects("test-bucket") + + assert "file1.txt" in result + assert "file2.txt" in result + + @patch("app.site_sync._create_sync_client") + def test_sync_bucket_connection_not_found(self, mock_create_client, site_sync_worker, replication_manager): + rule = ReplicationRule( + bucket_name="test-bucket", + target_connection_id="missing-conn", + target_bucket="remote-bucket", + mode=REPLICATION_MODE_BIDIRECTIONAL, + enabled=True, + ) + replication_manager.set_rule(rule) + + stats = site_sync_worker._sync_bucket(rule) + assert stats.errors == 1 + + +class TestSiteSyncIntegration: + @patch("app.site_sync._create_sync_client") + def test_full_sync_cycle(self, mock_create_client, site_sync_worker, storage, connections, replication_manager): + storage.create_bucket("sync-bucket") + storage.put_object("sync-bucket", "local-only.txt", io.BytesIO(b"local content")) + + mock_client = MagicMock() + mock_paginator = MagicMock() + mock_paginator.paginate.return_value = [ + { + "Contents": [ + { + "Key": "remote-only.txt", + "Size": 100, + "LastModified": datetime(2025, 1, 15, tzinfo=timezone.utc), + "ETag": '"remoteetag"', + }, + ] + } + ] + mock_client.get_paginator.return_value = mock_paginator + mock_client.head_object.return_value = {"Metadata": {}} + + def mock_download(bucket, key, path): + Path(path).write_bytes(b"remote content") + + mock_client.download_file.side_effect = mock_download + mock_create_client.return_value = mock_client + + rule = ReplicationRule( + bucket_name="sync-bucket", + target_connection_id="test-conn", + target_bucket="remote-bucket", + mode=REPLICATION_MODE_BIDIRECTIONAL, + enabled=True, + ) + replication_manager.set_rule(rule) + + stats = site_sync_worker._sync_bucket(rule) + + assert stats.objects_pulled == 1 + assert stats.errors == 0 + + objects = site_sync_worker._list_local_objects("sync-bucket") + assert "local-only.txt" in objects + assert "remote-only.txt" in objects