#cython: language_level=3

cimport cython

from cython.parallel import parallel, prange

import numpy as np
import os
import csv
import urllib.request


def load_and_parse_data(percent=100):
    """This function downloads, parses and returns the data

       percent: float
           The percentage of data to load. Load less data if
           processing is too slow

       Returns
       =======

       (ids,       : list of IDs of the different patterns - one pattern per row
        varieties, : list of the varieties that can be distinguished - one per column
        data)      : 2D numpy array of integers, -1, 0, 1, 2, which show whether
                     or not this pattern can distinguish the variety
    """
    filename = "AppleGenotypes.csv"

    if not os.path.exists(filename):
        # The file does not exist - download it!
        url = "https://raw.githubusercontent.com/chryswoods/minimalmarkers/main/example"
        filename = "AppleGenotypes.csv"

        urllib.request.urlretrieve(f"{url}/{filename}", filename)

    # Now parse the data. This reads the data using the
    # csv module, doing some formatting that is needed
    # for this type of data
    lines = open(filename, "r").readlines()

    # try to discover the separator used for the file (should be a comma)
    dialect = csv.Sniffer().sniff(lines[0], delimiters=[" ", ",", "\t"])

    # the varieties are the column headers (minus the first column
    # which is the ID code for the pattern)
    varieties = []

    for variety in list(csv.reader([lines[0]], dialect=dialect))[0][1:]:
        varieties.append(variety.lstrip().rstrip())

    ids = []
    nrows = len(lines) - 1
    ncols = len(varieties)

    if percent != 100:
        nrows = min(nrows, int(nrows * percent / 100.0))

    data = np.full((nrows, ncols), -1, np.int8)

    npatterns = 0

    irow = np.full(ncols, -1, np.int8)

    for i in range(1, nrows+1):
        parts = list(csv.reader([lines[i]], dialect=dialect))[0]

        if len(parts) != ncols+1:
            print("WARNING - invalid row! "
                  f"'{parts}' : {len(parts)} vs {ncols}")
        else:
            ids.append(parts[0])
            row = np.asarray(parts[1:], np.string_)
            pattern = data[npatterns]
            pattern[(row == b'0') | (row == b'A')] = 0
            pattern[(row == b'1') | (row == b'AB')] = 1
            pattern[(row == b'2') | (row == b'B')] = 2

            npatterns += 1

    # Verify that the data has the right dimensions
    nrows = len(ids)
    ncols = len(varieties)

    assert nrows == data.shape[0]
    assert ncols == data.shape[1]

    return (ids, varieties, data)


##############
############## We will be speeding up the code below
##############

@cython.boundscheck(False)
def calculate_scores(data):
    """Calculate the score for each row. This is calculated
       as the sum of pairs of values on each row where the values
       are not equal to each other, and neither are equal to -1.

       Returns
       =======

            scores : numpy array containing the scores
    """
    cdef int nrows = data.shape[0]
    cdef int ncols = data.shape[1]

    # Here is the list of scores
    scores = np.zeros(nrows)

    cdef double[::1] scores_view = scores
    cdef signed char[:,:] data_view = data

    cdef int irow = 0
    cdef signed char ival = 0
    cdef signed char jval = 0
    cdef int i = 0
    cdef int j = 0

    with nogil, parallel():
        # Loop over all rows
        for irow in prange(0, nrows):
            for i in range(0, ncols):
                for j in range(i, ncols):
                    ival = data_view[irow, i]
                    jval = data_view[irow, j]

                    if ival != -1 and jval != -1 and ival != jval:
                        scores_view[irow] += 1

    return scores


def get_index_of_best_score(scores):
    """Return the index of the best score from the passed list of scores"""

    # Now find the pattern with the highest score
    best_score = 0
    best_pattern = None

    for irow in range(0, len(scores)):
        if scores[irow] > best_score:
            best_pattern = irow
            best_score = scores[irow]

    return best_pattern
try:
    from setuptools import setup
except ImportError:
    from distutils.core import setup

from Cython.Build import cythonize

from distutils.extension import Extension

cyslow = Extension(
    "cyslow",
    sources=["cyslow.pyx"],
    extra_compile_args=['-O3', '-fopenmp'],
    extra_link_args=['-O3', '-fopenmp']
)

setup(
    ext_modules = cythonize(cyslow)
)
python setup.py build_ext --inplace
import cyslow

(ids, varieties, data) = cyslow.load_and_parse_data(100)
timeit(cyslow.calculate_scores(data))
17.9 ms ± 2.07 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
time python cyslow_main.py
The best score 21931.0 comes from pattern MDC054122.001_13602
python cyslow_main.py  0.38s user 0.09s system 197% cpu 0.240 total

Bonus

import tqdm

@cython.boundscheck(False)
def calculate_scores(data):
    """Calculate the score for each row. This is calculated
       as the sum of pairs of values on each row where the values
       are not equal to each other, and neither are equal to -1.

       Returns
       =======

            scores : numpy array containing the scores
    """
    cdef int nrows = data.shape[0]
    cdef int ncols = data.shape[1]

    # Here is the list of scores
    scores = np.zeros(nrows)

    cdef double[::1] scores_view = scores
    cdef signed char[:,:] data_view = data

    cdef int irow = 0
    cdef signed char ival = 0
    cdef signed char jval = 0
    cdef int i = 0
    cdef int j = 0
    cdef int start = 0
    cdef int end = 0

    nblocks = 10
    num_per_block = int(nrows / nblocks)

    while nblocks*num_per_block < nrows:
        num_per_block += 1

    for b in tqdm.tqdm(range(0, nblocks), unit_scale=num_per_block):
        start = b * num_per_block
        end = min(nrows, (b+1)*num_per_block)

        with nogil, parallel():
            # Loop over all rows
            for irow in prange(start, end):
                for i in range(0, ncols):
                    for j in range(i, ncols):
                        ival = data_view[irow, i]
                        jval = data_view[irow, j]

                        if ival != -1 and jval != -1 and ival != jval:
                            scores_view[irow] += 1

    return scores
time python cyslow_main.py
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1290/1290 [00:00<00:00, 42239.70it/s]
The best score 21931.0 comes from pattern MDC054122.001_13602
python cyslow_main.py  0.36s user 0.11s system 241% cpu 0.196 total

Back