#!/usr/bin/python

# This file is part of vdr-webvideo-plugin.
#
# Copyright 2009 Antti Ajanki <antti.ajanki@iki.fi>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import socket
import cStringIO
import sys
import cmd
import mimetypes
import urllib
import select
import os.path
import subprocess
import time
import libxml2
from optparse import OptionParser

version = '0.1.1'

# One of these media players is used to play streams with STREAM
# command. %s is replaced by the stream URL. The players are tried in
# the listed order until one of them returns return code 0.
# TODO: this should be configurable
streamplayers = ['vlc "%s"', 'mplayer "%s"', 'xine "%s"']

mimetypes.init()
# These mimetypes are common but often missing
mimetypes.add_type('video/flv', '.flv')
mimetypes.add_type('video/x-flv', '.flv')

def getContentUnicode(node):
    # node.getContent() returns an UTF-8 encoded sequence of bytes (a
    # string). Convert it to a unicode object.
    return unicode(node.getContent(), 'UTF-8')

class WVClient:
    DEFAULT_URL_PRIORITY = 100
    
    def __init__(self, addr, port):
        self.addr = addr
        self.port = port
        self.buffer = ''
        self.history = []
        self.history_pointer = 0
        self.sock = None

    def read_headers(self):
        expectresponseline = True
        responsecode = -1
        responsetext = ''
        headers = {}
        buf = ''
        while True:
            try:
                # FIXME
                self.sock.setblocking(0)
                r = ''
                while r.find('\r\n\r\n') == -1:
                    select.select([self.sock.fileno()], [], [])
                    re = self.sock.recv(4096)
                    if len(re) == 0:
                        break
                    r += re
                self.sock.setblocking(1)
            except socket.error, e:
                print 'error while reading headers',  e
                break
            if len(r) == 0:
                break

            buf = buf+r
            lf = buf.find('\r\n')
            while (lf != -1) and (lf != 0):
                line = buf[:lf]
                buf = buf[lf+2:]
                if expectresponseline:
                    # The first line contains the response code
                    responseline = line
                    try:
                        protover, code, responsetext = responseline.split(' ', 2)
                        responsecode = int(code)
                    except ValueError:
                        pass
                    expectresponseline = False
                else:
                    # Read one header line off the buffer
                    try:
                        name, value = line.split(':')
                        headers[name.lower()] = value.strip()
                    except ValueError:
                        pass
                lf = buf.find('\r\n')
            if lf == 0:
                # headers end with an empty line
                buf = buf[2:]
                break
        self.buffer = buf
        return (responsecode, responsetext, headers)

    def write_body(self, dest, contentlen):
        """Reads contentlen bytes from socket and writes them to a
        file-like object dest. Returns the number of bytes actully
        written."""
        remaining = contentlen
        while remaining > 0:
            if remaining > 4096:
                r = self.read(4096)
            else:
                r = self.read(remaining)
            if not r:
                # The connection was closed. Invalidate the socket so
                # that we will reconnect on next request.
                self.sock = None
                break
            dest.write(r)
            remaining -= len(r)
        return contentlen-remaining

    def read(self, bytes):
        if len(self.buffer) > 0:
            ret = self.buffer[:bytes]
            self.buffer = self.buffer[bytes:]
            return ret
        else:
            try:
                return self.sock.recv(bytes)
            except socket.error, e:
                print e
                return ''

    def _try_connect(self):
        self.sock = socket.socket(socket.AF_INET)
        try:
            self.sock.connect((self.addr, self.port))
        except socket.error, e:
            print 'Error connecting to the server:', e.args[1]
            return e.args[0]
        return 0

    def connect(self):
        self.buffer = ''
        e = self._try_connect()
        if e == 0:
            return True
        elif e == 111: # Connection refused
            # Maybe the server is not running? Try connecting again
            # after starting the server. Server can only be started
            # when connecting to localhost.
            if self.addr in ['127.0.0.1', 'localhost']:
                print 'Trying to start the server'
                try:
                    # TODO: path and parameters should be configurable
                    subprocess.call(
                        ['/usr/bin/webvid', '-d', '-l/tmp/webvid.log'])
                except OSError, e:
                    print 'Failed to start the server: ' + e.args[1]
                    return False
                time.sleep(0.5)
                e = self._try_connect()
                if e == 0:
                    return True
        return False

    def close(self):
        if self.sock is not None:
            try:
                self.sock.sendall('CLOSE X WVTP/1.0\r\n\r\n')
                self.sock.close()
            except socket.error:
                # ignore errors while closing
                pass

    def parse_page(self, page):
        if page is None:
            return None
        try:
            doc = libxml2.parseDoc(page)
        except libxml2.parserError:
            return None

        root = doc.getRootElement()
        if root.name != 'wvmenu':
            return None
        queryitems = []
        menu = Menu()
        node = root.children
        while node:
            if node.name == 'title':
                menu.title = getContentUnicode(node)
            elif node.name == 'link':
                mi = self.parse_link(node)
                menu.add(mi)
            elif node.name == 'textfield':
                mi = self.parse_textfield(node)
                menu.add(mi)
                queryitems.append(mi)
            elif node.name == 'itemlist':
                mi = self.parse_itemlist(node)
                menu.add(mi)
                queryitems.append(mi)
            elif node.name == 'textarea':
                mi = self.parse_textarea(node)
                menu.add(mi)
            elif node.name == 'button':
                mi = self.parse_button(node, queryitems)
                menu.add(mi)
            node = node.next
        doc.freeDoc()
        return menu
        
    def parse_link(self, node):
        label = ''
        ref = None
        obj = None
        child = node.children
        while child:
            if child.name == 'label':
                label = getContentUnicode(child)
            elif child.name == 'ref':
                ref = getContentUnicode(child)
            elif child.name == 'object':
                obj = getContentUnicode(child)
            child = child.next
        return MenuItemLink(label, ref, obj)

    def parse_textfield(self, node):
        label = ''
        id = node.prop('id')
        child = node.children
        while child:
            if child.name == 'label':
                label = getContentUnicode(child)
            child = child.next
        return MenuItemTextField(label, id)

    def parse_textarea(self, node):
        label = ''
        child = node.children
        while child:
            if child.name == 'label':
                label = getContentUnicode(child)
            child = child.next
        return MenuItemTextArea(label)

    def parse_itemlist(self, node):
        label = ''
        id = node.prop('id')
        items = []
        values = []
        child = node.children
        while child:
            if child.name == 'label':
                label = getContentUnicode(child)
            elif child.name == 'item':
                items.append(getContentUnicode(child))
                values.append(child.prop('value'))
            child = child.next
        return MenuItemList(label, id, items, values)

    def parse_button(self, node, queryitems):
        label = ''
        submission = None
        child = node.children
        while child:
            if child.name == 'label':
                label = getContentUnicode(child)
            elif child.name == 'submission':
                submission = getContentUnicode(child)
            child = child.next
        return MenuItemSubmitButton(label, submission, queryitems)

    def parse_mediaurl(self, xml):
        try:
            doc = libxml2.parseDoc(xml)
        except libxml2.parserError:
            print 'Failed to parse mediaurl'
            return None, None

        title = '???'
        urls_and_priorities = []
        root = doc.getRootElement()
        if root is None:
            return None, None
        node = root.children
        while node:
            if node.name == 'title':
                title = getContentUnicode(node)
            elif node.name == 'url':
                p = node.prop('priority')
                if p is None:
                    p = self.DEFAULT_URL_PRIORITY
                urls_and_priorities.append((p, getContentUnicode(node)))
            node = node.next
        doc.freeDoc()
        if not title:
            title = '???'
        urls_and_priorities.sort()
        urls_and_priorities.reverse()
        urls = [b for a,b in urls_and_priorities]
        return title, urls

    def get_headers_raw(self, ref, verb='GET'):
        if (self.sock is None) and (not self.connect()):
            return (-1, 'Can\'t connect to server', None)
        try:
            self.sock.sendall('%s %s WVTP/1.0\r\n\r\n' % (verb, ref))
        except socket.error, e:
            return (-1, 'Error in sendall: ' + str(e), None)
        responsecode, responsetext, headers = self.read_headers()
        return (responsecode, responsetext, headers)
        
    def get_raw(self, ref, verb='GET'):
        status, statuspharse, headers = self.get_headers_raw(ref, verb)
        if status != 200:
            return (status, statuspharse, None, None)
        body = cStringIO.StringIO()
        contentlen = int(headers.get('content-length', 0))
        r = self.write_body(body, contentlen)
        if r != contentlen:
            status = -1
            statuspharse = 'Length of body (%d) differs from expected (%d)' % (r, contentlen)
        return (status, statuspharse, headers, body.getvalue())

    def get(self, ref):
        status, statusmsg, headers, body = self.get_raw(ref)
        return (status, statusmsg, self.parse_page(body))

    def download(self, obj):
        status, statusmsg, headers, body = self.get_raw(obj)
        if body is None:
            return False
        title, urls = self.parse_mediaurl(body)
        if urls is None:
            print 'No URLs!'
            return False

        for url in urls:
            status, statuspharse, headers = self.get_headers_raw(url, 'DOWNLOAD')
            if status == 200:
                ext = mimetypes.guess_extension(headers.get('content-type', ''))
                if ext is None:
                    ext = ''
                filename = self.next_available_file_name(self.safe_filename(title), ext)
                contentlen = int(headers.get('content-length', 0))
                f = None
                try:
                    f = open(filename, 'w')
                    r = self.write_body(f, contentlen)
                    f.close()
                except IOError, e:
                    print 'IOError while writing to %s: %s' % (filename, e.args[1])
                    return False
                print 'Saved to %s' % filename
                if r != contentlen:
                    print 'Warning: the size of the file (%d) differs from expected (%d)' % (r, contentlen)
                return True

        print 'No valid url found'
        return False

    def safe_filename(self, name):
        """Sanitize a filename. No paths (replace '/' -> '!') and no
        names starting with a dot."""
        return name.replace('/', '!').lstrip('.')
        
    def next_available_file_name(self, basename, ext):
        if not os.path.exists(basename + ext):
            return basename + ext
        i = 1
        while os.path.exists('%s-%d%s' % (basename, i, ext)):
            i += 1
        return '%s-%d%s' % (basename, i, ext)

    def play_stream(self, obj):
        status, statusmsg, headers, body = self.get_raw(obj)
        if body is None:
            return False
        title, urls = self.parse_mediaurl(body)
        if urls is None:
            print 'No URLs!'
            return False

        for url in urls:
            status, statusmsg, headers, body = self.get_raw(url, 'STREAM')
            if body is not None:
                title, streamurls = self.parse_mediaurl(body)
                if streamurls is not None:
                    for pl in streamplayers:
                        try:
                            print 'Trying player: ' + pl % streamurls[0]
                            retcode = subprocess.call(pl % streamurls[0], shell=True)
                            if retcode > 0:
                                print 'Player failed with returncode', retcode
                            else:
                                return True
                        except OSError, e:
                            print 'Execution failed:', e
                            retcode = -1
        return False

    def get_current_menu(self):
        if (self.history_pointer >= 0) and \
               (self.history_pointer < len(self.history)):
            return self.history[self.history_pointer]
        else:
            return None

    def history_add(self, menu):
        if menu is not None:
            self.history = self.history[:(self.history_pointer+1)]
            self.history.append(menu)
            self.history_pointer = len(self.history)-1

    def history_back(self):
        if self.history_pointer > 0:
            self.history_pointer -= 1
        return self.get_current_menu()

    def history_forward(self):
        if self.history_pointer < len(self.history)-1:
            self.history_pointer += 1
        return self.get_current_menu()


