We combine and refactor all of the codes to create a program that takes a picture using Pi Camera when the button is pressed, classifies the object in the image, and then reads out the top label.
import torch
from torchvision import models
from torchvision import transforms
from PIL import Image
from gtts import gTTS
from io import BytesIO
import pygame
from picamera import PiCamera
from gpiozero import Button
from time import sleep
# Load image classification algorithms
#net = torch.hub.load('pytorch/vision:v0.4.2', 'squeezenet1_0', pretrained=True)
#net = torch.hub.load('pytorch/vision:v0.4.2', 'shufflenet_v2_x1_0', pretrained=True)
net = torch.hub.load('pytorch/vision:v0.4.2', 'mobilenet_v2', pretrained=True)
#net = torch.hub.load('pytorch/vision:v0.4.2', 'resnext50_32x4d', pretrained=True)
#net = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')
#net = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl')
net.eval()
# Load class labels
with open('imagenet_classes.txt') as f:
labels = [line.strip() for line in f.readlines()]
# image transformation for machine learning
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Initialize button
button = Button(17)
# Initialize pi camera
camera = PiCamera()
camera.resolution = (1920,1080)
camera.awb_mode = 'fluorescent'
camera.start_preview(alpha = 200)
def capture_image(button_dumb,camera_dumb):
stream = BytesIO()
button_dumb.wait_for_press()
camera_dumb.capture(stream, format='jpeg')
stream.seek(0)
pil_image = Image.open(stream)
return pil_image
def image_classify(pil_image,transform,net):
img_t = transform(pil_image)
batch_t = torch.unsqueeze(img_t, 0)
out = net(batch_t)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
_, indices = torch.sort(out, descending=True)
for idx in indices[0][:5]:
print(labels[idx], percentage[idx].item())
return indices
def play_label(text):
pygame.mixer.init()
mp3_fp = BytesIO()
tts = gTTS(text, 'en')
tts.write_to_fp(mp3_fp)
sound = BytesIO(mp3_fp.getvalue())
pygame.mixer.music.load(sound)
pygame.mixer.music.play()
sleep(2)
# Main loop
while True:
try:
pil_image = capture_image(button,camera)
indices = image_classify(pil_image,transform,net)
play_label(labels[indices[0][0]])
except KeyboardInterrupt:
camera.stop_preview()
break