Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions components/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ def update_target_model(self):
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))

def act(self, state):
if np.random.rand() <= self.epsilon:
return random.randrange(self.action_size)
act_values = self.model.predict(state)
def act(self, state,random=True):
if random:
if np.random.rand() <= self.epsilon:
return random.randrange(self.action_size)
act_values = self.model.predict(state)
else:
act_values = self.model.predict(state)
return np.argmax(act_values[0]) # returns action

def replay(self, batch_size):
Expand All @@ -98,24 +101,30 @@ def save(self, name):

class TabularAgent:
'''RL agent as described in the DSRL paper'''
def __init__(self, action_size, neighbor_radius=25):
def __init__(self, action_size,alpha,epsilon_decay,neighbor_radius=25):
self.action_size = action_size
self.alpha = alpha
self.epsilon = 1
self.epsilon_decay = 0.999
self.epsilon_decay = epsilon_decay
self.epsilon_min = 0.1
self.gamma = 0.95
self.neighbor_radius=neighbor_radius
self.offset = neighbor_radius*2
self.tables = {}

def act(self, state):
def act(self, state,random_act=True):
'''
Determines action to take based on given state
State: Array of interactions
(entities in each interaction are presorted by type for consistency)
Returns: action to take, chosen e-greedily
'''
if not random_act:
return np.argmax(self._total_rewards(state))
if np.random.rand() <= self.epsilon:
print('random action, e:', self.epsilon)
#print('random action, e:', self.epsilon)
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
return random.randrange(self.action_size)

if self.epsilon > self.epsilon_min:
Expand All @@ -125,20 +134,37 @@ def act(self, state):

def update(self, state, action, reward, next_state, done):
'''Update tables based on reward and action taken'''
curr_tr = self._total_rewards(state)
next_tr = self._total_rewards(next_state)
print('Reward for action {}: {}. Current total rewards: {}'.format(action, reward, curr_tr))
print('Next Total Reward:', next_tr)



for interaction in state:
type_1, type_2 = interaction['types_after'] # TODO resolve: should this too be types_before?
table = self.tables.setdefault(type_1, {}).setdefault(type_2, self._make_table())

if done:
table[interaction['loc_difference']][action] = reward
id1,id2 = interaction['interaction']
interaction_next_state = [inter for inter in next_state if inter['interaction']==(id1,id2)]
if len(interaction_next_state)==0:
continue
elif len(interaction_next_state)>1:
raise ValueError('This should not happen')
else:
table[interaction['loc_difference']][action] = \
reward + self.gamma * (np.max(next_tr) - curr_tr[action])
#print('Now we should update the Q-values')
#print(f'The current reward is {reward}')
interaction_next_state = interaction_next_state[0]
interaction['loc_difference'] = (interaction['loc_difference'][0]+self.offset,interaction['loc_difference'][1]+self.offset)
interaction_next_state['loc_difference'] = (interaction_next_state['loc_difference'][0]+self.offset,interaction_next_state['loc_difference'][1]+self.offset)
#print(interaction_next_state['loc_difference'])
#print(interaction['loc_difference'])
next_action_value = table[interaction_next_state['loc_difference']]
#print(f'The next action value {next_action_value}')
if done:
table[interaction['loc_difference']][action] = reward
else:
#print(f'Q-value before update {table[interaction["loc_difference"]][action]}')
#print(f'Location {interaction["loc_difference"]}')
#print(f"The new value should be {table[interaction['loc_difference']][action] + self.alpha*(reward + self.gamma * np.max(next_action_value) - table[interaction['loc_difference']][action])}")
#print(interaction['loc_difference'])
table[interaction['loc_difference']][action] = table[interaction['loc_difference']][action] + self.alpha*(reward + self.gamma * np.max(next_action_value) - table[interaction['loc_difference']][action])
#print(f'Q-value after update {table[interaction["loc_difference"]][action]}')

def _total_rewards(self, interactions):
action_rewards = np.zeros(self.action_size)
Expand All @@ -154,8 +180,8 @@ def _make_table(self):
3-D table: rows = loc_difference_x, cols = loc_difference_y, z = q-values for actions
Rows and cols added to as needed
'''
return np.zeros((self.neighbor_radius * 2, self.neighbor_radius * 2, self.action_size),
dtype=int)
return np.zeros((self.neighbor_radius * 8, self.neighbor_radius * 8, self.action_size),
dtype=float)

def save(self, filename):
'''Save agent's tables'''
Expand Down
58 changes: 37 additions & 21 deletions components/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

class SymbolAutoencoder():
'''Implements the DSRL paper section 3.1. Extract entities from raw image'''
def __init__(self, input_shape, neighbor_radius=25):
def __init__(self, input_shape,filter_size,neighbor_radius=25):
self.neighbor_radius = neighbor_radius

self.filter_size = filter_size
input_img = Input(shape=input_shape)
encoded = Conv2D(16, (5, 5), activation='relu', padding='same')(input_img)
encoded = MaxPooling2D((POOL_SIZE, POOL_SIZE), padding='same')(encoded)
Expand All @@ -30,6 +30,8 @@ def __init__(self, input_shape, neighbor_radius=25):
self.autoencoder = Model(input_img, decoded)
self.autoencoder.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