class Menu:
    def __init__(self):
        self.title = None
        self.items = []

    def __str__(self):
        s = u''
        if self.title:
            s = self.title + '\n' + '='*len(self.title) + '\n'
        for i,it in enumerate(self.items):
            s += u'%d. %s\n' % (i+1, unicode(it))
        return s

    def __getitem__(self, i):
        return self.items[i]

    def __len__(self):
        return len(self.items)

    def add(self, menuitem):
        self.items.append(menuitem)


class MenuItemLink:
    def __init__(self, label, ref, obj):
        self.label = unicode(label)
        self.ref = ref
        self.obj = obj

    def __str__(self):
        res = self.label
        if not self.obj:
            res = '[' + res + ']'
        return res

    def activate(self):
        return self.ref


class MenuItemTextField:
    def __init__(self, label, id):
        self.label = unicode(label)
        self.id = id
        self.value = u''

    def __str__(self):
        return u'%s: %s' % (self.label, self.value)

    def get_query_string(self):
        return '%s=%s' % (self.id, urllib.quote_plus(self.value.encode('utf-8')))

    def activate(self):
        self.value = unicode(raw_input('%s> ' % self.label), sys.stdin.encoding)
        return None


class MenuItemTextArea:
    def __init__(self, label):
        self.label = unicode(label)

    def __str__(self):
        return self.label

    def activate(self):
        return None


