"""
Replacement open source python toolbox for Slope, Aspect, Hillshade, and Contour Tools
Tested with ArcPro 2.5

Written by Michael Bean (w0701505@apps.losrios.edu)
For Introduction to GIS Programming (GEOG 375) Class
American River College, Sacramento, CA

Adapted mostly from sample code provided in:
    Lawhead, Joel. Learning Geospatial Analysis with Python: Understand GIS fundamentals and perform remote sensing
    data analysis using Python 3.7, 3rd Edition. Packt Publishing. Kindle Edition.
    http://git.io/vYwUX

This code is made available under The MIT License

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

import arcpy
import os
import sys
import traceback
import numpy as np
from numpy import pi
from numpy import arctan
from numpy import arctan2
from numpy import sin
from numpy import cos
from numpy import sqrt
import gdal
import ogr
import osr

# this allows GDAL to throw Python Exceptions
gdal.UseExceptions()

# Needed for numpy conversions
deg2rad = pi / 180.
rad2deg = 180. / pi

# No data value for NumPy arrays
NODATA = -9999


def doSlopeAspectHillshade(elevation_input, slope_output, aspect_output, hillshade_output, z_factor=1.0, azimuth=315.0,
                           altitude=45.0, use_gradient=False):
    """
    Adapted from:

    Lawhead, Joel. Learning Geospatial Analysis with Python: Understand GIS fundamentals and perform remote sensing
    data analysis using Python 3.7, 3rd Edition. Packt Publishing. Kindle Edition.
    http://git.io/vYwUX

    Shaded relief images using GDAL python:
    http://geoexamples.blogspot.com/2014/03/shaded-relief-images-using-gdal-python.html

    How Hillshade Works:
    https://desktop.arcgis.com/en/arcmap/10.7/tools/spatial-analyst-toolbox/how-hillshade-works.htm

    """

    # Shaded elevation parameters:
    #   zFactor is the elevation exaggeration or correction for z values in different units
    #       See https://pro.arcgis.com/en/pro-app/tool-reference/3d-analyst/applying-a-z-factor.htm
    #   azimuth is the sun direction
    #   altitude is the sun angle

    # Change altitude to zenith angle:
    zenith_deg = 90. - altitude

    # Convert to radians:
    zenith_rad = zenith_deg * deg2rad

    # Convert to right angle
    azimuth_math = 360. - azimuth + 90.

    # Convert to radians:
    azimuth_rad = azimuth_math * deg2rad if azimuth_math < 360.0 else (azimuth_math - 360.) * deg2rad

    try:
        arcpy.AddMessage("Opening elevation raster...")
        in_raster = arcpy.Raster(elevation_input)

        arcpy.AddMessage("Reading elevation values...")
        elevation_values = arcpy.RasterToNumPyArray(in_raster, "", "", "", NODATA)
        elevation_rows, elevation_columns = elevation_values.shape

        # Spatial Reference from input raster, we use for output rasters
        # maybe get smart about x, y, and z units, Esri still uses z-Factor to normalize units
        sr = in_raster.spatialReference

        cell_width = in_raster.meanCellWidth
        cell_height = in_raster.meanCellHeight

        lower_left = in_raster.extent.lowerLeft

        if use_gradient:  # Try NumPy.gradient
            # Not exactly the same output as 3x3 windows, different formula
            dz_dx, dz_dy = np.gradient(elevation_values, cell_width, cell_height)
        else:  # Create 3x3 windows over elevation values, window array will contain 9 array views after processing
            window = []
            for row in range(3):
                for col in range(3):
                    window.append(elevation_values[row:(row + elevation_rows - 2),
                                  col:(col + elevation_columns - 2)])

            # Calculate change in elevation in X direction
            dz_dx = ((window[2] + window[5] + window[5] + window[8]) -
                     (window[0] + window[3] + window[3] + window[6])) / \
                    (8. * cell_width)

            # Calculate change in elevation in Y direction
            dz_dy = ((window[6] + window[7] + window[7] + window[8]) -
                     (window[0] + window[1] + window[1] + window[2])) / \
                    (8. * cell_height)

            # We actually create output rasters that are one cell smaller on edges as we don't have data for 3x3 window
            lower_left.X += cell_width
            lower_left.Y += cell_height

        arcpy.AddMessage("Calculating slope raster...")
        slope_rad = arctan(z_factor * sqrt(dz_dx * dz_dx + dz_dy * dz_dy))

        if aspect_output or hillshade_output:
            arcpy.AddMessage("Calculating aspect raster...")
            aspect_rad = arctan2(dz_dx, -dz_dy)

            if hillshade_output:
                arcpy.AddMessage("Calculating hillshade raster...")
                hillshade = np.clip(255 * ((cos(zenith_rad) * cos(slope_rad)) + (
                            sin(zenith_rad) * sin(slope_rad) * cos(azimuth_rad - aspect_rad))), 0, 255).astype(np.uint8)

                arcpy.AddMessage("Writing hillshade raster to: " + hillshade_output)

                if arcpy.Exists(hillshade_output):
                    arcpy.Delete_management(hillshade_output)
                hillshade_raster = arcpy.NumPyArrayToRaster(hillshade, lower_left, cell_width, cell_height)
                arcpy.DefineProjection_management(hillshade_raster, sr)
                hillshade_raster.save(hillshade_output)
                arcpy.AddMessage("Saved hillshade raster")

            if aspect_output:
                arcpy.AddMessage("Writing aspect raster to: " + aspect_output)

                # If terrain is flat, Esri outputs -1 as aspect, see:
                # https://desktop.arcgis.com/en/arcmap/10.7/tools/spatial-analyst-toolbox/how-aspect-works.htm
                aspect = np.where(slope_rad == 0., -1, aspect_rad * rad2deg + 180)

                if arcpy.Exists(aspect_output):
                    arcpy.Delete_management(aspect_output)
                aspectRaster = arcpy.NumPyArrayToRaster(aspect, lower_left, cell_width, cell_height)
                arcpy.DefineProjection_management(aspectRaster, sr)
                aspectRaster.save(aspect_output)
                arcpy.AddMessage("Saved aspect raster")

        slope = slope_rad * rad2deg

        # todo: Need to determine how to deal with NODATA in a cell not on an edge!!!
        # By using -9999 we get a slope value close to 90 degrees
        slope[(slope > 89)] = NODATA

        if slope_output:
            arcpy.AddMessage("Writing slope raster to: " + slope_output)
            if arcpy.Exists(slope_output):
                arcpy.Delete_management(slope_output)
            slope_raster = arcpy.NumPyArrayToRaster(slope, lower_left, cell_width, cell_height, value_to_nodata=NODATA)
            arcpy.DefineProjection_management(slope_raster, sr)
            slope_raster.save(slope_output)
            arcpy.AddMessage('Saved slope raster')

        arcpy.AddMessage("Finished saving rasters")

    except:
        tb = sys.exc_info()[2]
        tbinfo = traceback.format_tb(tb)[0]
        pymsg = "PYTHON ERRORS:\nTraceback Info:\n" + tbinfo + "\nError Info:\n" + str(sys.exc_info()[1])
        msgs = "ARCPY ERRORS:\n" + arcpy.GetMessages(2) + "\n"

        arcpy.AddError(msgs)
        arcpy.AddError(pymsg)

        arcpy.AddMessage(arcpy.GetMessages(1))


def doContours(elevation_input, contour_feature_output, contour_interval=10., contour_base=0.):
    """
    Adapted from:

    Lawhead, Joel. Learning Geospatial Analysis with Python: Understand GIS fundamentals and perform remote sensing
    data analysis using Python 3.7, 3rd Edition. Packt Publishing. Kindle Edition.
    http://git.io/vYwUX
    """

    # Get elevation values from elevation raster
    arcpy.AddMessage("Opening elevation raster...")
    ds = gdal.Open(elevation_input)
    band = ds.GetRasterBand(1)
    nodata = band.GetNoDataValue()
    prj = ds.GetProjection()
    srs = osr.SpatialReference(wkt=prj)

    if contour_feature_output.endswith(".shp"):
        arcpy.AddMessage("Creating shapefile for output...")
        scratch_name = None
        contour_shapefile = contour_feature_output
    else:  # Create a temporary shapefile
        arcpy.AddMessage("Creating temporary shapefile for output...")
        scratch_name = arcpy.CreateScratchName("temp", data_type="Shapefile", workspace=arcpy.env.scratchFolder)
        contour_shapefile = scratch_name

    # Get OGR driver for shapefile
    ogr_driver = ogr.GetDriverByName('ESRI Shapefile')

    # Create shapefile for output
    ogr_ds = ogr_driver.CreateDataSource(contour_shapefile)

    # Create layer for new shapefile
    ogr_lyr = ogr_ds.CreateLayer(contour_shapefile, srs=srs, geom_type=ogr.wkbLineString25D)

    # Create ID field
    field_defn = ogr.FieldDefn('ID', ogr.OFTInteger)
    ogr_lyr.CreateField(field_defn)

    # Create ELEV field
    field_defn = ogr.FieldDefn('ELEV', ogr.OFTReal)
    ogr_lyr.CreateField(field_defn)

    try:
        # gdal.ContourGenerate() arguments
        # Band srcBand,
        # double contourInterval,
        # double contourBase,
        # double[] fixedLevelCount,
        # int useNoData,
        # double noDataValue,
        # Layer dstLayer,
        # int idField,
        # int elevField
        arcpy.AddMessage("Generating contours...")
        gdal.ContourGenerate(band, contour_interval, contour_base, [], 1, nodata, ogr_lyr, 0, 1)
        arcpy.AddMessage("Contours generated.")

        # Close files so we can delete if temporary
        del ogr_lyr, ogr_ds

        if scratch_name:
            arcpy.AddMessage("Copying contours...")
            arcpy.CopyFeatures_management(contour_shapefile, contour_feature_output)

            # Delete temp shapefile
            arcpy.AddMessage("Deleting temporary shapefile...")
            arcpy.Delete_management(contour_shapefile)
    except:
        tb = sys.exc_info()[2]
        tbinfo = traceback.format_tb(tb)[0]
        pymsg = "PYTHON ERRORS:\nTraceback Info:\n" + tbinfo + "\nError Info:\n" + str(sys.exc_info()[1])
        msgs = "ARCPY ERRORS:\n" + arcpy.GetMessages(2) + "\n"

        arcpy.AddError(msgs)
        arcpy.AddError(pymsg)

        arcpy.AddMessage(arcpy.GetMessages(1))


class Toolbox(object):
    def __init__(self):
        """Define the toolbox (the name of the toolbox is the name of the
        .pyt file)."""
        self.label = "Elevation Tools"
        self.alias = ""

        # List of tool classes associated with this toolbox
        self.tools = [SlopeAspectHillshade, Contours]


class SlopeAspectHillshade(object):
    def __init__(self):
        """Define the tool (tool name is the name of the class)."""
        self.label = "Slope Aspect Hillshade"
        self.description = "Generate Slope, Aspect, and/or Hillshade rasters using NumPy arrays"
        self.canRunInBackground = False

    def getParameterInfo(self):
        """Define parameter definitions"""

        elevation_input = arcpy.Parameter(
            displayName="Elevation Input",
            name="elevation_input",
            datatype="DERasterDataset",
            parameterType="Required",
            direction="Input")

        z_factor = arcpy.Parameter(
            displayName="Z-Factor",
            name="z_factor",
            datatype="GPDouble",
            parameterType="Optional",
            direction="Input")

        z_factor.value = "1.0"

        sun_azimuth = arcpy.Parameter(
            displayName="Sun Azimuth",
            name="sun_azimuth",
            datatype="GPDouble",
            parameterType="Optional",
            direction="Input")

        sun_azimuth.value = "315.0"

        sun_angle = arcpy.Parameter(
            displayName="Sun Angle",
            name="sun_angle",
            datatype="GPDouble",
            parameterType="Optional",
            direction="Input")

        sun_angle.value = "45.0"

        slope_output = arcpy.Parameter(
            displayName="Slope Output",
            name="slope_output",
            datatype="DERasterDataset",
            parameterType="Optional",
            direction="Output")

        # Use __file__ attribute to find the .lyr file (assuming the
        #  .pyt and .lyrx files exist in the same folder).
        slope_symbology_layer = os.path.join(os.path.dirname(__file__), 'slopeSymbology.lyrx')
        if arcpy.Exists(slope_symbology_layer):
            slope_output.symbology = slope_symbology_layer

        aspect_output = arcpy.Parameter(
            displayName="Aspect Output",
            name="aspect_output",
            datatype="DERasterDataset",
            parameterType="Optional",
            direction="Output")

        # Use __file__ attribute to find the .lyr file (assuming the
        #  .pyt and .lyrx files exist in the same folder).
        # todo: This crashes ArcGIS Pro 2.5, same lyrx file can be imported without error
        aspect_symbology_layer = os.path.join(os.path.dirname(__file__), 'aspectSymbology.lyrx')
        if arcpy.Exists(aspect_symbology_layer):
            aspect_output.symbology = aspect_symbology_layer

        hillshade_output = arcpy.Parameter(
            displayName="Hillshade Output",
            name="hillshade_output",
            datatype="DERasterDataset",
            parameterType="Optional",
            direction="Output")

        params = [elevation_input, z_factor, sun_azimuth, sun_angle, slope_output, aspect_output, hillshade_output]

        return params

    def isLicensed(self):
        """Set whether tool is licensed to execute."""
        return True

    def updateParameters(self, parameters):
        """Modify the values and properties of parameters before internal
        validation is performed.  This method is called whenever a parameter
        has been changed."""
        return

    def updateMessages(self, parameters):
        """Modify the messages created by internal validation for each tool
        parameter.  This method is called after internal validation."""
        return

    def execute(self, parameters, messages):
        # Get parameters
        elevation_input = parameters[0].valueAsText
        z_factor = parameters[1].valueAsText
        sun_azimuth = parameters[2].valueAsText
        sun_angle = parameters[3].valueAsText
        slope_output = parameters[4].valueAsText
        aspect_output = parameters[5].valueAsText
        hillshade_output = parameters[6].valueAsText

        # set overwrite output to True to overwrite intermediate and final output
        # if the files exist
        arcpy.env.overwriteOutput = True

        doSlopeAspectHillshade(elevation_input, slope_output, aspect_output, hillshade_output, float(z_factor),
                               float(sun_azimuth), float(sun_angle))

        return


class Contours(object):
    def __init__(self):
        """Define the tool (tool name is the name of the class)."""
        self.label = "GDAL Contours"
        self.description = "Generate contours using open source OGR and GDAL libraries"
        self.canRunInBackground = False

    def getParameterInfo(self):
        """Define parameter definitions"""

        elevation_input = arcpy.Parameter(
            displayName="Elevation Input",
            name="elevation_raster",
            datatype="DERasterDataset",
            parameterType="Required",
            direction="Input")

        contour_feature_output = arcpy.Parameter(
            displayName="Contour Feature Output",
            name="contour_feature_output",
            datatype="DEFeatureClass",
            parameterType="Required",
            direction="Output")

        contour_interval = arcpy.Parameter(
            displayName="Contour Interval",
            name="contour_interval",
            datatype="GPDouble",
            parameterType="Optional",
            direction="Input")
        contour_interval.value = "10.0"

        contour_base = arcpy.Parameter(
            displayName="Contour Base",
            name="contour_base",
            datatype="GPDouble",
            parameterType="Optional",
            direction="Input")
        contour_base.value = "0.0"

        params = [elevation_input, contour_feature_output, contour_interval, contour_base]

        return params

    def isLicensed(self):
        """Set whether tool is licensed to execute."""
        return True

    def updateParameters(self, parameters):
        """Modify the values and properties of parameters before internal
        validation is performed.  This method is called whenever a parameter
        has been changed."""
        return

    def updateMessages(self, parameters):
        """Modify the messages created by internal validation for each tool
        parameter.  This method is called after internal validation."""
        return

    def execute(self, parameters, messages):
        # Get parameters
        elevation_input = parameters[0].valueAsText
        contour_feature_output = parameters[1].valueAsText
        contour_interval = parameters[2].valueAsText
        contour_base = parameters[2].valueAsText

        # set overwrite output to True to overwrite intermediate and final output
        # if the files exist
        arcpy.env.overwriteOutput = True

        doContours(elevation_input, contour_feature_output, float(contour_interval), float(contour_base))

        return


if __name__ == "__main__":
    # For testing we use directory of python script for workspace
    script_dir = arcpy.env.workspace = os.path.dirname(os.path.abspath(__file__))

    # Test TIFF
    doSlopeAspectHillshade(os.path.join(script_dir, "Clip_grdn39w121_13.tif"), os.path.join(script_dir, "slope.tif"),
                           os.path.join(script_dir, "aspect.tif"), os.path.join(script_dir, "hillshade.tif"),
                           0.00001171)
    # Test just slope
    doSlopeAspectHillshade(os.path.join(script_dir, "Clip_grdn39w121_13.tif"),
                           os.path.join(script_dir, "justSlope.tif"),
                           "", "", 0.00001171, use_gradient=True)