self.repr_entity_activations = []

def train(self, train_data, epochs=50, batch_size=128, shuffle=True,
validation=None, tensorboard=False):
'''Train the autoencoder on provided images'''
Expand Down Expand Up @@ -62,21 +64,22 @@ def _extract_positions(self, encoded_image):
features -= background_value
#apply the local maximum filter; all pixel of maximal value
#in their neighborhood are set to 1
filtered = maximum_filter(features, size=(4, 4)) #TODO: Abstract size
filtered = np.asarray(filtered == features, dtype=int) - np.asarray(filtered == 0,
dtype=int)
filtered = maximum_filter(features, size=(self.filter_size, self.filter_size)) #TODO: Abstract size
filtered = np.asarray(filtered == features, dtype=int) - np.asarray(filtered == 0,dtype=int)
filtered.reshape(encoded_image.shape[:-1])
filtered *= POOL_SIZE # Pooling = downsampling = everything is scaled down by POOL_SIZE
#2d image of the positions, and just the indices
return filtered, np.transpose(np.nonzero(filtered))

def visualize(self, images):
def visualize(self, images,show=False):
'''Visualize autoencoder processing steps'''
if len(images) > 20:
raise Exception('Too many visualization images, please provide <= 20')
logger.info('Visualizing...')


encoded_imgs = self.encode(images)
print(f'Encoded Image {encoded_imgs.shape}')
position_maps = [self._extract_positions(x)[0] for x in encoded_imgs]
decoded_imgs = self.predict(images)

Expand All @@ -90,14 +93,14 @@ def flatten_to_img(array):
plt_i = i+1
# display original
axis = plt.subplot(4, n_plots, plt_i)
plt.imshow(flatten_to_img(images[i]))
plt.imshow(images[i])
plt.gray()
axis.get_xaxis().set_visible(False)
axis.get_yaxis().set_visible(False)

# display reconstruction
axis = plt.subplot(4, n_plots, plt_i + n_plots)
plt.imshow(flatten_to_img(decoded_imgs[i]))
plt.imshow(decoded_imgs[i])
plt.gray()
axis.get_xaxis().set_visible(False)
axis.get_yaxis().set_visible(False)
Expand All @@ -117,8 +120,10 @@ def flatten_to_img(array):
axis.get_xaxis().set_visible(False)
axis.get_yaxis().set_visible(False)

print('\nPlot visible, close it to proceed')
plt.show()
if show:
plt.show()
#print('\nPlot visible, close it to proceed')
return plt.gcf()

def get_entities(self, image):
'''
Expand All @@ -128,31 +133,35 @@ def get_entities(self, image):
etc.
}
'''

#print('Inside the get entities function')
encoded = self.encode(image.reshape((1,) + image.shape))[0]
pos_map, entities = self._extract_positions(encoded)

repr_entity_activations = [] # Representative depth slice for a certain type
#print(f'Number of identified entities: {len(entities)}')
#print(f'Number of identified entities: {entities.shape}')
#print(entities)


typed_entities = [] # Actual Entity() array
found_types = []
# TODO: Enhancements: knn classifier instead of this caveman shit
for entity_coords in entities:
activations = encoded[entity_coords[0], entity_coords[1], :]
if not repr_entity_activations:
repr_entity_activations.append(activations)
if not self.repr_entity_activations:
self.repr_entity_activations.append(activations)
e_type = 'type0'

else:
for i, e_activations in enumerate(repr_entity_activations):
for i, e_activations in enumerate(self.repr_entity_activations):
dist = sqeuclidean(activations, e_activations)
if dist < ENTITY_DIST_THRESHOLD: # Same type
repr_entity_activations[i] = (e_activations + activations) / 2
self.repr_entity_activations[i] = (e_activations + activations) / 2
e_type = 'type' + str(i)
break
else:
# No type match, make new type
repr_entity_activations.append(activations)
new_type_idx = len(repr_entity_activations) - 1
self.repr_entity_activations.append(activations)
new_type_idx = len(self.repr_entity_activations) - 1
e_type = 'type' + str(new_type_idx)

min_coords = entity_coords-self.neighbor_radius
Expand All @@ -170,12 +179,12 @@ def get_entities(self, image):
return typed_entities, found_types

@staticmethod
def from_saved(filename, input_shape, neighbor_radius=None):
def from_saved(filename, input_shape, filter_size, neighbor_radius=None):
'''Load autoencoder weights from filename, given input shape'''
if neighbor_radius is not None:
ret = SymbolAutoencoder(input_shape, neighbor_radius=neighbor_radius)
ret = SymbolAutoencoder(input_shape,filter_size, neighbor_radius=neighbor_radius)
else:
ret = SymbolAutoencoder(input_shape)
ret = SymbolAutoencoder(input_shape,filter_size)
ret.autoencoder.load_weights(filename)
return ret

Expand Down Expand Up @@ -214,3 +223,10 @@ def disappeared(self):
def _transition(self, from_type, to_type):
self.last_transition = [from_type, to_type]
self.entity_type = to_type

def __repr__(self):
text = ''
text += f'Entity ID {self.id} \n'
text += f'Entitiy Type {self.entity_type} \n'
text += f'Position {self.position} \n'
return text
Loading