# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
""" Copyright (c) 2007-2018 David Douard (Paris, FRANCE).
https://bitbucket.org/dddouard/pygpibtoolkit -- mailto:david.douard@sdfa3.org
"""

import re
import numpy

from PyQt5 import QtCore, QtWidgets
from PyQt5.QtCore import Qt

from pyqtgraph import PlotWidget

#from pygpibtoolkit.qt5.mpl import QMplCanvas

from pygpibtoolkit.HP3562A import state_decoder
from pygpibtoolkit.HP3562A import trace_decoder
from pygpibtoolkit.HP3562A import coord_decoder
from pygpibtoolkit.HP3562A import mathtools


children = []


def getChild(datablock):
    """
    Try to find the appropriate MDI child widget for the given datablock.
    """
    for child in children:
        if child.isValidDatablock(datablock):
            return child
    return None


class DatablockMDIChild(QtWidgets.QMainWindow):
    seqnumber = 1
    _username = "Window"

    @classmethod
    def isValidDatablock(cls, datablock):
        return False

    def __init__(self, datablock, name=None):
        super().__init__()
        if name is not None:
            self.username = name
        else:
            self.username = "{} {}".format(
                self.__class__._username, self.seqnumber)
        self.setAttribute(Qt.WA_DeleteOnClose)
        self.isUntitled = True
        self.dataIsModified = False
        self.setDatablock(datablock)
        self.setWindowTitle(self.username)
        self.setupUI()
        self.updateHeaderData()

    def setDatablock(self, datablock):
        self.datablock = datablock

    def setupUI(self):
        # setup headers views as a docked window
        assert isinstance(self._header_struct, tuple)
        self.headerDocks = []
        self.tables = []
        for i, header_struct in enumerate(self._header_struct):
            dock = QtWidgets.QDockWidget(
                "Header" + (i > 0 and (' %s' % (i+1)) or ''),
                self)
            # dock.setFeatures(dock.NoDockWidgetFeatures)
            sarea = QtWidgets.QScrollArea(dock)
            dock.setWidget(sarea)
            self.addDockWidget(QtCore.Qt.RightDockWidgetArea, dock)
            self.headerDocks.append(dock)

            l = QtWidgets.QVBoxLayout(sarea)
            l.setContentsMargins(0, 0, 0, 0)
            table = QtWidgets.QTableWidget(sarea)
            # table.setStyleSheet('font-size: 10px;')
            # self.setupRowsHeight(table)
            table.setShowGrid(False)
            table.setAlternatingRowColors(True)
            table.verticalHeader().hide()
            l.addWidget(table, 1)
            self.tables.append(table)

    def setupRowsHeight(self, table):
        if table.verticalHeader().minimumSectionSize() > 0:
            cellsize = table.verticalHeader().minimumSectionSize()
        else:
            cellsize = 15
        table.verticalHeader().setDefaultSectionSize(cellsize)

    def updateHeaderData(self):
        for header, table, header_struct in zip(
                self.header, self.tables, self._header_struct):
            table.clear()
            table.setRowCount(len(header_struct))
            table.setColumnCount(2)
            table.setHorizontalHeaderLabels(['Parameter', 'Value'])
            bool_re = re.compile(
                r'((?P<before>.*) )?(?P<flag>\w+/\w+)( (?P<after>.*))?')
            item = QtWidgets.QTableWidgetItem()
            item.setFlags(Qt.ItemIsSelectable | Qt.ItemIsEnabled)
            for i, row in enumerate(header_struct):
                pname = row[0]
                key = row[1]
                typ = row[2]
                if typ is None:
                    continue
                val = header.get(key, "N/A")
                if typ is bool and isinstance(val, typ):
                    m = bool_re.match(key)
                    if m:
                        d = m.groupdict()
                        key = ""
                        if d['before']:
                            key += d['before']
                        if d['after']:
                            key += d['after']
                        key = key.capitalize()
                        val = d['flag'].split('/')[not val]
                    else:
                        val = str(val)
                else:
                    val = str(val)
                val = val.strip()
                if val:
                    if val[0]+val[-1] in ['""', "''"]:
                        val = val[1:-1]
                    if val[0:2]+val[-1] in ['u""', "u''"]:
                        val = val[2:-1]
                    while val and val.endswith(chr(0)):
                        val = val[:-1]
                item_ = QtWidgets.QTableWidgetItem(item)
                item_.setText(key)
                table.setItem(i, 0, item_)
                item_ = QtWidgets.QTableWidgetItem(item)
                item_.setText(val)
                table.setItem(i, 1, item_)
            table.resizeColumnsToContents()
            # table.resizeRowsToContents()
            # self.setupRowsHeight(self.table)

    def userFriendlyName(self):
        return self.username

    def closeEvent(self, event):
        if self.maybeSave():
            event.accept()
        else:
            event.ignore()

    def maybeSave(self):
        if self.dataIsModified:
            ret = QtWidgets.QMessageBox.warning(
                self, self.tr("MDI"),
                self.tr("'%1' has been modified.\n"
                        "Do you want to save your changes?")
                .arg(self.userFriendlyCurrentFile()),
                QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.Default,
                QtWidgets.QMessageBox.No,
                QtWidgets.QMessageBox.Cancel | QtWidgets.QMessageBox.Escape)
            if ret == QtWidgets.QMessageBox.Yes:
                return self.save()
            elif ret == QtWidgets.QMessageBox.Cancel:
                return False
        return True


