#!/usr/bin/env python3
# coding: utf-8

from keras.models import Model
from keras import backend as K
from keras.layers import Input, LSTM, Dense
from keras.layers.core import Dropout
from keras.models import Sequential
import numpy as np
import os

class seq2seq():
    def __init__(self,categorycode,param):
        self.batch_size = 32
        self.epochs = 100
        self.latent_dim = 256
        self.num_samples = 2000
        if os.path.isdir(f'data/{categorycode}'):
            self.data_path = f'data/{categorycode}/{param}.txt'
        else:
            os.makedirs(f'data/{categorycode}')
        if os.path.isdir(f'model/{categorycode}'):
            self.model_path = f'model/{categorycode}/{param}.h5'
        else:
            os.makedirs(f'model/{categorycode}')

        input_texts = []
        target_texts = []
        input_characters = set()
        target_characters = set()
        with open(self.data_path, 'r', encoding='utf-8') as f:
            lines = f.read().split('\n')
        for line in lines[: min(self.num_samples, len(lines) - 1)]:
            input_text, target_text = line.split('\t')
            target_text = '\t' + target_text + '\n'
            input_texts.append(input_text)
            target_texts.append(target_text)
            for char in input_text:
                if char not in input_characters:
                    input_characters.add(char)
            for char in target_text:
                if char not in target_characters:
                    target_characters.add(char)

        input_characters = sorted(list(input_characters))
        target_characters = sorted(list(target_characters))
        self.num_encoder_tokens = len(input_characters)
        self.num_decoder_tokens = len(target_characters)
        self.max_encoder_seq_length = max([len(txt) for txt in input_texts])
        self.max_decoder_seq_length = max([len(txt) for txt in target_texts])

        print('标注样本数:', len(input_texts))
        print('输入独立字符:', self.num_encoder_tokens)
        print('输出独立字符:', self.num_decoder_tokens)
        print('最长输入:', self.max_encoder_seq_length)
        print('最长输出:', self.max_decoder_seq_length)

        self.input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
        self.target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])

        self.reverse_input_char_index = dict(
            (i, char) for char, i in self.input_token_index.items())
        self.reverse_target_char_index = dict(
            (i, char) for char, i in self.target_token_index.items())

        self.encoder_input_data = np.zeros(
            (len(input_texts), self.max_encoder_seq_length, self.num_encoder_tokens),
            dtype='float32')
        self.decoder_input_data = np.zeros(
            (len(input_texts), self.max_decoder_seq_length, self.num_decoder_tokens),
            dtype='float32')
        self.decoder_target_data = np.zeros(
            (len(input_texts), self.max_decoder_seq_length, self.num_decoder_tokens),
            dtype='float32')


        for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
            for t, char in enumerate(input_text):
                self.encoder_input_data[i, t, self.input_token_index[char]] = 1.
            for t, char in enumerate(target_text):
                self.decoder_input_data[i, t, self.target_token_index[char]] = 1.
                if t > 0:
                    self.decoder_target_data[i, t - 1, self.target_token_index[char]] = 1.

        self.encoder_inputs = Input(shape=(None, self.num_encoder_tokens))
        self.decoder_inputs = Input(shape=(None, self.num_decoder_tokens))
        self.encoder_lstm = LSTM((self.latent_dim), return_sequences=False, return_state=True)
        self.decoder_lstm = LSTM((self.latent_dim), return_sequences=True, return_state=True)
        self.decoder_dense = Dense(self.num_decoder_tokens, activation='softmax')
        _, self.state_h, self.state_c = self.encoder_lstm(self.encoder_inputs)
        self.encoder_states = [self.state_h, self.state_c]
        self.decoder_outputs, _, _ = self.decoder_lstm(self.decoder_inputs,initial_state=self.encoder_states)
        self.decoder_outputs = self.decoder_dense(self.decoder_outputs)
        self.decoder_model = self.Decoder_Model()
        self.encoder_model = self.Encoder_Model()

    def Encoder_Decoder_Model(self):
        model = Model([self.encoder_inputs, self.decoder_inputs], self.decoder_outputs)
        model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc'])
        model.summary()
        return model
    
    def Encoder_Model(self):
        encoder_model = Model(self.encoder_inputs, self.encoder_states)
        return encoder_model
    
    def Decoder_Model(self):
        decoder_state_input_h = Input(shape=(self.latent_dim,))
        decoder_state_input_c = Input(shape=(self.latent_dim,))
        decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
        decoder_outputs, state_h, state_c = self.decoder_lstm(self.decoder_inputs, initial_state=decoder_states_inputs)
        decoder_states = [state_h, state_c]
        decoder_outputs = self.decoder_dense(decoder_outputs)
        decoder_model = Model(
                                [self.decoder_inputs] + decoder_states_inputs,
                                [decoder_outputs] + decoder_states)
        return decoder_model

    def decode_sequence(self, input_sentence):
        input_seq = np.zeros((1,self.max_encoder_seq_length, self.num_encoder_tokens), dtype='float32')
        for t, char in enumerate(input_sentence):
            try:
                input_seq[0, t, self.input_token_index[char]] = 1.
            except KeyError:
                pass
        states_value = self.encoder_model.predict(input_seq)
        target_seq = np.zeros((1, 1, self.num_decoder_tokens))
        target_seq[0, 0, self.target_token_index['\t']] = 1.
        stop_condition = False
        decoded_sentence = ''
        while not stop_condition:
            output_tokens, h, c = self.decoder_model.predict([target_seq] + states_value)
            sampled_token_index = np.argmax(output_tokens[0, -1, :])
            sampled_char = self.reverse_target_char_index[sampled_token_index]
            decoded_sentence += sampled_char
            if (sampled_char == '\n' or len(decoded_sentence) > self.max_decoder_seq_length):
                stop_condition = True
            target_seq = np.zeros((1, 1, self.num_decoder_tokens))
            target_seq[0, 0, sampled_token_index] = 1.
            states_value = [h, c]
        return decoded_sentence
    
    def model_train(self):
        model = self.Encoder_Decoder_Model()
        model.fit([self.encoder_input_data, self.decoder_input_data], 
          self.decoder_target_data,
          batch_size=self.batch_size,
          epochs=self.epochs,
          validation_split=0.2)
        model.save(self.model_path)

    def model_load(self):
        model = self.Encoder_Decoder_Model()
        model.load_weights(self.model_path)

if __name__ == '__main__':
    model_1 = seq2seq('0101','CPU型号')
    #model_1.model_train()
    model_1.model_load()
    #print(model_1.decode_sequence('I5-4200'))
    m = 0
    while True:
        a = input('请输入要非标转标的参数值:')
        print(model_1.decode_sequence(a))