Move performance-critical Python functions to Rust: streaming I/O, multipart assembly, and AES-256-GCM encryption

This commit is contained in:
2026-02-27 22:55:20 +08:00
parent d4657c389d
commit 5536330aeb
8 changed files with 816 additions and 46 deletions

View File

@@ -19,6 +19,13 @@ from cryptography.hazmat.primitives import hashes
if sys.platform != "win32":
import fcntl
try:
import myfsio_core as _rc
_HAS_RUST = True
except ImportError:
_rc = None
_HAS_RUST = False
logger = logging.getLogger(__name__)
@@ -338,6 +345,69 @@ class StreamingEncryptor:
output.seek(0)
return output
def encrypt_file(self, input_path: str, output_path: str) -> EncryptionMetadata:
data_key, encrypted_data_key = self.provider.generate_data_key()
base_nonce = secrets.token_bytes(12)
if _HAS_RUST:
_rc.encrypt_stream_chunked(
input_path, output_path, data_key, base_nonce, self.chunk_size
)
else:
with open(input_path, "rb") as stream:
aesgcm = AESGCM(data_key)
with open(output_path, "wb") as out:
out.write(b"\x00\x00\x00\x00")
chunk_index = 0
while True:
chunk = stream.read(self.chunk_size)
if not chunk:
break
chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index)
encrypted_chunk = aesgcm.encrypt(chunk_nonce, chunk, None)
out.write(len(encrypted_chunk).to_bytes(self.HEADER_SIZE, "big"))
out.write(encrypted_chunk)
chunk_index += 1
out.seek(0)
out.write(chunk_index.to_bytes(4, "big"))
return EncryptionMetadata(
algorithm="AES256",
key_id=self.provider.KEY_ID if hasattr(self.provider, "KEY_ID") else "local",
nonce=base_nonce,
encrypted_data_key=encrypted_data_key,
)
def decrypt_file(self, input_path: str, output_path: str,
metadata: EncryptionMetadata) -> None:
data_key = self.provider.decrypt_data_key(metadata.encrypted_data_key, metadata.key_id)
base_nonce = metadata.nonce
if _HAS_RUST:
_rc.decrypt_stream_chunked(input_path, output_path, data_key, base_nonce)
else:
with open(input_path, "rb") as stream:
chunk_count_bytes = stream.read(4)
if len(chunk_count_bytes) < 4:
raise EncryptionError("Invalid encrypted stream: missing header")
chunk_count = int.from_bytes(chunk_count_bytes, "big")
aesgcm = AESGCM(data_key)
with open(output_path, "wb") as out:
for chunk_index in range(chunk_count):
size_bytes = stream.read(self.HEADER_SIZE)
if len(size_bytes) < self.HEADER_SIZE:
raise EncryptionError(f"Invalid encrypted stream: truncated at chunk {chunk_index}")
chunk_size = int.from_bytes(size_bytes, "big")
encrypted_chunk = stream.read(chunk_size)
if len(encrypted_chunk) < chunk_size:
raise EncryptionError(f"Invalid encrypted stream: incomplete chunk {chunk_index}")
chunk_nonce = self._derive_chunk_nonce(base_nonce, chunk_index)
try:
decrypted_chunk = aesgcm.decrypt(chunk_nonce, encrypted_chunk, None)
out.write(decrypted_chunk)
except Exception as exc:
raise EncryptionError(f"Failed to decrypt chunk {chunk_index}: {exc}") from exc
class EncryptionManager:
"""Manages encryption providers and operations."""

View File

