#!/usr/bin/env python

# Copyright (C) 2012-2023  The Bscp Authors <https://github.com/bscp-tool/bscp/graphs/contributors>
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
#
# ------------------------------------------------------------------------------
# modified by Lorenzo Canovi <lorenzo@kubiclabs.com>
#
# - add: versioning (following klabs standards)
# - mod: default blocksize from 64Kb to 1Mb
# - add: fancy output (progress, speed stats)
# ------------------------------------------------------------------------------

import hashlib
import struct
import subprocess
import sys
import time

Cmd     = "bscp"
CmdVer  = "1.1"
CmdStr	= "bscp v1.1 (2025-01-03)"

if sys.version_info < (3, 0):
    range = xrange

remote_script = r'''
import hashlib
import os
import os.path
import struct
import sys
import time

if sys.version_info < (3, 0):
    stdin_buffer = sys.stdin
    stdout_buffer = sys.stdout
    range = xrange
else:
    stdin_buffer = sys.stdin.buffer
    stdout_buffer = sys.stdout.buffer

(size, blocksize, filename_len, hashname_len) = struct.unpack('<QQQQ', stdin_buffer.read(8+8+8+8))
filename_bytes = stdin_buffer.read(filename_len)
hashname_bytes = stdin_buffer.read(hashname_len)
filename = filename_bytes.decode('utf-8')
hashname = hashname_bytes.decode('ascii')

sanity_hash = hashlib.new(hashname, filename_bytes).digest()
stdout_buffer.write(sanity_hash)
stdout_buffer.flush()
if stdin_buffer.read(2) != b'go':
    sys.exit()

if not os.path.exists(filename):
    # Create sparse file
    with open(filename, 'wb') as f:
        f.truncate(size)
    os.chmod(filename, 0o600)
    sys.stderr.write(' created file %s\n' % (filename))

with open(filename, 'rb+') as f:
    f.seek(0, 2)
    stdout_buffer.write(struct.pack('<Q', f.tell()))

    cnt = 0
    t0 = time.time()
    tm = t0
    elaps = 0

    readremain = size
    rblocksize = blocksize
    f.seek(0)
    while True:
        if readremain <= blocksize:
            rblocksize = readremain
        block = f.read(rblocksize)
        if len(block) == 0:
            break
        digest = hashlib.new(hashname, block).digest()
        stdout_buffer.write(digest)
        cnt = cnt + 1
        if time.time() - tm > 1:
            tm = time.time()
            elaps = tm - t0
            sys.stderr.write(' [%.1fMb/s]' % (cnt/elaps*blocksize/1024/1024) )
            sys.stderr.write(' digesting blocks on remote: %i\r' % (cnt) )
        readremain -= rblocksize
        if readremain == 0:
            break

    sys.stderr.write( '                                                                                 \r' )

    # write loop, reads changed blocks from parent process and write on output file
    stdout_buffer.flush()
    while True:
        position_s = stdin_buffer.read(8)
        if len(position_s) == 0:
            break
        (position,) = struct.unpack('<Q', position_s)
        block = stdin_buffer.read(blocksize)
        f.seek(position)
        f.write(block)

    cnt = 0
    t0 = time.time()
    tm = t0
    elaps = 0

    readremain = size
    rblocksize = blocksize
    hash_total = hashlib.new(hashname)
    f.seek(0)
    while True:
        if readremain <= blocksize:
            rblocksize = readremain
        block = f.read(rblocksize)
        if len(block) == 0:
            break
        cnt += 1
        if time.time() - tm > 1:
            tm = time.time()
            elaps = tm - t0
            sys.stderr.write(' [%.1fMb/s]' % (cnt/elaps*blocksize/1024/1024) )
            sys.stderr.write(' reading back output and computing checksum: %s\r' % (cnt))
        hash_total.update(block)
        readremain -= rblocksize
        if readremain == 0:
            break
    sys.stderr.write( '                                                                                 \r' )

stdout_buffer.write(hash_total.digest())
'''

class IOCounter:
    def __init__(self, in_stream, out_stream):
        self.in_stream = in_stream
        self.out_stream = out_stream
        self.in_total = 0
        self.out_total = 0
    def read(self, size=None):
        if size is None:
            s = self.in_stream.read()
        else:
            s = self.in_stream.read(size)
        self.in_total += len(s)
        return s
    def write(self, s):
        self.out_stream.write(s)
        self.out_total += len(s)
        self.out_stream.flush()

