233 lines
11 KiB
C#
233 lines
11 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
|
|
public static class TrainSplit
|
|
{
|
|
public static int[] GenerateScrambleTemplate(this int size, int? randomStateSeed = null)
|
|
{
|
|
var randomizer = new Random(randomStateSeed ?? (int)(DateTime.Now - new DateTime(1970, 1, 1)).TotalNanoseconds);
|
|
int[] scrambleTemplate = Enumerable.Range(0, size).ToArray();
|
|
randomizer.Shuffle(scrambleTemplate);
|
|
return scrambleTemplate;
|
|
}
|
|
public static double[][] Shuffle(this int[] scrambleTemplate, double[][] data)
|
|
{
|
|
var shuffled = data.Shape().Allocate2D<double>(0);
|
|
for (int i = 0; i < data.Length; i++)
|
|
for (int j = 0; j < data.Length; j++)
|
|
shuffled[i][scrambleTemplate[j]] = data[i][j];
|
|
return shuffled;
|
|
}
|
|
public static T[] Shuffle<T>(this int[] scrambleTemplate, T[] vec)
|
|
{
|
|
var shuffled = vec.Shape().Allocate1D(vec[0]);
|
|
for (int i = 0; i < vec.Length; i++)
|
|
shuffled[scrambleTemplate[i]] = vec[i];
|
|
return shuffled;
|
|
}
|
|
public static double[][] Shuffle(this double[][] data, int? randomStateSeed = null)
|
|
{
|
|
int[] scrambleTemplate = data[0].Length.GenerateScrambleTemplate(randomStateSeed);
|
|
return scrambleTemplate.Shuffle(data);
|
|
}
|
|
public static (string[], double[][]) Shuffle(this (string[] ids, double[][] data) tpl, int? randomStateSeed = null)
|
|
{
|
|
int[] scrambleTemplate = tpl.data[0].Length.GenerateScrambleTemplate(randomStateSeed);
|
|
return (
|
|
scrambleTemplate.Shuffle(tpl.ids),
|
|
scrambleTemplate.Shuffle(tpl.data)
|
|
);
|
|
}
|
|
public static (string[], string?[], double[][]) Shuffle(this (string[] ids, string?[] labels, double[][] data) tpl, int? randomStateSeed = null)
|
|
{
|
|
int[] scrambleTemplate = tpl.data[0].Length.GenerateScrambleTemplate(randomStateSeed);
|
|
return (
|
|
scrambleTemplate.Shuffle(tpl.ids),
|
|
scrambleTemplate.Shuffle(tpl.labels),
|
|
scrambleTemplate.Shuffle(tpl.data)
|
|
);
|
|
}
|
|
public static ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][])) SplitTTVSimple(this (string[] ids, string?[] labels, double[][] data) tpl, double ratioTrain, double ratioTest, double ratioValidation)
|
|
{
|
|
if (tpl.ids.Length == 0)
|
|
return (([], [], []), ([], [], []), ([], [], []));
|
|
var (ids, labels, data) = tpl;
|
|
var features = data.Length;
|
|
var nSamples = data[0].Length;
|
|
int nTrain = (int)Math.Round(nSamples * ratioTrain / (ratioTrain + ratioTest + ratioValidation));
|
|
int nTest = (int)Math.Round(nSamples * ratioTest / (ratioTrain + ratioTest + ratioValidation));
|
|
int nVal = nSamples - (nTrain + nTest);
|
|
var iTrain = nTrain.Allocate1D("");
|
|
var iTest = nTest.Allocate1D("");
|
|
var iVal = nVal.Allocate1D("");
|
|
var lTrain = nTrain.Allocate1D<string?>(null);
|
|
var lTest = nTest.Allocate1D<string?>(null);
|
|
var lVal = nVal.Allocate1D<string?>(null);
|
|
var dTrain = (features, nTrain).Allocate2D<double>(0);
|
|
var dTest = (features, nTest).Allocate2D<double>(0);
|
|
var dVal = (features, nVal).Allocate2D<double>(0);
|
|
for (int f = 0; f < features; f++)
|
|
for (int s = 0; s < nTrain; s++)
|
|
dTrain[f][s] = data[f][s];
|
|
for (int f = 0; f < features; f++)
|
|
for (int s = 0; s < nTest; s++)
|
|
dTest[f][s] = data[f][s + nTrain];
|
|
for (int f = 0; f < features; f++)
|
|
for (int s = 0; s < nVal; s++)
|
|
dVal[f][s] = data[f][s + nTrain + nTest];
|
|
for (int s = 0; s < nTrain; s++)
|
|
iTrain[s] = ids[s];
|
|
for (int s = 0; s < nTest; s++)
|
|
iTest[s] = ids[s + nTrain];
|
|
for (int s = 0; s < nVal; s++)
|
|
iVal[s] = ids[s + nTrain + nTest];
|
|
for (int s = 0; s < nTrain; s++)
|
|
lTrain[s] = labels[s];
|
|
for (int s = 0; s < nTest; s++)
|
|
lTest[s] = labels[s + nTrain];
|
|
for (int s = 0; s < nVal; s++)
|
|
lVal[s] = labels[s + nTrain + nTest];
|
|
return (
|
|
(iTrain, lTrain, dTrain),
|
|
(iTest, lTest, dTest),
|
|
(iVal, lVal, dVal)
|
|
);
|
|
}
|
|
|
|
public static long Bucketize(this ((double[], double[]), int) minsmaxs, double[] point)
|
|
{
|
|
var ((mins, maxs), cspas) = minsmaxs;
|
|
long b = 0;
|
|
for (int f = 0; f < point.Length; f++)
|
|
b += Math.Min(
|
|
(long)Math.Floor(
|
|
(mins[f], maxs[f]).Scale(point[f])
|
|
* cspas
|
|
),
|
|
cspas - 1
|
|
) * ((long)Math.Pow(cspas, f));
|
|
return b;
|
|
}
|
|
|
|
public static ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][])) MergeTTVBuckets(this ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][]))[] buckets)
|
|
{
|
|
var nTr = buckets.Sum(x => x.Item1.Item1.Length);
|
|
var nTe = buckets.Sum(x => x.Item2.Item1.Length);
|
|
var nVa = buckets.Sum(x => x.Item3.Item1.Length);
|
|
if (nTr == 0 && nTe == 0 && nVa == 0)
|
|
return (([], [], []), ([], [], []), ([], [], []));
|
|
int nF = buckets.Select(b => new int[] { b.Item1.Item3.Length, b.Item2.Item3.Length, b.Item3.Item3.Length }.Max()).Max();
|
|
if (nF == 0)
|
|
return (([], [], []), ([], [], []), ([], [], []));
|
|
var dsTr = (nTr.Allocate1D(""), nTr.Allocate1D<string?>(null), (nF, nTr).Allocate2D<double>(0));
|
|
var dsTe = (nTe.Allocate1D(""), nTe.Allocate1D<string?>(null), (nF, nTe).Allocate2D<double>(0));
|
|
var dsVa = (nVa.Allocate1D(""), nVa.Allocate1D<string?>(null), (nF, nVa).Allocate2D<double>(0));
|
|
int cTr = 0;
|
|
int cTe = 0;
|
|
int cVa = 0;
|
|
foreach (var (bTr, bTe, bVa) in buckets)
|
|
{
|
|
for (int s = 0; s < bTr.Item1.Length; s++, cTr++)
|
|
{
|
|
dsTr.Item1[cTr] = bTr.Item1[s];
|
|
dsTr.Item2[cTr] = bTr.Item2[s];
|
|
for (int f = 0; f < nF; f++)
|
|
dsTr.Item3[f][cTr] = bTr.Item3[f][s];
|
|
}
|
|
for (int s = 0; s < bTe.Item1.Length; s++, cTe++)
|
|
{
|
|
dsTe.Item1[cTe] = bTe.Item1[s];
|
|
dsTe.Item2[cTe] = bTe.Item2[s];
|
|
for (int f = 0; f < nF; f++)
|
|
dsTe.Item3[f][cTe] = bTe.Item3[f][s];
|
|
}
|
|
for (int s = 0; s < bVa.Item1.Length; s++, cVa++)
|
|
{
|
|
dsVa.Item1[cVa] = bVa.Item1[s];
|
|
dsVa.Item2[cVa] = bVa.Item2[s];
|
|
for (int f = 0; f < nF; f++)
|
|
dsVa.Item3[f][cVa] = bVa.Item3[f][s];
|
|
}
|
|
}
|
|
return (dsTr, dsTe, dsVa);
|
|
}
|
|
|
|
public static ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][])) SplitTTV(this (string[] ids, string?[] labels, double[][] data) tpl, double ratioTrain, double ratioTest, double ratioValidation, int cartesianStatificationPerAxis = 1, bool labelStatification = false)
|
|
{
|
|
if (tpl.ids.Length == 0)
|
|
return (([], [], []), ([], [], []), ([], [], []));
|
|
else if (cartesianStatificationPerAxis > 1)
|
|
{
|
|
var datat = tpl.data.Transpose();
|
|
var minsmaxs = (tpl.data.Select(x => x.Min()).ToArray(), tpl.data.Select(x => x.Max()).ToArray());
|
|
var bucketDef = (minsmaxs, cartesianStatificationPerAxis);
|
|
var nSamples = datat.Length;
|
|
var nCartesianBuckets = (long)Math.Pow(cartesianStatificationPerAxis, tpl.data.Length);
|
|
var cartesianBuckets = new (List<string>, List<string?>, List<double>[])[nCartesianBuckets];
|
|
var cartesianTTVBuckets = new ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][]))[nCartesianBuckets];
|
|
for (int b = 0; b < nCartesianBuckets; b++)
|
|
{
|
|
var datas = new List<double>[tpl.data.Length];
|
|
for (int f = 0; f < tpl.data.Length; f++)
|
|
datas[f] = [];
|
|
cartesianBuckets[b] = ([], [], datas);
|
|
}
|
|
for (int s = 0; s < nSamples; s++)
|
|
{
|
|
var (bi, bl, bd) = cartesianBuckets[bucketDef.Bucketize(datat[s])];
|
|
bi.Add(tpl.ids[s]);
|
|
bl.Add(tpl.labels[s]);
|
|
for (int f = 0; f < tpl.data.Length; f++)
|
|
bd[f].Add(datat[s][f]);
|
|
}
|
|
for (int b = 0; b < nCartesianBuckets; b++)
|
|
{
|
|
var (bi, bl, bd) = cartesianBuckets[b];
|
|
cartesianTTVBuckets[b] = (
|
|
bi.ToArray(),
|
|
bl.ToArray(),
|
|
bd.Select(x => x.ToArray()).ToArray()
|
|
).SplitTTV(ratioTrain, ratioTest, ratioValidation, 1, labelStatification);
|
|
}
|
|
return cartesianTTVBuckets.MergeTTVBuckets();
|
|
}
|
|
else if (labelStatification)
|
|
{
|
|
var datat = tpl.data.Transpose();
|
|
var nSamples = datat.Length;
|
|
var labels = tpl.labels.Distinct().ToArray();
|
|
var nLabelBuckets = labels.Length;
|
|
var labelBuckets = new (List<string>, List<string?>, List<double>[])[nLabelBuckets];
|
|
var labelTTVBuckets = new ((string[], string?[], double[][]), (string[], string?[], double[][]), (string[], string?[], double[][]))[nLabelBuckets];
|
|
for (int b = 0; b < nLabelBuckets; b++)
|
|
{
|
|
var datas = new List<double>[tpl.data.Length];
|
|
for (int f = 0; f < tpl.data.Length; f++)
|
|
datas[f] = [];
|
|
labelBuckets[b] = ([], [], datas);
|
|
}
|
|
for (int s = 0; s < nSamples; s++)
|
|
{
|
|
var ix = labels.IndexOf(tpl.labels[s]);
|
|
var (bi, bl, bd) = labelBuckets[ix];
|
|
bi.Add(tpl.ids[s]);
|
|
bl.Add(tpl.labels[s]);
|
|
for (int f = 0; f < tpl.data.Length; f++)
|
|
bd[f].Add(datat[s][f]);
|
|
}
|
|
for (int b = 0; b < nLabelBuckets; b++)
|
|
{
|
|
var (bi, bl, bd) = labelBuckets[b];
|
|
labelTTVBuckets[b] = (
|
|
bi.ToArray(),
|
|
bl.ToArray(),
|
|
bd.Select(x => x.ToArray()).ToArray()
|
|
).SplitTTV(ratioTrain, ratioTest, ratioValidation, cartesianStatificationPerAxis, false);
|
|
}
|
|
return labelTTVBuckets.MergeTTVBuckets();
|
|
}
|
|
else
|
|
return tpl.SplitTTVSimple(ratioTrain, ratioTest, ratioValidation);
|
|
}
|
|
} |