diff --git a/README.md b/README.md index efe3d37..a611719 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,19 @@ +--- +runme: + id: 01HXQM38H0ZX9BSHH89DJJB5PA + version: v3 +--- + # Atlantis engine ![1715102710368](example/docs/image/Atlantis.png) ## Thành quả phát triển -* [X] Thiết kế API -* [X] Công khai toàn bộ mã nguồn -* [X] Tạo các code mẫu để áp dụng -* [ ] đang cập nhật tiếp... +- [x] Thiết kế API +- [x] Công khai toàn bộ mã nguồn +- [x] Tạo các code mẫu để áp dụng +- [ ] đang cập nhật tiếp... ## Thử nghiệm @@ -25,15 +31,15 @@ Giao tiếp có vai trò vô cùng quan trọng đối với mỗi chúng ta, n ### Công nghệ sử dụng -* **Thị giác máy tính** +- **Thị giác máy tính** Thị giác máy tính là gì? Thị giác máy tính là một công nghệ mà máy sử dụng để tự động nhận biết và mô tả hình ảnh một cách chính xác và hiệu quả. Ngày nay, các hệ thống máy tính có quyền truy cập vào khối lượng lớn hình ảnh và dữ liệu video bắt nguồn từ hoặc được tạo bằng điện thoại thông minh, camera giao thông, hệ thống bảo mật và các thiết bị khác. Ứng dụng thị giác máy tính sử dụng trí tuệ nhân tạo và máy học (AI/ML) để xử lý những dữ liệu này một cách chính xác nhằm xác định đối tượng và nhận diện khuôn mặt, cũng như phân loại, đề xuất, giám sát và phát hiện. -* **Máy học** +- **Máy học** Máy học (machine learning) là một lĩnh vực của trí tuệ nhân tạo (AI) mà trong đó máy tính được lập trình để tự động học và cải thiện từ dữ liệu mà nó nhận được. Thay vì chỉ dựa trên các quy tắc cụ thể được lập trình trước, máy học cho phép máy tính "học" thông qua việc phân tích dữ liệu và tìm ra các mẫu, xu hướng hoặc quy luật ẩn trong dữ liệu mà không cần được lập trình trực tiếp. -* **Giao diện chương trình ứng dụng** +- **Giao diện chương trình ứng dụng** Giao diện chương trình là gì? Giao diện chương trình – Application Programming Interface viết tắt là API là một trung gian phần mềm cho phép hai ứng dụng giao tiếp với nhau, có thể sử dụng cho web-based system, operating system, database system, computer hardware, hoặc software library. @@ -43,8 +49,9 @@ Giao diện chương trình là gì? Giao diện chương trình – Application #### Dưới đây là hướng dẫn train model cơ bản -```python -# collect_imgs.py +```python {"id":"01HXQM38GYX73YFXXRSQ2KJSEJ"} +# engine/train/collect_imgs.py + import os import cv2 DATA_DIR = './data' @@ -74,16 +81,18 @@ for j in range(number_of_classes): counter += 1 cap.release() cv2.destroyAllWindows() + + ``` -```python -#create_dataset.py +```python {"id":"01HXQM38GYX73YFXXRSQXQ5PWM"} +# engine/train/create_dataset.py + import os import pickle -import mediapipe as mp import cv2 -import matplotlib.pyplot as plt +import mediapipe as mp mp_hands = mp.solutions.hands mp_drawing = mp.solutions.drawing_utils @@ -98,8 +107,9 @@ labels = [] for dir_ in os.listdir(DATA_DIR): for img_path in os.listdir(os.path.join(DATA_DIR, dir_)): data_aux = [] - x_ = [] - y_ = [] + + xs = [] + ys = [] img = cv2.imread(os.path.join(DATA_DIR, dir_, img_path)) img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) @@ -107,65 +117,75 @@ for dir_ in os.listdir(DATA_DIR): results = hands.process(img_rgb) if results.multi_hand_landmarks: for hand_landmarks in results.multi_hand_landmarks: - for i in range(len(hand_landmarks.landmark)): - x = hand_landmarks.landmark[i].x - y = hand_landmarks.landmark[i].y - - x_.append(x) - y_.append(y) + for _ in range(len(hand_landmarks.landmark)): + xs.append(hand_landmarks.landmark[0].x) + ys.append(hand_landmarks.landmark[0].y) - for i in range(len(hand_landmarks.landmark)): - x = hand_landmarks.landmark[i].x - y = hand_landmarks.landmark[i].y - data_aux.append(x - min(x_)) - data_aux.append(y - min(y_)) + for _ in range(len(hand_landmarks.landmark)): + data_aux.append(hand_landmarks.landmark[0].x - min(xs)) + data_aux.append(hand_landmarks.landmark[0].y - min(ys)) data.append(data_aux) labels.append(dir_) -f = open('data.pickle', 'wb') -pickle.dump({'data': data, 'labels': labels}, f) -f.close() +with open('data.pickle', 'wb') as f: + pickle.dump({'data': data, 'labels': labels}, f) + + ``` -```python -#train_classifier.py +```python {"id":"01HXQM38GYX73YFXXRSR0G097A"} +# engine/train/train_classifier.py + import pickle +from os import remove +import numpy as np from sklearn.ensemble import RandomForestClassifier -from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score -import numpy as np -from os import remove -data_dict = pickle.load(open('./data.pickle', 'rb')) +from sklearn.model_selection import train_test_split + +with open('./data.pickle', 'rb') as f: + data_dict = pickle.load(f) + data = np.asarray(data_dict['data']) labels = np.asarray(data_dict['labels']) -x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, shuffle=True, stratify=labels, random_state=42) + +x_train, x_test, y_train, y_test = train_test_split(data, + labels, + test_size=0.2, + shuffle=True, + stratify=labels, + random_state=42) + model = RandomForestClassifier() + model.fit(x_train, y_train) y_predict = model.predict(x_test) + score = accuracy_score(y_predict, y_test) -print('{}% of samples were classified correctly !'.format(score * 100)) +print(f'{score * 100}% of samples were classified correctly !') -f = open('model.p', 'wb') -pickle.dump({'model': model}, f) -f.close() +with open('../model/model.p', 'wb') as f: + pickle.dump({'model': model}, f) remove('./data.pickle') + ``` -```python -#inference_classifier.py +```python {"id":"01HXQM38GYX73YFXXRSRYS7KBR"} +# engine/train/inference_classifier.py + import pickle import cv2 import mediapipe as mp import numpy as np -import time -model_dict = pickle.load(open('./model.p', 'rb')) +with open('../model/model.p', 'rb') as f: + model_dict = pickle.load(f) model = model_dict['model'] cap = cv2.VideoCapture(0) @@ -174,18 +194,17 @@ mp_hands = mp.solutions.hands mp_drawing = mp.solutions.drawing_utils mp_drawing_styles = mp.solutions.drawing_styles -OUTPUT = [] +output: list[str] = [] hands = mp_hands.Hands(static_image_mode=True, min_detection_confidence=0.3) -labels_dict = {0:'ok',1:'xin chao',2:'tam biet'} -CHECK_FRAME = 0 +labels_dict = {0: 'ok', 1: 'xin chao', 2: 'tam biet'} +detected = False predicted_character = '' while True: - data_aux = [] - x_ = [] - y_ = [] + xs = [] + ys = [] ret, frame = cap.read() H, W, _ = frame.shape @@ -193,70 +212,64 @@ while True: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = hands.process(frame_rgb) - + if results.multi_hand_landmarks: + detected = True - CHECK_FRAME+=1 + if not detected: + continue - if CHECK_FRAME == 1: - CHECK_FRAME = 0 + detected = False - for hand_landmarks in results.multi_hand_landmarks: - mp_drawing.draw_landmarks( - frame, - hand_landmarks, - mp_hands.HAND_CONNECTIONS, - mp_drawing_styles.get_default_hand_landmarks_style(), - mp_drawing_styles.get_default_hand_connections_style()) + for hand_landmarks in results.multi_hand_landmarks: + mp_drawing.draw_landmarks( + frame, hand_landmarks, mp_hands.HAND_CONNECTIONS, + mp_drawing_styles.get_default_hand_landmarks_style(), + mp_drawing_styles.get_default_hand_connections_style()) - for hand_landmarks in results.multi_hand_landmarks: - for i in range(len(hand_landmarks.landmark)): - x = hand_landmarks.landmark[i].x - y = hand_landmarks.landmark[i].y + for hand_landmarks in results.multi_hand_landmarks: + for _ in range(len(hand_landmarks.landmark)): + xs.append(hand_landmarks.landmark[0].x) + ys.append(hand_landmarks.landmark[0].y) - x_.append(x) - y_.append(y) + for _ in range(len(hand_landmarks.landmark)): + data_aux.append(hand_landmarks.landmark[0].x - min(xs)) + data_aux.append(hand_landmarks.landmark[0].y - min(ys)) - for i in range(len(hand_landmarks.landmark)): - x = hand_landmarks.landmark[i].x - y = hand_landmarks.landmark[i].y - data_aux.append(x - min(x_)) - data_aux.append(y - min(y_)) + x1 = int(min(xs) * W) - 10 + y1 = int(min(ys) * H) - 10 - x1 = int(min(x_) * W) - 10 - y1 = int(min(y_) * H) - 10 + x2 = int(max(xs) * W) - 10 + y2 = int(max(ys) * H) - 10 - x2 = int(max(x_) * W) - 10 - y2 = int(max(y_) * H) - 10 + prediction = model.predict([np.asarray(data_aux)]) - prediction = model.predict([np.asarray(data_aux)]) + print(prediction) - print(prediction) + predicted_character = labels_dict[int(prediction[0])] - predicted_character = labels_dict[int(prediction[0])] - - cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), 4) - cv2.putText(frame, predicted_character , (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 0), 3, - cv2.LINE_AA) + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), 4) + cv2.putText(frame, predicted_character, (x1, y1 - 10), + cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 0), 3, cv2.LINE_AA) if predicted_character != '': - if len(OUTPUT) == 0 or OUTPUT[-1] != predicted_character: - OUTPUT.append(predicted_character) - + if len(output) == 0 or output[-1] != predicted_character: + output.append(predicted_character) cv2.imshow('frame', frame) - if cv2.waitKey(25)==ord('q'): + if cv2.waitKey(25) == ord('q'): break -print(OUTPUT) +print(output) cap.release() cv2.destroyAllWindows() + ``` ## Giấy phép -``` +```md {"id":"01HXQM38GZ1HBF5NNQ999QT1D4"} MIT License Copyright (c) 2024 iotran207 @@ -276,9 +289,11 @@ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + ``` Phần cuối xin cảm ơn team MEDIAPIPE của google vì đã phát triển một framework thật tuyệt vời và [computervisioneng](https://github.com/computervisioneng) đã tạo nên một repo thật tuyệt vời để học hỏi. diff --git a/engine/Dockerfile b/engine/Dockerfile index 6448609..c3547a7 100644 --- a/engine/Dockerfile +++ b/engine/Dockerfile @@ -10,6 +10,6 @@ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt COPY ./api /code/api -COPY ./train/model.p /code/api/model.p +COPY ./model/model.p /code/model/model.p CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8345"] \ No newline at end of file diff --git a/engine/api/database.sqlite b/engine/api/database.sqlite index 9e006f3..7aec99c 100644 Binary files a/engine/api/database.sqlite and b/engine/api/database.sqlite differ diff --git a/engine/api/main.py b/engine/api/main.py index 0fda04b..1094573 100644 --- a/engine/api/main.py +++ b/engine/api/main.py @@ -1,16 +1,12 @@ -from os import listdir from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from routers.engine import router as engine_router app = FastAPI() origins = ['*'] -for file in listdir('./routers'): - if file.endswith('.py') and file != '__init__.py': - module = file.replace('.py', '') - exec(f'from routers.{module} import router') - exec(f'app.include_router(router)') +app.include_router(engine_router) app.add_middleware( CORSMiddleware, @@ -20,6 +16,7 @@ allow_headers=["*"], ) + @app.get("/") async def home(): - return {"message": "Hello World!"} \ No newline at end of file + return {"message": "Hello World!"} diff --git a/engine/api/routers/engine.py b/engine/api/routers/engine.py index 77e1150..868a637 100644 --- a/engine/api/routers/engine.py +++ b/engine/api/routers/engine.py @@ -1,19 +1,16 @@ -from fastapi import APIRouter, UploadFile -from asyncio import sleep as sleep_async -from uvicorn import run as run_server -from fastapi.responses import HTMLResponse -from cv2 import VideoCapture, cvtColor, COLOR_BGR2RGB, imencode, CAP_DSHOW +from os import mkdir +from os.path import exists as path_exists from pickle import load as pickle_load -from time import time -from pydantic import BaseModel from sqlite3 import connect as sqlite_connect -from random import SystemRandom -from string import ascii_uppercase, digits -from os.path import exists as path_exists -from os import mkdir +from time import time import mediapipe as mp import numpy as np +from api.utils import randomString +from cv2 import COLOR_BGR2RGB, VideoCapture, cvtColor +from fastapi import APIRouter, UploadFile +from fastapi.responses import JSONResponse +from pydantic import BaseModel DATABASE = sqlite_connect('database.sqlite') DATABASE_CURSOR = DATABASE.cursor() @@ -25,31 +22,34 @@ mp_drawing_styles = mp.solutions.drawing_styles hands = mp_hands.Hands(static_image_mode=True, min_detection_confidence=0.3) -model_dict = pickle_load(open('model.p', 'rb')) +with open('../model/model.p', 'rb') as f: + model_dict = pickle_load(f) model = model_dict['model'] -labels_dict = {0:'hi',1:'toi la',2:'hoang lan'} +labels_dict = {0: 'hi', 1: 'toi la', 2: 'hoang lan'} + -class GetTokenModel(BaseModel): +class User(BaseModel): username: str password: str -async def VerifyUserToken(token: str): - if not DATABASE_CURSOR.execute(f"SELECT * FROM DATA_USER WHERE token = '{token}'").fetchone(): - return False - else: - return True -async def HandleVideo(VIDEO_PATH: str): - video_capture = VideoCapture(VIDEO_PATH) +def verifyUserToken(token: str) -> bool: + return not not DATABASE_CURSOR.execute( + "SELECT * FROM DATA_USER WHERE token = ?", (token, )).fetchone() + + +async def handleVideo(video_path: str) -> dict[str, list[str] | str]: + video_capture = VideoCapture(video_path) predicted_character = '' - OUTPUT = [] - CHECK_FRAME = 0 + output: list[str] = [] + detected = True + try: while True: - data_aux = [] - x_ = [] - y_ = [] + data_aux: list[int] = [] + xs: list[int] = [] + ys: list[int] = [] ret, frame = video_capture.read() if not ret: break @@ -59,103 +59,142 @@ async def HandleVideo(VIDEO_PATH: str): frame_rgb = cvtColor(frame, COLOR_BGR2RGB) results = hands.process(frame_rgb) if results.multi_hand_landmarks: - CHECK_FRAME+=1 - if CHECK_FRAME == 1: - CHECK_FRAME = 0 - for hand_landmarks in results.multi_hand_landmarks: - for i in range(len(hand_landmarks.landmark)): - x = hand_landmarks.landmark[0].x - y = hand_landmarks.landmark[0].y - x_.append(x) - y_.append(y) - - for i in range(len(hand_landmarks.landmark)): - x = hand_landmarks.landmark[0].x - y = hand_landmarks.landmark[0].y - data_aux.append(x - min(x_)) - data_aux.append(y - min(y_)) - - x1 = int(min(x_) * W) - 10 - y1 = int(min(y_) * H) - 10 - x2 = int(max(x_) * W) - 10 - y2 = int(max(y_) * H) - 10 - - prediction = model.predict([np.asarray(data_aux)]) - predicted_character = labels_dict[int(prediction[0])] - if OUTPUT == []: - OUTPUT.append(predicted_character) - print(predicted_character) - else: - if OUTPUT[-1] != predicted_character: - OUTPUT.append(predicted_character) - print(predicted_character) - - return {"data": OUTPUT} + detected = True + + if not detected: + continue + + detected = False + for hand_landmarks in results.multi_hand_landmarks: + for _ in range(len(hand_landmarks.landmark)): + xs.append(hand_landmarks.landmark[0].x) + ys.append(hand_landmarks.landmark[0].y) + + for _ in range(len(hand_landmarks.landmark)): + data_aux.append(hand_landmarks.landmark[0].x - min(xs)) + data_aux.append(hand_landmarks.landmark[0].y - min(ys)) + + # unused variables + # x1 = int(min(x_) * W) - 10 + # y1 = int(min(y_) * H) - 10 + # x2 = int(max(x_) * W) - 10 + # y2 = int(max(y_) * H) - 10 + + prediction = model.predict([np.asarray(data_aux)]) + predicted_character = labels_dict[int(prediction[0])] + if output == []: + output.append(predicted_character) + print(predicted_character) + continue + if output[-1] != predicted_character: + output.append(predicted_character) + print(predicted_character) + + return {"data": output} except Exception as e: return {"error": str(e)} + @router.post("/video") -async def UploadVideo(file: UploadFile = UploadFile,token: str = None): - if not DATABASE_CURSOR.execute(f"SELECT * FROM DATA_USER WHERE token = '{token}'").fetchone(): +async def uploadVideo(file: UploadFile | None = None, + token: str | None = None): + if file is None: + return {"error": "No file provided."} + if token is None: + return {"error": "No token provided."} + + if not DATABASE_CURSOR.execute("SELECT * FROM DATA_USER WHERE token = ?", + (token, )).fetchone(): return {"error": "Invalid token."} - else: - TimeNow = time() - with open(f"data/temp/{TimeNow}.mp4", "wb") as buffer: - buffer.write(file.file.read()) - return await HandleVideo(f"data/temp/{TimeNow}.mp4") + timestamp = time() + with open(f"data/temp/{timestamp}.mp4", "wb") as buffer: + buffer.write(file.file.read()) + + return await handleVideo(f"data/temp/{timestamp}.mp4") -def GenToken(): - return ''.join(SystemRandom().choice(ascii_uppercase + digits) for _ in range(32)) @router.post("/regentoken") -async def GetToken(data: GetTokenModel): - token = GenToken() - while DATABASE_CURSOR.execute(f"SELECT * FROM DATA_USER WHERE token = '{token}'").fetchone(): - token = GenToken() - DATABASE_CURSOR.execute(f"UPDATE DATA_USER SET token = '{token}' WHERE username = '{data.username}' AND password = '{data.password}'") +async def getToken(user: User): + token = randomString(32) + + while DATABASE_CURSOR.execute("SELECT * FROM DATA_USER WHERE token = ?", + (token, )).fetchone(): + token = randomString(32) + + DATABASE_CURSOR.execute( + "UPDATE DATA_USER SET token = ? WHERE username = ? AND password = ?", + (token, user.username, user.password)) DATABASE.commit() - - DATABASE_CURSOR.execute(f"SELECT * FROM DATA_USER WHERE username = '{data.username}' AND password = '{data.password}'") + + DATABASE_CURSOR.execute( + "SELECT * FROM DATA_USER WHERE username = ? AND password = ?", + (user.username, user.password)) + result = DATABASE_CURSOR.fetchone() - if result: - return {"token": result[2]} - else: - return {"error": "Invalid username or password, you can go to /register to create a new account."} + if not result: + return { + "error": + "Invalid username or password, you can go to /register to create a new account." + } + + return {"token": result[2]} + @router.post("/register") -async def Register(username: str, password: str): +async def register(username: str, password: str): try: - DATABASE_CURSOR.execute(f"SELECT * FROM DATA_USER WHERE username = '{username}'") - result = DATABASE_CURSOR.fetchone() - if result: + DATABASE_CURSOR.execute("SELECT * FROM DATA_USER WHERE username = ?", + (username, )) + if DATABASE_CURSOR.fetchone(): # If the username already exists return {"error": "Username already exists."} - else: - token = GenToken() - while DATABASE_CURSOR.execute(f"SELECT * FROM DATA_USER WHERE token = '{token}'").fetchone(): - token = GenToken() - DATABASE_CURSOR.execute(f"INSERT INTO DATA_USER (username, password,token) VALUES ('{username}', '{password}', '{token}')") - DATABASE.commit() - return {"token": f"{token}"} + + token = randomString(32) + while DATABASE_CURSOR.execute( + "SELECT * FROM DATA_USER WHERE token = ?", + (token, )).fetchone(): + # If the token already exists + token = randomString(32) + + DATABASE_CURSOR.execute( + "INSERT INTO DATA_USER (username, password, token) VALUES (?, ?, ?)", + (username, password, token)) + + DATABASE.commit() + + return {"token": token} except Exception as e: - return {"error": str(e)} - + return JSONResponse(status_code=500, content={"error": str(e)}) + + @router.post("/upload") -async def UploadFile(file: UploadFile = UploadFile,token: str = None): - if VerifyUserToken(token)==False: +async def uploadFile(file: UploadFile | None = None, token: str | None = None): + if file is None: + return {"error": "No file provided."} + if token is None: + return {"error": "No token provided."} + + if verifyUserToken(token) == False: return {"error": "Invalid token."} - else: - USERNAME = DATABASE_CURSOR.execute(f"SELECT * FROM DATA_USER WHERE token = '{token}'").fetchone()[0] - if not path_exists(f"data/user/{USERNAME}"): - mkdir(f"data/user/{USERNAME}") - - with open(f"data/user/{USERNAME}/{file.filename}", "wb") as buffer: - buffer.write(file.file.read()) - return {"status": "success", "filename": file.filename, "username": USERNAME,"time_upload": time(),"size": file.file.seek(0,2)} - + + username = DATABASE_CURSOR.execute( + "SELECT * FROM DATA_USER WHERE token = ?", (token, )).fetchone()[0] + + if not path_exists(f"data/user/{username}"): + mkdir(f"data/user/{username}") + + with open(f"data/user/{username}/{file.filename}", "wb") as buffer: + buffer.write(file.file.read()) + + return { + "status": "success", + "filename": file.filename, + "username": username, + "time_upload": time(), + "size": file.file.seek(0, 2) + } + + @router.post("/customtrain") -async def train(token:str): - return {"status":"will be updated soon."} - - - +async def train(token: str): + return {"status": "will be updated soon."} diff --git a/engine/api/utils.py b/engine/api/utils.py new file mode 100644 index 0000000..51a09d5 --- /dev/null +++ b/engine/api/utils.py @@ -0,0 +1,8 @@ +import random +import string + + +def randomString(length): + return ''.join( + random.choice(string.ascii_letters + string.digits) + for i in range(length)) diff --git a/engine/api/model.p b/engine/model/model.p similarity index 100% rename from engine/api/model.p rename to engine/model/model.p diff --git a/engine/train/create_dataset.py b/engine/train/create_dataset.py index 5e113d0..3141bdf 100644 --- a/engine/train/create_dataset.py +++ b/engine/train/create_dataset.py @@ -1,9 +1,8 @@ import os import pickle -import mediapipe as mp import cv2 -import matplotlib.pyplot as plt +import mediapipe as mp mp_hands = mp.solutions.hands mp_drawing = mp.solutions.drawing_utils @@ -19,8 +18,8 @@ for img_path in os.listdir(os.path.join(DATA_DIR, dir_)): data_aux = [] - x_ = [] - y_ = [] + xs = [] + ys = [] img = cv2.imread(os.path.join(DATA_DIR, dir_, img_path)) img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) @@ -28,22 +27,16 @@ results = hands.process(img_rgb) if results.multi_hand_landmarks: for hand_landmarks in results.multi_hand_landmarks: - for i in range(len(hand_landmarks.landmark)): - x = hand_landmarks.landmark[i].x - y = hand_landmarks.landmark[i].y - - x_.append(x) - y_.append(y) + for _ in range(len(hand_landmarks.landmark)): + xs.append(hand_landmarks.landmark[0].x) + ys.append(hand_landmarks.landmark[0].y) - for i in range(len(hand_landmarks.landmark)): - x = hand_landmarks.landmark[i].x - y = hand_landmarks.landmark[i].y - data_aux.append(x - min(x_)) - data_aux.append(y - min(y_)) + for _ in range(len(hand_landmarks.landmark)): + data_aux.append(hand_landmarks.landmark[0].x - min(xs)) + data_aux.append(hand_landmarks.landmark[0].y - min(ys)) data.append(data_aux) labels.append(dir_) -f = open('data.pickle', 'wb') -pickle.dump({'data': data, 'labels': labels}, f) -f.close() +with open('data.pickle', 'wb') as f: + pickle.dump({'data': data, 'labels': labels}, f) diff --git a/engine/train/inference_classifier.py b/engine/train/inference_classifier.py index 8ec4ae3..f94f5af 100644 --- a/engine/train/inference_classifier.py +++ b/engine/train/inference_classifier.py @@ -3,9 +3,9 @@ import cv2 import mediapipe as mp import numpy as np -import time -model_dict = pickle.load(open('./model.p', 'rb')) +with open('../model/model.p', 'rb') as f: + model_dict = pickle.load(f) model = model_dict['model'] cap = cv2.VideoCapture(0) @@ -14,17 +14,17 @@ mp_drawing = mp.solutions.drawing_utils mp_drawing_styles = mp.solutions.drawing_styles -OUTPUT = [] +output: list[str] = [] hands = mp_hands.Hands(static_image_mode=True, min_detection_confidence=0.3) -labels_dict = {0:'ok',1:'xin chao',2:'tam biet'} -CHECK_FRAME = 0 +labels_dict = {0: 'ok', 1: 'xin chao', 2: 'tam biet'} +detected = False predicted_character = '' while True: data_aux = [] - x_ = [] - y_ = [] + xs = [] + ys = [] ret, frame = cap.read() H, W, _ = frame.shape @@ -32,62 +32,55 @@ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = hands.process(frame_rgb) - - if results.multi_hand_landmarks: - CHECK_FRAME+=1 + if results.multi_hand_landmarks: + detected = True - if CHECK_FRAME == 1: - CHECK_FRAME = 0 + if not detected: + continue - for hand_landmarks in results.multi_hand_landmarks: - mp_drawing.draw_landmarks( - frame, - hand_landmarks, - mp_hands.HAND_CONNECTIONS, - mp_drawing_styles.get_default_hand_landmarks_style(), - mp_drawing_styles.get_default_hand_connections_style()) + detected = False - for hand_landmarks in results.multi_hand_landmarks: - for i in range(len(hand_landmarks.landmark)): - x = hand_landmarks.landmark[i].x - y = hand_landmarks.landmark[i].y + for hand_landmarks in results.multi_hand_landmarks: + mp_drawing.draw_landmarks( + frame, hand_landmarks, mp_hands.HAND_CONNECTIONS, + mp_drawing_styles.get_default_hand_landmarks_style(), + mp_drawing_styles.get_default_hand_connections_style()) - x_.append(x) - y_.append(y) + for hand_landmarks in results.multi_hand_landmarks: + for _ in range(len(hand_landmarks.landmark)): + xs.append(hand_landmarks.landmark[0].x) + ys.append(hand_landmarks.landmark[0].y) - for i in range(len(hand_landmarks.landmark)): - x = hand_landmarks.landmark[i].x - y = hand_landmarks.landmark[i].y - data_aux.append(x - min(x_)) - data_aux.append(y - min(y_)) + for _ in range(len(hand_landmarks.landmark)): + data_aux.append(hand_landmarks.landmark[0].x - min(xs)) + data_aux.append(hand_landmarks.landmark[0].y - min(ys)) - x1 = int(min(x_) * W) - 10 - y1 = int(min(y_) * H) - 10 + x1 = int(min(xs) * W) - 10 + y1 = int(min(ys) * H) - 10 - x2 = int(max(x_) * W) - 10 - y2 = int(max(y_) * H) - 10 + x2 = int(max(xs) * W) - 10 + y2 = int(max(ys) * H) - 10 - prediction = model.predict([np.asarray(data_aux)]) + prediction = model.predict([np.asarray(data_aux)]) - print(prediction) + print(prediction) - predicted_character = labels_dict[int(prediction[0])] + predicted_character = labels_dict[int(prediction[0])] - cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), 4) - cv2.putText(frame, predicted_character , (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 0), 3, - cv2.LINE_AA) + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), 4) + cv2.putText(frame, predicted_character, (x1, y1 - 10), + cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 0), 3, cv2.LINE_AA) if predicted_character != '': - if len(OUTPUT) == 0 or OUTPUT[-1] != predicted_character: - OUTPUT.append(predicted_character) - + if len(output) == 0 or output[-1] != predicted_character: + output.append(predicted_character) cv2.imshow('frame', frame) - if cv2.waitKey(25)==ord('q'): + if cv2.waitKey(25) == ord('q'): break -print(OUTPUT) +print(output) cap.release() cv2.destroyAllWindows() diff --git a/engine/train/model.p b/engine/train/model.p deleted file mode 100644 index dabb960..0000000 Binary files a/engine/train/model.p and /dev/null differ diff --git a/engine/train/train_classifier.py b/engine/train/train_classifier.py index a5d311b..2f384ac 100644 --- a/engine/train/train_classifier.py +++ b/engine/train/train_classifier.py @@ -1,18 +1,23 @@ import pickle +from os import remove +import numpy as np from sklearn.ensemble import RandomForestClassifier -from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score -import numpy as np -from os import remove - +from sklearn.model_selection import train_test_split -data_dict = pickle.load(open('./data.pickle', 'rb')) +with open('./data.pickle', 'rb') as f: + data_dict = pickle.load(f) data = np.asarray(data_dict['data']) labels = np.asarray(data_dict['labels']) -x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, shuffle=True, stratify=labels, random_state=42) +x_train, x_test, y_train, y_test = train_test_split(data, + labels, + test_size=0.2, + shuffle=True, + stratify=labels, + random_state=42) model = RandomForestClassifier() @@ -22,10 +27,9 @@ score = accuracy_score(y_predict, y_test) -print('{}% of samples were classified correctly !'.format(score * 100)) +print(f'{score * 100}% of samples were classified correctly !') -f = open('model.p', 'wb') -pickle.dump({'model': model}, f) -f.close() +with open('../model/model.p', 'wb') as f: + pickle.dump({'model': model}, f) remove('./data.pickle') diff --git a/example/Python/nomodel.py b/example/Python/nomodel.py index e17f867..0799bad 100644 --- a/example/Python/nomodel.py +++ b/example/Python/nomodel.py @@ -1,45 +1,57 @@ import cv2 -import numpy as np -import os -from matplotlib import pyplot as plt import mediapipe as mp -mp_hands = mp.solutions.holistic +mp_hands = mp.solutions.holistic mp_drawing = mp.solutions.drawing_utils cap = cv2.VideoCapture(0) + def mediapipe_detection(image, model): - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - image.flags.writeable = False - results = model.process(image) - image.flags.writeable = True - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image.flags.writeable = False + results = model.process(image) + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) return image, results + def draw_landmarks(image, results): - mp_drawing.draw_landmarks(image, results.left_hand_landmarks, mp_hands.HAND_CONNECTIONS) - mp_drawing.draw_landmarks(image, results.right_hand_landmarks, mp_hands.HAND_CONNECTIONS) + mp_drawing.draw_landmarks(image, results.left_hand_landmarks, + mp_hands.HAND_CONNECTIONS) + mp_drawing.draw_landmarks(image, results.right_hand_landmarks, + mp_hands.HAND_CONNECTIONS) + def draw_styled_landmarks(image, results): - mp_drawing.draw_landmarks(image, results.left_hand_landmarks, mp_hands.HAND_CONNECTIONS, - mp_drawing.DrawingSpec(color=(121,22,76), thickness=2, circle_radius=4), - mp_drawing.DrawingSpec(color=(121,44,250), thickness=2, circle_radius=2) - ) - mp_drawing.draw_landmarks(image, results.right_hand_landmarks, mp_hands.HAND_CONNECTIONS, - mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=4), - mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2) - ) - + mp_drawing.draw_landmarks( + image, results.left_hand_landmarks, mp_hands.HAND_CONNECTIONS, + mp_drawing.DrawingSpec(color=(121, 22, 76), + thickness=2, + circle_radius=4), + mp_drawing.DrawingSpec(color=(121, 44, 250), + thickness=2, + circle_radius=2)) + mp_drawing.draw_landmarks( + image, results.right_hand_landmarks, mp_hands.HAND_CONNECTIONS, + mp_drawing.DrawingSpec(color=(245, 117, 66), + thickness=2, + circle_radius=4), + mp_drawing.DrawingSpec(color=(245, 66, 230), + thickness=2, + circle_radius=2)) + + cap = cv2.VideoCapture(0) -with mp_hands.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic: +with mp_hands.Holistic(min_detection_confidence=0.5, + min_tracking_confidence=0.5) as holistic: while cap.isOpened(): ret, frame = cap.read() image, results = mediapipe_detection(frame, holistic) print(results) - + draw_styled_landmarks(image, results) cv2.imshow('OpenCV Feed', image) @@ -48,4 +60,3 @@ def draw_styled_landmarks(image, results): break cap.release() cv2.destroyAllWindows() - diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..8a94273 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,4 @@ +# Global options: + +[mypy] +disable_error_code = import \ No newline at end of file