Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bitmind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# DEALINGS IN THE SOFTWARE.


__version__ = "2.2.9"
__version__ = "2.2.10"
version_split = __version__.split(".")
__spec_version__ = (
(1000 * int(version_split[0]))
Expand Down
126 changes: 56 additions & 70 deletions bitmind/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from bitmind.validator.config import TARGET_IMAGE_SIZE
from bitmind.utils.image_transforms import get_base_transforms

base_transforms = get_base_transforms(TARGET_IMAGE_SIZE)


Expand All @@ -49,16 +50,17 @@
# predictions = dendrite.query( ImageSynapse( images = b64_images ) )
# assert len(predictions) == len(b64_images)


def prepare_synapse(input_data, modality):
if isinstance(input_data, torch.Tensor):
input_data = transforms.ToPILImage()(input_data.cpu().detach())
if isinstance(input_data, list) and isinstance(input_data[0], torch.Tensor):
for i, img in enumerate(input_data):
input_data[i] = transforms.ToPILImage()(img.cpu().detach())

if modality == 'image':
if modality == "image":
return prepare_image_synapse(input_data)
elif modality == 'video':
elif modality == "video":
return prepare_video_synapse(input_data)
else:
raise NotImplementedError(f"Unsupported modality: {modality}")
Expand All @@ -80,32 +82,37 @@ def prepare_image_synapse(image: Image):
return ImageSynapse(image=b64_encoded_image)


class ImageSynapse(bt.Synapse):
def prepare_video_synapse(frames: List[Image.Image]):
"""
This protocol helps in handling image/prediction request and response communication between
the miner and the validator.
Prepares video frames for use with VideoSynapse object.

Attributes:
- image: a bas64 encoded images
- prediction: a float indicating the probabilty that the image is AI generated/modified.
>.5 is considered generated/modified, <= 0.5 is considered real.
Args:
frames (List[Image.Image]): The list of video frames to be prepared.

Returns:
VideoSynapse: An instance of VideoSynapse containing the encoded frames and a default prediction value.
"""
frame_bytes = []
for frame in frames:
buffer = BytesIO()
frame.save(buffer, format="JPEG")
frame_bytes.append(buffer.getvalue())

testnet_label: int = -1 # for easier miner eval on testnet
combined_bytes = b"".join(frame_bytes)
compressed_data = zlib.compress(combined_bytes)
encoded_data = base64.b85encode(compressed_data).decode("utf-8")
return VideoSynapse(video=encoded_data)

# Required request input, filled by sending dendrite caller.
image: str = pydantic.Field(
title="Image",
description="A base64 encoded image",
default="",
frozen=False
)

class MediaSynapse(bt.Synapse):

testnet_label: int = -1 # for miners to monitor their performance on testnet

prediction: Union[float, List[float]] = pydantic.Field(
title="Prediction",
description="Probability vector for [real, synthetic, semi-synthetic] classes.",
default=[-1., -1., -1.],
frozen=False
default=[-1.0, -1.0, -1.0],
frozen=False,
)

def deserialize(self) -> np.ndarray:
Expand All @@ -118,73 +125,46 @@ def deserialize(self) -> np.ndarray:
p = self.prediction
if isinstance(p, float):
if p == -1:
return np.array([-1., -1., -1.])
return np.array([-1.0, -1.0, -1.0])
else:
return np.array([1-p, p, 0.])
return np.array([1 - p, p, 0.0])
elif isinstance(p, list):
if len(p) == 2:
p += [0.0] # assume 2-dim responses are [real, fake]
return np.array(p)
else:
raise ValueError(f"Unsupported prediction type: {type(p)}")


def prepare_video_synapse(frames: List[Image.Image]):
class ImageSynapse(MediaSynapse):
"""
This protocol helps in handling image/prediction request and response communication between
the miner and the validator.

Attributes:
- image: a bas64 encoded images
- prediction: a float indicating the probabilty that the image is AI generated/modified.
>.5 is considered generated/modified, <= 0.5 is considered real.
"""
frame_bytes = []
for frame in frames:
buffer = BytesIO()
frame.save(buffer, format="JPEG")
frame_bytes.append(buffer.getvalue())

combined_bytes = b''.join(frame_bytes)
compressed_data = zlib.compress(combined_bytes)
encoded_data = base64.b85encode(compressed_data).decode('utf-8')
return VideoSynapse(video=encoded_data)
image: str = pydantic.Field(
title="Image", description="A base64 encoded image", default="", frozen=False
)


class VideoSynapse(bt.Synapse):
class VideoSynapse(MediaSynapse):
"""
Naive initial VideoSynapse
Better option would be to modify the Dendrite interface to allow multipart/form-data here:
https://github.com/opentensor/bittensor/blob/master/bittensor/core/dendrite.py#L533
Another higher lift option would be to look into Epistula or Fiber
Naive initial VideoSynapse (Epistula version coming soon I promise)
"""

testnet_label: int = -1 # for easier miner eval on testnet

# Required request input, filled by sending dendrite caller.
video: str = pydantic.Field(
title="Video",
description="A wildly inefficient means of sending video data",
default="",
frozen=False
)

# Optional request output, filled by receiving axon.
prediction: Union[float, List[float]] = pydantic.Field(
title="Prediction",
description="Probability vector for [real, synthetic, semi-synthetic] classes.",
default=[-1., -1., -1.],
frozen=False
frozen=False,
)

def deserialize(self) -> np.ndarray:
"""
Deserialize the output. Backwards compatible with binary float outputs.

Returns:
- float: The deserialized miner prediction probabilities
"""
p = self.prediction
if isinstance(p, float):
if p == -1:
return np.array([-1., -1., -1.])
else:
return np.array([1-p, p, 0.])
elif isinstance(p, list):
return np.array(p)
else:
raise ValueError(f"Unsupported prediction type: {type(p)}")


def decode_video_synapse(synapse: VideoSynapse) -> List[torch.Tensor]:
"""
Expand All @@ -196,7 +176,7 @@ def decode_video_synapse(synapse: VideoSynapse) -> List[torch.Tensor]:
Returns:
List of torch tensors, each representing a frame from the video
"""
compressed_data = base64.b85decode(synapse.video.encode('utf-8'))
compressed_data = base64.b85decode(synapse.video.encode("utf-8"))
combined_bytes = zlib.decompress(compressed_data)

# Split the combined bytes into individual JPEG files
Expand All @@ -208,7 +188,10 @@ def decode_video_synapse(synapse: VideoSynapse) -> List[torch.Tensor]:
while current_pos < data_length:
# Find start of JPEG (FF D8)
while current_pos < data_length - 1:
if combined_bytes[current_pos] == 0xFF and combined_bytes[current_pos + 1] == 0xD8:
if (
combined_bytes[current_pos] == 0xFF
and combined_bytes[current_pos + 1] == 0xD8
):
break
current_pos += 1

Expand All @@ -219,7 +202,10 @@ def decode_video_synapse(synapse: VideoSynapse) -> List[torch.Tensor]:

# Find end of JPEG (FF D9)
while current_pos < data_length - 1:
if combined_bytes[current_pos] == 0xFF and combined_bytes[current_pos + 1] == 0xD9:
if (
combined_bytes[current_pos] == 0xFF
and combined_bytes[current_pos + 1] == 0xD9
):
current_pos += 2
break
current_pos += 1
Expand All @@ -234,10 +220,10 @@ def decode_video_synapse(synapse: VideoSynapse) -> List[torch.Tensor]:
print(f"Error processing frame: {e}")
continue

bt.logging.info('transforming video inputs')
bt.logging.info("transforming video inputs")
frames = base_transforms(frames)

frames = torch.stack(frames, dim=0)
frames = frames.unsqueeze(0)
print(f'decoded video into tensor with shape {frames.shape}')
print(f"decoded video into tensor with shape {frames.shape}")
return frames