21 #include "llvm/Support/MathExtras.h"
49 constexpr
float G_2x2_3x3[] = {
56 constexpr
float GT_2x2_3x3[] = {
62 constexpr
float BT_2x2_3x3[] = {
69 constexpr
float B_2x2_3x3[] = {
76 constexpr
float AT_2x2_3x3[] = {
81 constexpr
float A_2x2_3x3[] = {
88 constexpr
float G_4x4_3x3[] = {
97 constexpr
float GT_4x4_3x3[] = {
98 1, -1./3, -1./3, 1./12, 1./12, 0,
99 0, 1./3, -1./3, -1./6, 1./6, 0,
100 0, -1./3, -1./3, 1./3, 1./3, 1
103 constexpr
float BT_4x4_3x3[] = {
104 1./4, 0, -5./16, 0, 1./16, 0,
105 0, 1./4, -1./4, -1./16, 1./16, 0,
106 0, -1./4, -1./4, 1./16, 1./16, 0,
107 0, 1./4, -1./8, -1./4, 1./8, 0,
108 0, -1./4, -1./8, 1./4, 1./8, 0,
109 0, 1./4, 0, -5./16, 0, 1./16
112 constexpr
float B_4x4_3x3[] = {
114 0, 1./4, -1./4, 1./4, -1./4, 1./4,
115 -5./16, -1./4, -1./4, -1./8, -1./8, 0,
116 0, -1./16, 1./16, -1./4, 1./4, -5./16,
117 1./16, 1./16, 1./16, 1./8, 1./8, 0,
121 constexpr
float AT_4x4_3x3[] = {
122 1./8, 1./4, 1./4, 1./8, 1./8, 0,
123 0, -1./4, 1./4, -1./4, 1./4, 0,
124 0, 1./4, 1./4, 1./2, 1./2, 0,
125 0, -1./4, 1./4, -1, 1, 1./2
128 constexpr
float A_4x4_3x3[] = {
130 1./4, -1./4, 1./4, -1./4,
131 1./4, 1./4, 1./4, 1./4,
132 1./8, -1./4, 1./2, -1,
137 constexpr
float G_2x2_5x5[] = {
139 1./6, -1./6, 1./6, -1./6, 1./6,
140 -1./6, -1./6, -1./6, -1./6, -1./6,
141 -4./15, 2./15, -1./15, 1./30, -1./60,
142 1./60, 1./30, 1./15, 2./15, 4./15,
146 constexpr
float GT_2x2_5x5[] = {
147 1, 1./6, -1./6, -4./15, 1./60, 0,
148 0, -1./6, -1./6, 2./15, 1./30, 0,
149 0, 1./6, -1./6, -1./15, 1./15, 0,
150 0, -1./6, -1./6, 1./30, 2./15, 0,
151 0, 1./6, -1./6, -1./60, 4./15, 1
154 constexpr
float BT_2x2_5x5[] = {
155 1./8, 3./16, -1./4, -3./16, 1./8, 0,
156 0, 1./8, 1./16, -5./16, 1./8, 0,
157 0, -1./8, -5./16, -1./16, 1./8, 0,
158 0, 1./4, -1./8, -1./4, 1./8, 0,
159 0, -1./8, -1./4, 1./8, 1./4, 0,
160 0, 1./8, 3./16, -1./4, -3./16, 1./8
163 constexpr
float B_2x2_5x5[] = {
165 3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
166 -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
167 -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
168 1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
172 constexpr
float AT_2x2_5x5[] = {
174 0, -1, 1, -1, 2, 1./2
177 constexpr
float A_2x2_5x5[] = {
188 struct TransformMatrix {
189 TransformMatrix(
const float *
table, int64_t
rows, int64_t
cols,
200 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
201 TransformMatrix transform, Type type) {
202 ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
204 return arith::ConstantOp::create(
208 SmallVector<int64_t>{transform.rows, transform.cols}, type),
213 Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
214 Value loopNorFIndex, Value loopCorFIndex,
215 Value heightOffset, Value widthOffset,
216 int64_t extractHeight, int64_t extractWidth,
217 int64_t loopNorFIdx, int64_t loopCorFIdx,
218 int64_t heightIdx, int64_t widthIdx) {
219 auto sourceType = cast<ShapedType>(source.getType());
220 Type elementType = sourceType.getElementType();
221 int64_t srcSize = sourceType.getRank();
223 auto oneIndex = builder.getIndexAttr(1);
224 SmallVector<OpFoldResult> offsets;
225 offsets.resize(srcSize);
226 offsets[loopNorFIdx] = loopNorFIndex;
227 offsets[loopCorFIdx] = loopCorFIndex;
228 offsets[heightIdx] = heightOffset;
229 offsets[widthIdx] = widthOffset;
230 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
231 sizes[heightIdx] = builder.getIndexAttr(extractHeight);
232 sizes[widthIdx] = builder.getIndexAttr(extractWidth);
233 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
235 auto extractFilterType =
237 auto extractFilterOp = tensor::ExtractSliceOp::create(
238 builder, loc, extractFilterType, source, offsets, sizes, strides);
240 return extractFilterOp;
244 Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
245 Value tileHIndex, Value tileWIndex,
246 Value loopNorFIndex, Value loopCorFIndex,
247 int64_t tileHIdx, int64_t tileWIdx,
248 int64_t loopNorFIdx, int64_t loopCorFIdx,
249 int64_t heightIdx, int64_t widthIdx) {
250 auto sourceType = cast<ShapedType>(source.getType());
251 Type elementType = sourceType.getElementType();
252 auto sourceShape = sourceType.getShape();
253 int64_t srcSize = sourceType.getRank();
254 int64_t height = sourceShape[heightIdx];
255 int64_t width = sourceShape[widthIdx];
257 auto zeroIndex = builder.getIndexAttr(0);
258 auto oneIndex = builder.getIndexAttr(1);
259 SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
260 offsets.resize(srcSize);
261 offsets[tileHIdx] = tileHIndex;
262 offsets[tileWIdx] = tileWIndex;
263 offsets[loopNorFIdx] = loopNorFIndex;
264 offsets[loopCorFIdx] = loopCorFIndex;
265 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
266 sizes[heightIdx] = builder.getIndexAttr(height);
267 sizes[widthIdx] = builder.getIndexAttr(width);
268 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
271 auto extractFilterOp = tensor::ExtractSliceOp::create(
272 builder, loc, extractFilterType, source, offsets, sizes, strides);
274 return extractFilterOp;
279 Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
280 Value dest, Value loopNorFIndex, Value loopCorFIndex,
281 Value heightOffset, Value widthOffset, int64_t height,
282 int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
283 int64_t heightIdx, int64_t widthIdx) {
284 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
285 auto oneIndex = builder.getIndexAttr(1);
286 SmallVector<OpFoldResult> retOffsets;
287 retOffsets.resize(destSize);
288 retOffsets[loopNorFIdx] = loopNorFIndex;
289 retOffsets[loopCorFIdx] = loopCorFIndex;
290 retOffsets[heightIdx] = heightOffset;
291 retOffsets[widthIdx] = widthOffset;
292 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
293 retSizes[heightIdx] = builder.getIndexAttr(height);
294 retSizes[widthIdx] = builder.getIndexAttr(width);
295 SmallVector<OpFoldResult> strides(destSize, oneIndex);
297 auto insertSliceOp = tensor::InsertSliceOp::create(
298 builder, loc, source, dest, retOffsets, retSizes, strides);
300 return insertSliceOp;
305 Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
306 Value dest, Value tileHIndex, Value tileWIndex,
307 Value loopNorFIndex, Value loopCorFIndex, int64_t height,
308 int64_t width, int64_t tileHIdx, int64_t tileWIdx,
309 int64_t loopNorFIdx, int64_t loopCorFIdx,
310 int64_t heightIdx, int64_t widthIdx) {
311 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
312 auto zeroIndex = builder.getIndexAttr(0);
313 auto oneIndex = builder.getIndexAttr(1);
314 SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
315 retOffsets.resize(destSize);
316 retOffsets[tileHIdx] = tileHIndex;
317 retOffsets[tileWIdx] = tileWIndex;
318 retOffsets[loopNorFIdx] = loopNorFIndex;
319 retOffsets[loopCorFIdx] = loopCorFIndex;
320 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
321 retSizes[heightIdx] = builder.getIndexAttr(height);
322 retSizes[widthIdx] = builder.getIndexAttr(width);
323 SmallVector<OpFoldResult> strides(destSize, oneIndex);
325 auto insertSliceOp = tensor::InsertSliceOp::create(
326 builder, loc, source, dest, retOffsets, retSizes, strides);
328 return insertSliceOp;
342 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
343 Value retValue, WinogradConv2DFmr fmr,
344 bool leftTransform =
true,
bool rightTransform =
true) {
346 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
348 {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
349 {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
350 {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
354 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
356 {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
357 {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
358 {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
361 auto filterType = cast<ShapedType>(filter.getType());
362 Type elementType = filterType.getElementType();
363 auto filterShape = filterType.getShape();
364 int64_t filterF = filterShape[0];
365 int64_t filterH = filterShape[1];
366 int64_t filterW = filterShape[2];
367 int64_t filterC = filterShape[3];
371 if (filterH != r && filterH != 1)
373 if (filterW != r && filterW != 1)
377 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
379 Value FIter = ivs[0];
380 Value CIter = ivs[1];
384 extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
385 zeroIdx, filterH, filterW, 0,
389 Value matmulRetValue = extractFilter;
390 Value zero = arith::ConstantOp::create(builder, loc,
391 rewriter.getZeroAttr(elementType));
394 auto it = GMatrices.find(fmr);
395 if (it == GMatrices.end())
397 const TransformMatrix &GMatrix = it->second;
399 retRows = GMatrix.rows;
401 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
405 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
407 Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
409 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
410 ValueRange{G, extractFilter},
412 matmulRetValue = matmulOp.getResult(0);
415 if (rightTransform) {
417 auto it = GTMatrices.find(fmr);
418 if (it == GTMatrices.end())
420 const TransformMatrix >Matrix = it->second;
424 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
428 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
430 Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
432 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
433 ValueRange{matmulRetValue,
GT},
435 matmulRetValue = matmulOp.getResult(0);
439 int64_t retHeight = leftTransform ? m + r - 1 : 1;
440 int64_t retWidth = rightTransform ? m + r - 1 : 1;
443 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
444 zeroIdx, zeroIdx, retHeight, retWidth,
448 return {insertSliceOp};
455 rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
456 {oneStep, oneStep}, {retValue}, buildBody);
477 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
478 Value retValue, WinogradConv2DFmr fmr,
479 bool leftTransform =
true,
bool rightTransform =
true) {
481 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
483 {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
484 {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
485 {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
489 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
491 {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
492 {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
493 {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
498 auto inputType = cast<ShapedType>(input.getType());
499 Type elementType = inputType.getElementType();
500 auto inputShape = inputType.getShape();
501 int64_t inputN = inputShape[0];
502 int64_t inputC = inputShape[3];
503 auto valueType = cast<ShapedType>(retValue.getType());
504 auto valueShape = valueType.getShape();
505 int64_t tileH = valueShape[2];
506 int64_t tileW = valueShape[3];
507 int64_t alphaH = leftTransform ? m + r - 1 : 1;
508 int64_t alphaW = rightTransform ? m + r - 1 : 1;
510 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
512 Value tileHIter = ivs[0];
513 Value tileWIter = ivs[1];
514 Value NIter = ivs[2];
515 Value CIter = ivs[3];
517 auto context = builder.getContext();
519 auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
522 Value heightOffset = affine::AffineApplyOp::create(
523 builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
524 Value widthOffset = affine::AffineApplyOp::create(
525 builder, loc, rightTransform ? affineMap : identityAffineMap,
530 extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
531 widthOffset, alphaH, alphaW, 0,
536 Value matmulRetValue = extractInput;
537 Value zero = arith::ConstantOp::create(builder, loc,
538 rewriter.getZeroAttr(elementType));
541 auto it = BTMatrices.find(fmr);
542 if (it == BTMatrices.end())
544 const TransformMatrix &BTMatrix = it->second;
546 retRows = BTMatrix.rows;
548 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
552 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
555 create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
557 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
558 ValueRange{BT, matmulRetValue},
560 matmulRetValue = matmulOp.getResult(0);
563 if (rightTransform) {
565 auto it = BMatrices.find(fmr);
566 if (it == BMatrices.end())
568 const TransformMatrix &BMatrix = it->second;
570 retCols = BMatrix.cols;
572 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
576 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
578 create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
580 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
581 ValueRange{matmulRetValue,
B},
583 matmulRetValue = matmulOp.getResult(0);
587 auto combinedVal = insert2DDataTo6D(
588 builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
589 CIter, retRows, retCols, 2, 3, 4, 5,
592 return {combinedVal};
602 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
603 {tileHBound, tileWBound, nUpperBound, cUpperBound},
604 {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
626 static Value matrixMultiply(RewriterBase &rewriter, Location loc,
627 Value transformedFilter, Value transformedInput,
628 Type outputElementType) {
630 auto filterType = cast<ShapedType>(transformedFilter.getType());
631 assert(filterType.hasStaticShape() &&
"only support static shapes.");
632 ArrayRef<int64_t> filterShape = filterType.getShape();
633 Type filterElementType = filterType.getElementType();
635 {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
637 SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
638 Value collapseFilter = tensor::CollapseShapeOp::create(
639 rewriter, loc, filterReassocType, transformedFilter, filterReassoc);
643 auto inputType = cast<ShapedType>(transformedInput.getType());
644 assert(inputType.hasStaticShape() &&
"only support static shapes.");
645 ArrayRef<int64_t> inputShape = inputType.getShape();
646 Type inputElementType = inputType.getElementType();
648 {inputShape[0] * inputShape[1],
649 inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
651 SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
652 Value collapseInput = tensor::CollapseShapeOp::create(
653 rewriter, loc, inputReassocType, transformedInput, inputReassoc);
657 {inputShape[0] * inputShape[1],
658 inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
660 Value empty = tensor::EmptyOp::create(rewriter, loc, matmulType.getShape(),
663 Value zero = arith::ConstantOp::create(
664 rewriter, loc, rewriter.getZeroAttr(outputElementType));
665 Value init = linalg::FillOp::create(rewriter, loc, zero, empty).getResult(0);
667 auto matmulOp = linalg::BatchMatmulOp::create(
668 rewriter, loc, matmulType, ValueRange({collapseInput, collapseFilter}),
673 SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
674 auto outputReassocType =
676 inputShape[3], inputShape[4], filterShape[3]},
678 auto expandOutput = tensor::ExpandShapeOp::create(
679 rewriter, loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
700 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
701 Value output, WinogradConv2DFmr fmr,
702 bool leftTransform =
true,
bool rightTransform =
true) {
704 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
706 {WinogradConv2DFmr::F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
707 {WinogradConv2DFmr::F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
708 {WinogradConv2DFmr::F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
712 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
714 {WinogradConv2DFmr::F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
715 {WinogradConv2DFmr::F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
716 {WinogradConv2DFmr::F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
721 auto valueType = cast<ShapedType>(value.getType());
722 Type elementType = valueType.getElementType();
723 auto valueShape = valueType.getShape();
724 int64_t valueH = valueShape[0];
725 int64_t valueW = valueShape[1];
726 int64_t valueN = valueShape[4];
727 int64_t valueF = valueShape[5];
728 int64_t alphaH = leftTransform ? m + r - 1 : 1;
729 int64_t alphaW = rightTransform ? m + r - 1 : 1;
731 if (valueH != alphaH && valueH != 1)
733 if (valueW != alphaW && valueW != 1)
736 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
738 auto context = builder.getContext();
739 Value tileHIter = ivs[0];
740 Value tileWIter = ivs[1];
741 Value NIter = ivs[2];
742 Value FIter = ivs[3];
746 extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
750 const TransformMatrix &AMatrix = AMatrices.at(fmr);
751 const TransformMatrix &ATMatrix = ATMatrices.at(fmr);
752 int64_t
scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
753 (leftTransform ? ATMatrix.scalarFactor : 1);
754 int64_t retCols = rightTransform ? AMatrix.cols : 1;
755 int64_t retRows = leftTransform ? ATMatrix.rows : 1;
757 Value matmulRetValue = extractValue;
758 Value zero = arith::ConstantOp::create(builder, loc,
759 rewriter.getZeroAttr(elementType));
761 auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
764 Value heightOffset = affine::AffineApplyOp::create(
765 builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
766 Value widthOffset = affine::AffineApplyOp::create(
767 builder, loc, rightTransform ? affineMap : identityAffineMap,
771 extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
772 widthOffset, retRows, retCols,
778 Value init = outInitVal;
780 auto empty = tensor::EmptyOp::create(builder, loc,
781 matmulType.getShape(), elementType)
783 init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
786 Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
788 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
789 ValueRange{AT, matmulRetValue},
791 matmulRetValue = matmulOp.getResult(0);
794 if (rightTransform) {
797 Value init = outInitVal;
799 auto empty = tensor::EmptyOp::create(builder, loc,
800 matmulType.getShape(), elementType)
802 init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
805 Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
807 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
808 ValueRange{matmulRetValue,
A},
810 matmulRetValue = matmulOp.getResult(0);
815 Value scalarFactorValue = arith::ConstantOp::create(
818 auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
819 SmallVector<AffineMap> affineMaps = {
820 AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
823 linalg::GenericOp::create(
824 rewriter, loc, matmulType,
825 ValueRange{scalarFactorValue, matmulRetValue},
826 ValueRange{outInitVal}, affineMaps,
828 utils::IteratorType::parallel, utils::IteratorType::parallel},
829 [&](OpBuilder &nestedBuilder, Location nestedLoc,
831 auto mulf = arith::MulFOp::create(nestedBuilder, nestedLoc,
833 auto addf = arith::AddFOp::create(nestedBuilder, nestedLoc,
834 mulf.getResult(), args[2]);
835 linalg::YieldOp::create(nestedBuilder, nestedLoc,
843 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
844 heightOffset, widthOffset, retRows, retCols,
849 return {combinedVal};
852 int64_t tilwH = valueShape[2];
853 int64_t tileW = valueShape[3];
861 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
862 {tileHBound, tileWBound, nUpperBound, fUpperBound},
863 {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
869 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
870 Value value, ArrayRef<int64_t> alignedShape) {
871 auto valueType = cast<ShapedType>(value.getType());
872 Type elementType = valueType.getElementType();
874 Value padValue = arith::ConstantOp::create(rewriter, loc, elementType,
875 rewriter.getZeroAttr(elementType));
882 static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
884 RankedTensorType extractedType) {
885 OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
886 OpFoldResult oneIndex = rewriter.getIndexAttr(1);
887 SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
888 SmallVector<OpFoldResult, 4> strides(4, oneIndex);
890 ArrayRef<int64_t> extractedShape = extractedType.getShape();
891 SmallVector<OpFoldResult> sizes =
894 return tensor::ExtractSliceOp::create(rewriter, loc, extractedType, value,
895 offsets, sizes, strides);
901 attr, [](
const APInt &element) {
return element.getSExtValue() == 1; });
906 static FailureOr<Operation *>
907 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
908 WinogradConv2DFmr fmr) {
909 if (!convOp.hasPureTensorSemantics())
910 return rewriter.notifyMatchFailure(
911 convOp,
"expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
913 Value input = convOp.getInputs()[0];
914 Value filter = convOp.getInputs()[1];
915 Value output = convOp.getOutputs()[0];
916 auto inputType = cast<ShapedType>(input.getType());
917 auto filterType = cast<ShapedType>(filter.getType());
918 auto outputType = cast<ShapedType>(output.getType());
920 if (!inputType.hasStaticShape())
921 return rewriter.notifyMatchFailure(convOp,
922 "expected a static shape for the input");
924 if (!filterType.hasStaticShape())
925 return rewriter.notifyMatchFailure(
926 convOp,
"expected a static shape for the filter");
929 return rewriter.notifyMatchFailure(convOp,
930 "expected all ones for dilations");
933 return rewriter.notifyMatchFailure(convOp,
"expected all ones for strides");
935 ArrayRef<int64_t> filterShape = filterType.getShape();
936 int64_t filterF = filterShape[0];
937 int64_t filterH = filterShape[1];
938 int64_t filterW = filterShape[2];
939 int64_t filterC = filterShape[3];
940 ArrayRef<int64_t> inputShape = inputType.getShape();
941 int64_t inputN = inputShape[0];
942 int64_t inputH = inputShape[1];
943 int64_t inputW = inputShape[2];
944 int64_t inputC = inputShape[3];
945 ArrayRef<int64_t> outputShape = outputType.getShape();
946 int64_t outputN = outputShape[0];
947 int64_t outputH = outputShape[1];
948 int64_t outputW = outputShape[2];
949 int64_t outputF = outputShape[3];
954 bool isSupportedFilter =
false;
955 if (filterH == filterW && filterH == r)
956 isSupportedFilter =
true;
957 if (filterH == r && filterW == 1)
958 isSupportedFilter =
true;
959 if (filterH == 1 && filterW == r)
960 isSupportedFilter =
true;
962 if (!isSupportedFilter)
963 return rewriter.notifyMatchFailure(
964 convOp,
"only support filter (r x r), (r x 1) or (1 x r)");
967 Location loc = convOp.getLoc();
970 bool leftTransform = filterH != 1;
972 bool rightTransform = filterW != 1;
973 int64_t heightM = leftTransform ? m : 1;
974 int64_t widthM = rightTransform ? m : 1;
975 int64_t heightR = leftTransform ? r : 1;
976 int64_t widthR = rightTransform ? r : 1;
979 Type filterElementType = filterType.getElementType();
980 int64_t alphaH = heightM + heightR - 1;
981 int64_t alphaW = widthM + widthR - 1;
982 int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
983 int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
986 Value retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
988 auto transformedFilter = linalg::WinogradFilterTransformOp::create(
989 rewriter, loc, retType, filter, retValue, fmr);
995 Type inputElementType = inputType.getElementType();
996 int64_t alignedInputH = tileH * heightM + (heightR - 1);
997 int64_t alignedInputW = tileW * widthM + (widthR - 1);
998 if (alignedInputH != inputH || alignedInputW != inputW) {
999 input = padToAlignedTensor(rewriter, loc, input,
1000 {inputN, alignedInputH, alignedInputW, inputC});
1004 {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1005 retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
1007 auto transformedInput = linalg::WinogradInputTransformOp::create(
1008 rewriter, loc, retType, input, retValue, fmr);
1010 Type outputElementType = outputType.getElementType();
1011 Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1012 transformedInput, outputElementType);
1018 int64_t alignedOutputH = tileH * heightM;
1019 int64_t alignedOutputW = tileW * widthM;
1020 bool isOutputUnaligned =
1021 ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1022 if (isOutputUnaligned) {
1024 {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1026 padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1027 outputType = alignedOutputType;
1030 Value transformedOutput = linalg::WinogradOutputTransformOp::create(
1031 rewriter, loc, outputType, matmulRet, output, fmr);
1035 if (isOutputUnaligned) {
1036 transformedOutput = extractFromAlignedTensor(
1037 rewriter, loc, transformedOutput,
1039 outputElementType));
1042 rewriter.replaceOp(convOp, transformedOutput);
1044 return transformedOutput.getDefiningOp();
1048 FailureOr<Operation *>
1049 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1050 linalg::WinogradFilterTransformOp op) {
1051 Location loc = op.getLoc();
1052 Value filter = op.getFilter();
1053 auto filterType = cast<ShapedType>(filter.getType());
1054 auto filterShape = filterType.getShape();
1055 int64_t filterH = filterShape[1];
1056 int64_t filterW = filterShape[2];
1059 bool leftTransform = filterH != 1;
1061 bool rightTransform = filterW != 1;
1062 Value transformedFilter =
1063 filterTransform(rewriter, loc, filter, op.getOutput(), op.getFmr(),
1064 leftTransform, rightTransform);
1065 if (!transformedFilter)
1068 rewriter.replaceOp(op, transformedFilter);
1070 return transformedFilter.getDefiningOp();
1074 FailureOr<Operation *>
1075 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1076 linalg::WinogradInputTransformOp op) {
1077 Location loc = op.getLoc();
1078 Value output = op.getOutput();
1079 auto outputType = cast<ShapedType>(output.getType());
1080 auto outputShape = outputType.getShape();
1082 int64_t outputH = outputShape[0];
1083 int64_t outputW = outputShape[1];
1086 bool leftTransform = outputH != 1;
1088 bool rightTransform = outputW != 1;
1089 Value transformedInput =
1090 inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getFmr(),
1091 leftTransform, rightTransform);
1092 if (!transformedInput)
1095 rewriter.replaceOp(op, transformedInput);
1097 return transformedInput.getDefiningOp();
1101 FailureOr<Operation *>
1102 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1103 linalg::WinogradOutputTransformOp op) {
1104 Location loc = op.getLoc();
1105 Value value = op.getValue();
1106 auto valueType = cast<ShapedType>(value.getType());
1107 auto valueShape = valueType.getShape();
1108 int64_t valueH = valueShape[0];
1109 int64_t valueW = valueShape[1];
1112 bool leftTransform = valueH != 1;
1114 bool rightTransform = valueW != 1;
1115 Value transformedOutput =
1116 outputTransform(rewriter, loc, value, op.getOutput(), op.getFmr(),
1117 leftTransform, rightTransform);
1118 if (!transformedOutput)
1121 rewriter.replaceOp(op, transformedOutput);
1123 return transformedOutput.getDefiningOp();
1127 class DecomposeWinogradFilterTransform final
1128 :
public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1132 LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1133 PatternRewriter &rewriter)
const override {
1134 return decomposeWinogradFilterTransformHelper(rewriter, op);
1139 class DecomposeWinogradInputTransform final
1140 :
public OpRewritePattern<linalg::WinogradInputTransformOp> {
1144 LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1145 PatternRewriter &rewriter)
const override {
1146 return decomposeWinogradInputTransformHelper(rewriter, op);
1151 class DecomposeWinogradOutputTransform final
1152 :
public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1156 LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1157 PatternRewriter &rewriter)
const override {
1158 return decomposeWinogradOutputTransformHelper(rewriter, op);
1163 class WinogradConv2DNhwcFhwc final
1164 :
public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1168 : OpRewritePattern(context), fmr(fmr) {}
1170 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1171 PatternRewriter &rewriter)
const override {
1172 if (
failed(winogradConv2DHelper(rewriter, convOp, fmr)))
1179 WinogradConv2DFmr fmr;
1186 linalg::Conv2DNhwcFhwcOp op,
1187 linalg::WinogradConv2DFmr fmr) {
1188 return winogradConv2DHelper(rewriter, op, fmr);
1191 FailureOr<Operation *>
1193 linalg::WinogradFilterTransformOp op) {
1194 return decomposeWinogradFilterTransformHelper(rewriter, op);
1197 FailureOr<Operation *>
1199 linalg::WinogradInputTransformOp op) {
1200 return decomposeWinogradInputTransformHelper(rewriter, op);
1203 FailureOr<Operation *>
1205 linalg::WinogradOutputTransformOp op) {
1206 return decomposeWinogradOutputTransformHelper(rewriter, op);
1210 WinogradConv2DFmr fmr) {
1213 patterns.insert<WinogradConv2DNhwcFhwc>(context, fmr);
1219 .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1220 DecomposeWinogradOutputTransform>(context);
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, WinogradConv2DFmr fmr)
Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns)
Patterns to decompose Winograd operators.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
static bool hasAllOneValues(DenseIntElementsAttr attr)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
@ Type
An inlay hint that for a type annotation.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...