@@ -841,32 +841,61 @@ class ObjectStorage:
tmp_dir = self._system_root_path() / self.SYSTEM_TMP_DIR
tmp_dir.mkdir(parents=True, exist_ok=True)
tmp_path = tmp_dir / f"{uuid.uuid4().hex}.tmp"
try:
with tmp_path.open("wb") as target:
shutil.copyfileobj(stream, target)
new_size = tmp_path.stat().st_size
size_delta = new_size - existing_size
object_delta = 0 if is_overwrite else 1
if enforce_quota:
quota_check = self.check_quota(
bucket_name,
additional_bytes=max(0, size_delta),
additional_objects=object_delta,
if _HAS_RUST:
tmp_path = None
try:
tmp_path_str, etag, new_size = _rc.stream_to_file_with_md5(
stream, str(tmp_dir)
)
if not quota_check["allowed"]:
raise QuotaExceededError(
quota_check["message"] or "Quota exceeded",
quota_check["quota"],
quota_check["usage"],
)
tmp_path = Path(tmp_path_str)
size_delta = new_size - existing_size
object_delta = 0 if is_overwrite else 1
if enforce_quota:
quota_check = self.check_quota(
bucket_name,
additional_bytes=max(0, size_delta),
additional_objects=object_delta,
)
if not quota_check["allowed"]:
raise QuotaExceededError(
quota_check["message"] or "Quota exceeded",
quota_check["quota"],
quota_check["usage"],
)
shutil.move(str(tmp_path), str(destination))
finally:
if tmp_path:
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
else:
tmp_path = tmp_dir / f"{uuid.uuid4().hex}.tmp"
try:
with tmp_path.open("wb") as target:
shutil.copyfileobj(stream, target)
new_size = tmp_path.stat().st_size
size_delta = new_size - existing_size
object_delta = 0 if is_overwrite else 1
if enforce_quota:
quota_check = self.check_quota(
bucket_name,
additional_bytes=max(0, size_delta),
additional_objects=object_delta,
)
if not quota_check["allowed"]:
raise QuotaExceededError(
quota_check["message"] or "Quota exceeded",
quota_check["quota"],
quota_check["usage"],
)
if _HAS_RUST:
etag = _rc.md5_file(str(tmp_path))
else:
checksum = hashlib.md5()
with tmp_path.open("rb") as f:
while True:
@@ -876,13 +905,12 @@ class ObjectStorage:
checksum.update(chunk)
etag = checksum.hexdigest()
shutil.move(str(tmp_path), str(destination))
finally:
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
shutil.move(str(tmp_path), str(destination))
finally:
try:
tmp_path.unlink(missing_ok=True)
except OSError:
pass
stat = destination.stat()
@@ -1702,19 +1730,29 @@ class ObjectStorage:
if versioning_enabled and destination.exists():
archived_version_size = destination.stat().st_size
self._archive_current_version(bucket_id, safe_key, reason="overwrite")
checksum = hashlib.md5()
with destination.open("wb") as target:
if _HAS_RUST:
part_paths = []
for _, record in validated:
part_path = upload_root / record["filename"]
if not part_path.exists():
pp = upload_root / record["filename"]
if not pp.exists():
raise StorageError(f"Missing part file {record['filename']}")
with part_path.open("rb") as chunk:
while True:
data = chunk.read(1024 * 1024)
if not data:
break
checksum.update(data)
target.write(data)
part_paths.append(str(pp))
checksum_hex = _rc.assemble_parts_with_md5(part_paths, str(destination))
else:
checksum = hashlib.md5()
with destination.open("wb") as target:
for _, record in validated:
part_path = upload_root / record["filename"]
if not part_path.exists():
raise StorageError(f"Missing part file {record['filename']}")
with part_path.open("rb") as chunk:
while True:
data = chunk.read(1024 * 1024)
if not data:
break
checksum.update(data)
target.write(data)
checksum_hex = checksum.hexdigest()
except BlockingIOError:
raise StorageError("Another upload to this key is in progress")
@@ -1729,7 +1767,7 @@ class ObjectStorage:
)
stat = destination.stat()
etag = checksum.hexdigest()
etag = checksum_hex
metadata = manifest.get("metadata")
internal_meta = {"__etag__": etag, "__size__": str(stat.st_size)}