def bscp(local_filename, remote_host, remote_filename, blocksize, hashname):
    remote_filename_bytes = remote_filename.encode('utf-8')
    hashname_bytes = hashname.encode('ascii')
    hash_total = hashlib.new(hashname)
    with open(local_filename, 'rb') as f:
        f.seek(0, 2)
        size = f.tell()
        f.seek(0)

        # Calculate number of blocks, including the last block which may be smaller
        blockcount = int((size + blocksize - 1) / blocksize)

        remote_command = 'python -c "%s"' % (remote_script,)
        command = ('ssh', '--', remote_host, remote_command)
        p = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=None)
        io = IOCounter(p.stdout, p.stdin)

        io.write(struct.pack('<QQQQ', size, blocksize, len(remote_filename_bytes), len(hashname_bytes)))
        io.write(remote_filename_bytes)
        io.write(hashname_bytes)

        sanity_digest = hashlib.new(hashname, remote_filename_bytes).digest()
        remote_digest = io.read(len(sanity_digest))
        if remote_digest != sanity_digest:
            raise RuntimeError('%s error! remote script failed to execute properly' % (Cmd))
        io.write(b'go')

        (remote_size,) = struct.unpack('<Q', io.read(8))
        if remote_size < size:
            raise RuntimeError('%s error! remote size less than local (local: %i, remote: %i)' % (Cmd, size, remote_size))
        remote_digest_list = [io.read(hash_total.digest_size) for i in range(blockcount)]

        cnt = 0
        cnt_w = 0
        t0 = time.time()
        tm = t0
        elaps = 0

        # reads input, compare digests and writing changed block to remote process
        for remote_digest in remote_digest_list:
            position = f.tell()
            block = f.read(blocksize)
            hash_total.update(block)
            digest = hashlib.new(hashname, block).digest()
            cnt += 1
            if digest != remote_digest:
                try:
                    io.write(struct.pack('<Q', position))
                    io.write(block)
                    cnt_w += 1
                except IOError:
                    break
            if time.time() - tm > 1:
                tm = time.time()
                elaps = tm - t0
                sys.stderr.write(' [%.1fMb/s]' % (cnt/elaps*blocksize/1024/1024) )
                sys.stderr.write(' copying ... %i / %i  blocks read/write     \r' % (cnt, cnt_w))

        sys.stderr.write(' total blocks read/write: %i / %i                             \n' % (cnt, cnt_w))
        p.stdin.close()

        remote_digest_total = io.read()
        p.wait()
        if remote_digest_total != hash_total.digest():
            raise RuntimeError('%s error! checksum mismatch after transfer' % (Cmd))

    return (io.in_total, io.out_total, size)

if __name__ == '__main__':
    blocksize   = 1024 * 1024
    hashname    = 'sha256'
    try:
        local_filename = sys.argv[1]
        (remote_host, remote_filename) = sys.argv[2].split(':')
        if len(sys.argv) >= 4:
            blocksize = int(sys.argv[3])
        if len(sys.argv) >= 5:
            hashname = sys.argv[4]
        assert len(sys.argv) <= 5
    except:
        sys.stderr.write('\n')
        sys.stderr.write('== %s - copy block devices over network using rsync like algo ==\n' % (CmdStr))
        sys.stderr.write('\nusage: %s srcfile host:destfile [blocksize] [hashname]\n' % (Cmd))
        sys.stderr.write('\n')
        sys.stderr.write('default blocksize: %i\n' % (blocksize))
        sys.stderr.write('default hashname:  %s\n' % (hashname))
        sys.stderr.write('\n')
        sys.exit(1)

    sys.stderr.write('\n')
    sys.stderr.write( ' block size: %ib (%.1fMb)\n' % (blocksize, blocksize/1024/1024) )

    (in_total, out_total, size) = bscp(local_filename, remote_host, remote_filename, blocksize, hashname)

    speedup = size * 1.0 / (in_total + out_total)
    sys.stderr.write(' %s: in=%i out=%i size=%i speedup=%.2f\n' % (Cmd, in_total, out_total, size, speedup))
    sys.exit(0)
