import {isNaN, isNumber, meanBy} from 'lodash'

import type {
  CementStrengthSample,
  PredictionStats,
  StatsRecord,
  StrengthLevel,
  TimeRange
} from '../../declarations'
import {filterSamplesByTimeRange} from '../filter'
import {calcMeanAbsoluteError, calcR2Score, isStatsRecord} from '../math/stats'

import {getActualValue, getPredictedValue, getPredictedValues} from './cementStrengthSample'

export const toStatsRecord = (sample: CementStrengthSample, strengthLevel: StrengthLevel) => {
  const actual = getActualValue(sample, strengthLevel)
  const predicted = getPredictedValue(sample, strengthLevel)
  if (isNumber(actual) && isNumber(predicted)) {
    return {
      id: sample.id,
      actual,
      predicted
    }
  }

  return undefined
}

const toStatsRecords = (
  samples: CementStrengthSample[],
  strengthLevel: StrengthLevel
): StatsRecord[] =>
  samples.map((sample) => toStatsRecord(sample, strengthLevel)).filter(isStatsRecord)

export const extractStatsRecords = (
  samples: CementStrengthSample[],
  strengthLevel: StrengthLevel,
  timeRange: TimeRange
): StatsRecord[] => toStatsRecords(filterSamplesByTimeRange(samples, timeRange), strengthLevel)

const calcAverage = (
  samples: CementStrengthSample[],
  strengthLevel: StrengthLevel
): number | undefined => {
  const avg = meanBy(getPredictedValues(samples, strengthLevel), (r) => r.value)
  return isNaN(avg) ? undefined : avg
}

export const calcCementStrengthStats = (
  samples: CementStrengthSample[],
  timeRange: TimeRange,
  strengthLevel: StrengthLevel,
  target?: number
): PredictionStats => {
  const samplesInRange = filterSamplesByTimeRange(samples, timeRange)
  const average = calcAverage(samplesInRange, strengthLevel)
  const deviationFromTarget = isNumber(target) && isNumber(average) ? average - target : undefined

  const records = toStatsRecords(samplesInRange, strengthLevel)
  if (records.length === 0) {
    return {
      average,
      deviationFromTarget
    }
  }
  const meanAbsoluteError = calcMeanAbsoluteError(records)
  const r2Value = calcR2Score(records)

  return {
    average,
    deviationFromTarget,
    meanAbsoluteError,
    r2: isNaN(r2Value) ? undefined : r2Value
  }
}
