diff --git a/train_steering_model.py b/train_steering_model.py index 421e2c3..d89510a 100755 --- a/train_steering_model.py +++ b/train_steering_model.py @@ -8,6 +8,7 @@ from keras.models import Sequential from keras.layers import Dense, Dropout, Flatten, Lambda, ELU from keras.layers.convolutional import Convolution2D +from keras import backend as K from server import client_generator @@ -47,6 +48,8 @@ def get_model(time_len=1): if __name__ == "__main__": + K.set_image_dim_ordering('th') + parser = argparse.ArgumentParser(description='Steering angle model trainer') parser.add_argument('--host', type=str, default="localhost", help='Data server ip address.') parser.add_argument('--port', type=int, default=5557, help='Port of server.')