"""
Experimental code to take a JPG image containing a grid, straighten it up and produce a small PNG
suitable for AYAB knitting software

Licence: Do with this as you wish.
"""

import cv2
import numpy as np
import matplotlib.pyplot as plt

def find_outermost_corners(image):
    # https://docs.opencv.org/4.x/dc/d0d/tutorial_py_features_harris.html
    gray = np.float32(image)
    dst = cv2.cornerHarris(gray,2,3,0.04)
    # Threshold for an optimal value, it may vary depending on the image.
    corners = np.argwhere(dst > 0.01 * dst.max())

    top_left = (np.inf, np.inf)
    top_right = (-np.inf, np.inf)
    bottom_left = (np.inf, -np.inf)
    bottom_right = (-np.inf, -np.inf)

    # Keep updating corners closer to the outside as we find a better option
    for y,x in corners:
        if x + y < sum(top_left):
            top_left = (x, y)
        if x - y > top_right[0] - top_right[1]:
            top_right = (x, y)
        if x - y < bottom_left[0] - bottom_left[1]:
            bottom_left = (x, y)
        if x + y > sum(bottom_right):
            bottom_right = (x, y)

    return (top_left, top_right, bottom_left, bottom_right)

def find_scaled_corners(corners):
    """Here we find the largest width and height, and output a rectangle of that size"""

    def pythagoras(larger,smaller):
        return np.sqrt((larger[0] - smaller[0]) ** 2 + (larger[1] - smaller[1]) ** 2)

    top_left, top_right, bottom_left, bottom_right = corners

    top_width = pythagoras(top_right, top_left)
    bottom_width = pythagoras(bottom_right, bottom_left)
    largest_width = max(int(top_width), int(bottom_width))

    left_height = pythagoras(top_left, bottom_left)
    right_height = pythagoras(top_right, bottom_right)
    largest_height = max(int(left_height), int(right_height))

    # Convert these into co-ordinates of a rectangle of largest size
    top_left = (0,0)
    top_right = (largest_width-1,0)
    bottom_left = (0,largest_height-1)
    bottom_right = (largest_width, largest_height)

    return (top_left, top_right, bottom_left, bottom_right), largest_width, largest_height

def show_corners(image,corners):
    """Optional step to display the corners on the image"""
    image_with_corners = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    for x,y in corners:
        cv2.circle(image_with_corners, (x,y),20, (255,0,0),-1)

    plt.imshow(image_with_corners)
    plt.title("Found corners")
    plt.show()

def convert(input_filename,output_filename,look_for_read_border,stitches_w,stitches_h,filter_size):
    # Load the image
    colour = cv2.imread(input_filename)
    grey = cv2.cvtColor(colour, cv2.COLOR_RGB2GRAY)

    if look_for_read_border:
        # To assist the corner detection assume we've drawn a red line around the edge
        lower_red = np.array([0,0,254])
        upper_red = np.array([255,255,255])
        red_only_image = cv2.inRange(colour, lower_red, upper_red)

        plt.imshow(red_only_image)
        plt.title("Looking for the corners here..")
        plt.show()

        found_corners=find_outermost_corners(red_only_image)
    else:
        found_corners=find_outermost_corners(grey)

    show_corners(grey,found_corners)

    scaled_corners,width,height=find_scaled_corners(found_corners)

    #https://docs.opencv.org/4.x/da/d6e/tutorial_py_geometric_transformations.html
    #Pass two vectors of n (x,y) co-ordinates.
    transform_matrix = cv2.getPerspectiveTransform(np.float32(found_corners), np.float32(scaled_corners))
    straightened_image = cv2.warpPerspective(grey,transform_matrix,(width,height))

    # Remove the bits of the page which are outside the corners of the pattern
    pattern_only = straightened_image[0:height, 0:width]

    # Attempt to filter out the thin grid lines...
    pattern_only = cv2.blur(pattern_only,(filter_size,filter_size))
    plt.imshow(pattern_only)
    plt.title("The straightened image")
    plt.show()

    # Now resize the image to match the stitch size in the pattern
    resized_as_stitches=cv2.resize(pattern_only, (stitches_w,stitches_h))
    # ... and now turn into black OR white
    _, as_black_white = cv2.threshold(resized_as_stitches, 127, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)

    plt.imshow(as_black_white,cmap='gray')
    plt.title("The knitting pattern")
    plt.show()
    cv2.imwrite(output_filename, as_black_white)

"""
Use as follows:

- Input filename
- The filename to be written as output
- Whether the image has been annotated with a red box showing the grid position
- The number of horizontal stitches in the output
- The number of vertical stitches in the output
- The size of the blur used to remove the grid lines (need to experiment to find a good value)
"""
convert("ladybirds.jpg","ladybirds.png",False,36,98,1)
#convert("owls.jpg","owls.png",True,62,87,5)
#convert("fox.jpg","fox.png",True,61,102,4)