class StateBinaryDatablockMDIChild(DatablockMDIChild):
    _username = "State"
    _header_struct = (state_decoder.HEADER, )

    @classmethod
    def isValidDatablock(cls, datablock):
        try:
            h = state_decoder.decode_state(datablock)
            assert len(h) > 0
            return True
        except Exception as e:
            return False

    def setDatablock(self, datablock):
        super().setDatablock(datablock)
        self.header = [state_decoder.decode_state(self.datablock)]


class TraceBinaryDatablockMDIChild(DatablockMDIChild):
    _username = "Trace"
    _header_struct = (trace_decoder.HEADER, )

    @classmethod
    def isValidDatablock(cls, datablock):
        try:
            h, t = trace_decoder.decode_trace(datablock)
            assert len(h) > 0
            assert len(t)
            return True
        except Exception as e:
            return False

    def __init__(self, datablock, name=None):
        super().__init__(datablock, name)
        self.updateTraceData()

    def setDatablock(self, datablock):
        super().setDatablock(datablock)
        self.header, self.trace = trace_decoder.decode_trace(self.datablock)
        self.header = [self.header]

    def setupToolBar(self):
        toolbar = QtWidgets.QToolBar(self)
        self.addToolBar(toolbar)
        self.ylogaction = QtWidgets.QAction(self.tr("Y Log"), self)
        self.ylogaction.setCheckable(True)
        self.ylogaction.toggled.connect(self.updateTraceData)
        toolbar.addAction(self.ylogaction)

    def setupUI(self):
        self.setupToolBar()
        super().setupUI()
        mainw = QtWidgets.QWidget(self)
        l = QtWidgets.QVBoxLayout(mainw)
        l.setContentsMargins(0, 0, 0, 0)
        self.canvas = PlotWidget(self)
        self.plot = self.canvas.plot()
        self.canvas.showGrid(x=True, y=True)
        l.addWidget(self.canvas, 1)

        self.setCentralWidget(mainw)

    def updateTraceData(self):
        f0 = self.header[0]['Start freq value']
        dx = self.header[0]['Delta X-axis']
        n = self.header[0]['Number of elements']
        x = numpy.linspace(f0, f0+dx*n, len(self.trace))
        y = self.trace.copy()
        if self.ylogaction.isChecked():
            minv = min(y[y > 0])
            y[y == 0] = minv
            y = numpy.log10(y)
            y = y * 10
        self.plot.setData(x=x, y=y)
        self.canvas.setLabel('bottom', self.header[0]['Domain'],
                             units=self.header[0]['X axis units'])
        yunit = 'dB' if self.ylogaction.isChecked() else ''
        self.canvas.setLabel('left', self.header[0]['Amplitude units'],
                             units=yunit)
        self.canvas.setTitle(self.header[0]['Display function'])

        # compute THD, if any
        y = self.trace.copy()
        if f0 > 0:
            # must add some initial zeros
            yy = numpy.zeros(int(f0/dx) + len(y))
            yy[-len(y):] = y
            y = yy
        msg = ""
        try:
            f0, thd = mathtools.thd(y, db=True)
            f0 = f0 * dx
            assert thd
        except:
            pass
        else:
            msg += 'THD:%.2g db  Freq:%.2f Hz  ' % (thd, f0)
        try:
            thdn = mathtools.thd_n(y, db=True)
        except:
            pass
        else:
            msg += 'THD+N:%.2g db  ' % thdn
        self.statusBar().showMessage(msg)


class CoordBinaryDatablockMDIChild(TraceBinaryDatablockMDIChild):
    _username = "Coord"
    _header_struct = (coord_decoder.TRACE_HEADER, coord_decoder.HEADER, )

    @classmethod
    def isValidDatablock(cls, datablock):
        try:
            h1, h2, t = coord_decoder.decode_coord(datablock)
            assert len(h1) > 0
            assert len(h2) > 0
            assert len(t) > 0
            return True
        except Exception as e:
            return False

    def setupToolBar(self):
        pass

    def setDatablock(self, datablock):
        super().setDatablock(datablock)
        h1, h2, self.trace = coord_decoder.decode_coord(self.datablock)
        self.header = [h2, h1]

    def updateTraceData(self):
        f0 = self.header[0]['Start freq value']
        dx = self.header[0]['Delta X-axis']
        n = self.header[0]['Number of elements']
        x = numpy.linspace(f0, f0+dx*n, len(self.trace))
        y = self.trace.copy()

        y = y.clip(min=self.header[1]['Min value of data'],
                   max=self.header[1]['Max value of data'],)

        if self.header[1]['Y scale factor']:
            y *= self.header[1]['Y scale factor']
        self.plot.setData(x=x, y=y)
        self.canvas.setLabel(
            'bottom', self.header[0]['Domain'],
            units=self.header[0]['X axis units'])
        self.canvas.setLabel(
            'left', self.header[1]['Y coordinates'],
            units=self.header[0]['Amplitude units'])
        self.canvas.showGrid(x=True, y=True)
        self.canvas.setTitle(self.header[0]['Display function'])


children.append(CoordBinaryDatablockMDIChild)
children.append(TraceBinaryDatablockMDIChild)
children.append(StateBinaryDatablockMDIChild)
