Vorrei creare un custom keras strato (un codebook per un VQVAE modello). Durante l'allenamento mi piacerebbe avere un tf.Variable
che traccia l'utilizzo di ogni codice, così posso riavviare codici non utilizzati. Così ho creato il mio Codebook livello...
class Codebook(layers.Layer):
def __init__(self, num_codes, code_reset_limit = None, **kwargs):
super().__init__(**kwargs)
self.num_codes = num_codes
self.code_reset_limit = code_reset_limit
if self.code_reset_limit:
self.code_counter = tf.Variable(tf.zeros(num_codes, dtype = tf.int32), trainable = False)
def build(self, input_shape):
self.codes = self.add_weight(name = 'codes',
shape = (self.num_codes, input_shape[-1]),
initializer = 'random_uniform',
trainable = True)
super().build(input_shape)
Il problema che ho è che la Layer
classe trova la variabile membro self.code_counter
e lo aggiunge alla lista dei pesi che vengono salvati con il livello. Si aspetta anche il self.code_counter
per essere presente quando i pesi sono caricati che non è il caso quando ho eseguito in modalità di inferenza. Come posso fare così keras non traccia una variabile nel mio livello. Non voglio che si è prolungata o da parte del layers.weights
.