class MenuItemList:
    def __init__(self, label, id, items, values):
        self.label = unicode(label)
        self.id = id
        assert len(items) == len(values)
        self.items = items
        self.values = values
        self.current = 0

    def __str__(self):
        itemstrings = [self.label + ':']
        for i,k in enumerate(self.items):
            if i == self.current:
                itemstrings.append('<' + k + '>')
            else:
                itemstrings.append(k)
        return u' '.join(itemstrings)

    def get_query_string(self):
        if (self.current >= 0) and (self.current < len(self.items)) \
               and (self.values[self.current] != ''):
            return '%s=%s' % (self.id, self.values[self.current])
        else:
            return None

    def activate(self):
        tmp = raw_input('Select item (1-%d)> ' % len(self.items))
        try:
            i = int(tmp)
            if (i < 1) or (i > len(self.items)):
                raise ValueError
            self.current = i-1
        except ValueError:
            print 'Must be an integer in the range 1 - %d' % len(self.items)
        return None


class MenuItemSubmitButton:
    def __init__(self, label, baseurl, subitems):
        self.label = unicode(label)
        self.baseurl = baseurl
        self.subitems = subitems

    def __str__(self):
        return '[' + self.label + ']'

    def activate(self):
        substrings = [x.get_query_string() for x in self.subitems if x.get_query_string() is not None]
        return self.baseurl + '?' + '&'.join(substrings)


class WVShell(cmd.Cmd):
    def __init__(self, client, completekey='tab', stdin=None, stdout=None):
        cmd.Cmd.__init__(self, completekey, stdin, stdout)
        self.prompt = '> '
        self.client = client

    def preloop(self):
        self.stdout.write('wvclient %s starting\n' % version)
        self.do_menu(None)

    def precmd(self, arg):
        try:
            int(arg)
            return 'select ' + arg
        except ValueError:
            return arg

    def onecmd(self, c):
        try:
            return cmd.Cmd.onecmd(self, c)
        except Exception:
            import traceback
            print 'Exception occured while handling command "' + c + '"'
            print traceback.format_exc()
            return False

    def display_menu(self, menu):
        if menu is not None:
            self.stdout.write(unicode(menu).encode(self.stdout.encoding, 'replace'))
    
    def _get_numbered_item(self, arg):
        menu = self.client.get_current_menu()
        try:
            v = int(arg)-1
            if (v < 0) or (v >= len(menu)):
                raise ValueError
        except ValueError:
            self.stdout.write('Invalid selection: %s\n' % arg)
            return None
        return menu[v]
        
    def do_select(self, arg):
        """select x
Select the link whose index is x.
        """
        menuitem = self._get_numbered_item(arg)
        if menuitem is None:
            return False
        ref = menuitem.activate()
        if ref is not None:
            status, statusmsg, menu = self.client.get(ref)
            if menu is not None:
                self.client.history_add(menu)
            else:
                self.stdout.write('Error: %d %s\n' % (status, statusmsg))
        else:
            menu = self.client.get_current_menu()
        self.display_menu(menu)
        return False

    def do_download(self, arg):
        """download x
Download media object whose index is x to a file. Downloadable items
are the ones without brackets.
        """
        menuitem = self._get_numbered_item(arg)
        if menuitem is None:
            return False
        elif hasattr(menuitem, 'obj') and menuitem.obj is not None:
            self.client.download(menuitem.obj)
        else:
            self.stdout.write('Not a media object\n')
        return False

    def do_stream(self, arg):
        """stream x
Play the media file whose index is x. Media objects are the ones
without brackets.
        """
        menuitem = self._get_numbered_item(arg)
        if menuitem is None:
            return False
        elif hasattr(menuitem, 'obj') and menuitem.obj is not None:
            self.client.play_stream(menuitem.obj)
        else:
            self.stdout.write('Not a media object\n')
        return False

    def do_display(self, arg):
        """Redisplay the current menu."""
        if not arg:
            self.display_menu(self.client.get_current_menu())
        else:
            self.stdout('Unknown parameter %s\n' % arg)
        return False

    def do_menu(self, arg):
        """Get back to the main menu."""
        status, statusmsg, menu = self.client.get('/mainmenu')
        if menu is not None:
            self.client.history_add(menu)
            self.display_menu(menu)
        else:
            self.stdout.write('Error: %d %s\n' % (status, statusmsg))
        return False

    def do_back(self, arg):
        """Go to the previous menu in the history."""
        menu = self.client.history_back()
        self.display_menu(menu)
        return False

    def do_forward(self, arg):
        """Go to the next menu in the history."""
        menu = self.client.history_forward()
        self.display_menu(menu)
        return False

    def do_quit(self, arg):
        """Quit the program."""
        self.client.close()
        return True

    def do_EOF(self, arg):
        """Quit the program."""
        self.client.close()
        return True


DEFAULT_SERVER = '127.0.0.1'
DEFAULT_PORT = 2357

def main():
    parser = OptionParser()
    parser.add_option('-s', '--server', type='string', dest='server',
                      help='connect to SERVER', metavar='SERVER',
                      default=DEFAULT_SERVER)
    parser.add_option('-p', '--port', type='int', dest='port',
                      help='connect to PORT', metavar='PORT',
                      default=DEFAULT_PORT)
    (options, args) = parser.parse_args()

    shell = WVShell(WVClient(options.server, options.port))
    shell.cmdloop()

if __name__ == '__main__':
    main()
