22 #include "llvm/Support/MathExtras.h"
41 constexpr
float G_2x2_3x3[] = {
48 constexpr
float GT_2x2_3x3[] = {
54 constexpr
float BT_2x2_3x3[] = {
61 constexpr
float B_2x2_3x3[] = {
68 constexpr
float AT_2x2_3x3[] = {
73 constexpr
float A_2x2_3x3[] = {
80 constexpr
float G_4x4_3x3[] = {
89 constexpr
float GT_4x4_3x3[] = {
90 1, -1./3, -1./3, 1./12, 1./12, 0,
91 0, 1./3, -1./3, -1./6, 1./6, 0,
92 0, -1./3, -1./3, 1./3, 1./3, 1
95 constexpr
float BT_4x4_3x3[] = {
96 1./4, 0, -5./16, 0, 1./16, 0,
97 0, 1./4, -1./4, -1./16, 1./16, 0,
98 0, -1./4, -1./4, 1./16, 1./16, 0,
99 0, 1./4, -1./8, -1./4, 1./8, 0,
100 0, -1./4, -1./8, 1./4, 1./8, 0,
101 0, 1./4, 0, -5./16, 0, 1./16
104 constexpr
float B_4x4_3x3[] = {
106 0, 1./4, -1./4, 1./4, -1./4, 1./4,
107 -5./16, -1./4, -1./4, -1./8, -1./8, 0,
108 0, -1./16, 1./16, -1./4, 1./4, -5./16,
109 1./16, 1./16, 1./16, 1./8, 1./8, 0,
113 constexpr
float AT_4x4_3x3[] = {
114 1./8, 1./4, 1./4, 1./8, 1./8, 0,
115 0, -1./4, 1./4, -1./4, 1./4, 0,
116 0, 1./4, 1./4, 1./2, 1./2, 0,
117 0, -1./4, 1./4, -1, 1, 1./2
120 constexpr
float A_4x4_3x3[] = {
122 1./4, -1./4, 1./4, -1./4,
123 1./4, 1./4, 1./4, 1./4,
124 1./8, -1./4, 1./2, -1,
129 constexpr
float G_2x2_5x5[] = {
131 1./6, -1./6, 1./6, -1./6, 1./6,
132 -1./6, -1./6, -1./6, -1./6, -1./6,
133 -4./15, 2./15, -1./15, 1./30, -1./60,
134 1./60, 1./30, 1./15, 2./15, 4./15,
138 constexpr
float GT_2x2_5x5[] = {
139 1, 1./6, -1./6, -4./15, 1./60, 0,
140 0, -1./6, -1./6, 2./15, 1./30, 0,
141 0, 1./6, -1./6, -1./15, 1./15, 0,
142 0, -1./6, -1./6, 1./30, 2./15, 0,
143 0, 1./6, -1./6, -1./60, 4./15, 1
146 constexpr
float BT_2x2_5x5[] = {
147 1./8, 3./16, -1./4, -3./16, 1./8, 0,
148 0, 1./8, 1./16, -5./16, 1./8, 0,
149 0, -1./8, -5./16, -1./16, 1./8, 0,
150 0, 1./4, -1./8, -1./4, 1./8, 0,
151 0, -1./8, -1./4, 1./8, 1./4, 0,
152 0, 1./8, 3./16, -1./4, -3./16, 1./8
155 constexpr
float B_2x2_5x5[] = {
157 3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
158 -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
159 -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
160 1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
164 constexpr
float AT_2x2_5x5[] = {
166 0, -1, 1, -1, 2, 1./2
169 constexpr
float A_2x2_5x5[] = {
179 using TransformMapKeyTy = std::pair<int, int>;
188 constexpr TransformMapKeyTy F_2_3{2, 3};
189 constexpr TransformMapKeyTy F_4_3{4, 3};
190 constexpr TransformMapKeyTy F_2_5{2, 5};
193 struct TransformMatrix {
194 TransformMatrix(
const float *
table, int64_t
rows, int64_t
cols,
205 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
206 TransformMatrix transform, Type type) {
207 ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
209 return builder.create<arith::ConstantOp>(
212 SmallVector<int64_t>{transform.rows, transform.cols}, type),
217 Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
218 Value loopNorFIndex, Value loopCorFIndex,
219 Value heightOffset, Value widthOffset,
220 int64_t extractHeight, int64_t extractWidth,
221 int64_t loopNorFIdx, int64_t loopCorFIdx,
222 int64_t heightIdx, int64_t widthIdx) {
223 auto sourceType = cast<ShapedType>(source.getType());
224 Type elementType = sourceType.getElementType();
225 int64_t srcSize = sourceType.getRank();
227 auto oneIndex = builder.getIndexAttr(1);
228 SmallVector<OpFoldResult> offsets;
229 offsets.resize(srcSize);
230 offsets[loopNorFIdx] = loopNorFIndex;
231 offsets[loopCorFIdx] = loopCorFIndex;
232 offsets[heightIdx] = heightOffset;
233 offsets[widthIdx] = widthOffset;
234 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
235 sizes[heightIdx] = builder.getIndexAttr(extractHeight);
236 sizes[widthIdx] = builder.getIndexAttr(extractWidth);
237 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
239 auto extractFilterType =
241 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
242 loc, extractFilterType, source, offsets, sizes, strides);
244 return extractFilterOp;
248 Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
249 Value tileHIndex, Value tileWIndex,
250 Value loopNorFIndex, Value loopCorFIndex,
251 int64_t tileHIdx, int64_t tileWIdx,
252 int64_t loopNorFIdx, int64_t loopCorFIdx,
253 int64_t heightIdx, int64_t widthIdx) {
254 auto sourceType = cast<ShapedType>(source.getType());
255 Type elementType = sourceType.getElementType();
256 auto sourceShape = sourceType.getShape();
257 int64_t srcSize = sourceType.getRank();
258 int64_t height = sourceShape[heightIdx];
259 int64_t width = sourceShape[widthIdx];
261 auto zeroIndex = builder.getIndexAttr(0);
262 auto oneIndex = builder.getIndexAttr(1);
263 SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
264 offsets.resize(srcSize);
265 offsets[tileHIdx] = tileHIndex;
266 offsets[tileWIdx] = tileWIndex;
267 offsets[loopNorFIdx] = loopNorFIndex;
268 offsets[loopCorFIdx] = loopCorFIndex;
269 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
270 sizes[heightIdx] = builder.getIndexAttr(height);
271 sizes[widthIdx] = builder.getIndexAttr(width);
272 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
275 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
276 loc, extractFilterType, source, offsets, sizes, strides);
278 return extractFilterOp;
283 Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
284 Value dest, Value loopNorFIndex, Value loopCorFIndex,
285 Value heightOffset, Value widthOffset, int64_t height,
286 int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
287 int64_t heightIdx, int64_t widthIdx) {
288 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
289 auto oneIndex = builder.getIndexAttr(1);
290 SmallVector<OpFoldResult> retOffsets;
291 retOffsets.resize(destSize);
292 retOffsets[loopNorFIdx] = loopNorFIndex;
293 retOffsets[loopCorFIdx] = loopCorFIndex;
294 retOffsets[heightIdx] = heightOffset;
295 retOffsets[widthIdx] = widthOffset;
296 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
297 retSizes[heightIdx] = builder.getIndexAttr(height);
298 retSizes[widthIdx] = builder.getIndexAttr(width);
299 SmallVector<OpFoldResult> strides(destSize, oneIndex);
301 auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
302 loc, source, dest, retOffsets, retSizes, strides);
304 return insertSliceOp;
309 Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
310 Value dest, Value tileHIndex, Value tileWIndex,
311 Value loopNorFIndex, Value loopCorFIndex, int64_t height,
312 int64_t width, int64_t tileHIdx, int64_t tileWIdx,
313 int64_t loopNorFIdx, int64_t loopCorFIdx,
314 int64_t heightIdx, int64_t widthIdx) {
315 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
316 auto zeroIndex = builder.getIndexAttr(0);
317 auto oneIndex = builder.getIndexAttr(1);
318 SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
319 retOffsets.resize(destSize);
320 retOffsets[tileHIdx] = tileHIndex;
321 retOffsets[tileWIdx] = tileWIndex;
322 retOffsets[loopNorFIdx] = loopNorFIndex;
323 retOffsets[loopCorFIdx] = loopCorFIndex;
324 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
325 retSizes[heightIdx] = builder.getIndexAttr(height);
326 retSizes[widthIdx] = builder.getIndexAttr(width);
327 SmallVector<OpFoldResult> strides(destSize, oneIndex);
329 auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
330 loc, source, dest, retOffsets, retSizes, strides);
332 return insertSliceOp;
346 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
347 Value retValue, int64_t m, int64_t r,
348 bool leftTransform =
true,
bool rightTransform =
true) {
350 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
352 {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
353 {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
354 {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
358 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
360 {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
361 {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
362 {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
365 auto filterType = cast<ShapedType>(filter.getType());
366 Type elementType = filterType.getElementType();
367 auto filterShape = filterType.getShape();
368 int64_t filterF = filterShape[0];
369 int64_t filterH = filterShape[1];
370 int64_t filterW = filterShape[2];
371 int64_t filterC = filterShape[3];
373 if (filterH != r && filterH != 1)
375 if (filterW != r && filterW != 1)
378 Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
379 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
381 Value FIter = ivs[0];
382 Value CIter = ivs[1];
386 extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
387 zeroIdx, filterH, filterW, 0,
390 TransformMapKeyTy key = {m, r};
392 Value matmulRetValue = extractFilter;
393 Value zero = builder.create<arith::ConstantOp>(
394 loc, rewriter.getZeroAttr(elementType));
397 auto it = GMatrices.find(key);
398 if (it == GMatrices.end())
400 const TransformMatrix &GMatrix = it->second;
402 retRows = GMatrix.rows;
406 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
408 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
410 Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
412 auto matmulOp = builder.create<linalg::MatmulOp>(
413 loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
414 matmulRetValue = matmulOp.getResult(0);
417 if (rightTransform) {
419 auto it = GTMatrices.find(key);
420 if (it == GTMatrices.end())
422 const TransformMatrix >Matrix = it->second;
428 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
430 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
432 Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
434 auto matmulOp = builder.create<linalg::MatmulOp>(
435 loc, matmulType, ValueRange{matmulRetValue,
GT}, ValueRange{init});
436 matmulRetValue = matmulOp.getResult(0);
440 int64_t retHeight = leftTransform ? m + r - 1 : 1;
441 int64_t retWidth = rightTransform ? m + r - 1 : 1;
444 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
445 zeroIdx, zeroIdx, retHeight, retWidth,
449 return {insertSliceOp};
452 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
453 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
454 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
456 rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
457 {oneStep, oneStep}, {retValue}, buildBody);
478 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
479 Value retValue, int64_t m, int64_t r,
480 bool leftTransform =
true,
bool rightTransform =
true) {
482 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
484 {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
485 {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
486 {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
490 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
492 {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
493 {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
494 {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
497 auto inputType = cast<ShapedType>(input.getType());
498 Type elementType = inputType.getElementType();
499 auto inputShape = inputType.getShape();
500 int64_t inputN = inputShape[0];
501 int64_t inputC = inputShape[3];
502 auto valueType = cast<ShapedType>(retValue.getType());
503 auto valueShape = valueType.getShape();
504 int64_t tileH = valueShape[2];
505 int64_t tileW = valueShape[3];
506 int64_t alphaH = leftTransform ? m + r - 1 : 1;
507 int64_t alphaW = rightTransform ? m + r - 1 : 1;
509 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
511 Value tileHIter = ivs[0];
512 Value tileWIter = ivs[1];
513 Value NIter = ivs[2];
514 Value CIter = ivs[3];
516 auto context = builder.getContext();
518 auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
521 Value heightOffset = builder.create<affine::AffineApplyOp>(
522 loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
523 Value widthOffset = builder.create<affine::AffineApplyOp>(
524 loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
528 extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
529 widthOffset, alphaH, alphaW, 0,
532 TransformMapKeyTy key = {m, r};
535 Value matmulRetValue = extractInput;
536 Value zero = builder.create<arith::ConstantOp>(
537 loc, rewriter.getZeroAttr(elementType));
540 auto it = BTMatrices.find(key);
541 if (it == BTMatrices.end())
543 const TransformMatrix &BTMatrix = it->second;
545 retRows = BTMatrix.rows;
549 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
551 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
554 create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
556 auto matmulOp = builder.create<linalg::MatmulOp>(
557 loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
558 matmulRetValue = matmulOp.getResult(0);
561 if (rightTransform) {
563 auto it = BMatrices.find(key);
564 if (it == BMatrices.end())
566 const TransformMatrix &BMatrix = it->second;
568 retCols = BMatrix.cols;
572 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
574 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
576 create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
578 auto matmulOp = builder.create<linalg::MatmulOp>(
579 loc, matmulType, ValueRange{matmulRetValue,
B}, ValueRange{init});
580 matmulRetValue = matmulOp.getResult(0);
584 auto combinedVal = insert2DDataTo6D(
585 builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
586 CIter, retRows, retCols, 2, 3, 4, 5,
589 return {combinedVal};
592 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
593 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH);
594 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
595 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
596 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
597 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
599 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
600 {tileHBound, tileWBound, nUpperBound, cUpperBound},
601 {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
623 static Value matrixMultiply(RewriterBase &rewriter, Location loc,
624 Value transformedFilter, Value transformedInput,
625 Type outputElementType) {
627 auto filterType = cast<ShapedType>(transformedFilter.getType());
628 assert(filterType.hasStaticShape() &&
"only support static shapes.");
629 ArrayRef<int64_t> filterShape = filterType.getShape();
630 Type filterElementType = filterType.getElementType();
632 {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
634 SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
635 Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
636 loc, filterReassocType, transformedFilter, filterReassoc);
640 auto inputType = cast<ShapedType>(transformedInput.getType());
641 assert(inputType.hasStaticShape() &&
"only support static shapes.");
642 ArrayRef<int64_t> inputShape = inputType.getShape();
643 Type inputElementType = inputType.getElementType();
645 {inputShape[0] * inputShape[1],
646 inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
648 SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
649 Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
650 loc, inputReassocType, transformedInput, inputReassoc);
654 {inputShape[0] * inputShape[1],
655 inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
657 Value empty = rewriter
658 .create<tensor::EmptyOp>(loc, matmulType.getShape(),
661 Value zero = rewriter.create<arith::ConstantOp>(
662 loc, rewriter.getZeroAttr(outputElementType));
663 Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
665 auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
666 loc, matmulType, ValueRange({collapseInput, collapseFilter}),
671 SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
672 auto outputReassocType =
674 inputShape[3], inputShape[4], filterShape[3]},
676 auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
677 loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
698 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
699 Value output, int64_t m, int64_t r,
700 bool leftTransform =
true,
bool rightTransform =
true) {
702 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
704 {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
705 {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
706 {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
710 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
712 {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
713 {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
714 {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
717 auto valueType = cast<ShapedType>(value.getType());
718 Type elementType = valueType.getElementType();
719 auto valueShape = valueType.getShape();
720 int64_t valueH = valueShape[0];
721 int64_t valueW = valueShape[1];
722 int64_t valueN = valueShape[4];
723 int64_t valueF = valueShape[5];
724 int64_t alphaH = leftTransform ? m + r - 1 : 1;
725 int64_t alphaW = rightTransform ? m + r - 1 : 1;
727 if (valueH != alphaH && valueH != 1)
729 if (valueW != alphaW && valueW != 1)
732 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
734 auto context = builder.getContext();
735 Value tileHIter = ivs[0];
736 Value tileWIter = ivs[1];
737 Value NIter = ivs[2];
738 Value FIter = ivs[3];
742 extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
746 const TransformMapKeyTy key = {m, r};
747 const TransformMatrix &AMatrix = AMatrices.at(key);
748 const TransformMatrix &ATMatrix = ATMatrices.at(key);
749 int64_t
scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
750 (leftTransform ? ATMatrix.scalarFactor : 1);
751 int64_t retCols = rightTransform ? AMatrix.cols : 1;
752 int64_t retRows = leftTransform ? ATMatrix.rows : 1;
754 Value matmulRetValue = extractValue;
755 Value zero = builder.create<arith::ConstantOp>(
756 loc, rewriter.getZeroAttr(elementType));
758 auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
761 Value heightOffset = builder.create<affine::AffineApplyOp>(
762 loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
763 Value widthOffset = builder.create<affine::AffineApplyOp>(
764 loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
767 extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
768 widthOffset, retRows, retCols,
774 Value init = outInitVal;
777 .create<tensor::EmptyOp>(loc, matmulType.getShape(),
780 init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
783 Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
785 auto matmulOp = builder.create<linalg::MatmulOp>(
786 loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
787 matmulRetValue = matmulOp.getResult(0);
790 if (rightTransform) {
793 Value init = outInitVal;
796 .create<tensor::EmptyOp>(loc, matmulType.getShape(),
799 init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
802 Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
804 auto matmulOp = builder.create<linalg::MatmulOp>(
805 loc, matmulType, ValueRange{matmulRetValue,
A}, ValueRange{init});
806 matmulRetValue = matmulOp.getResult(0);
811 Value scalarFactorValue = builder.create<arith::ConstantOp>(
814 auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
815 SmallVector<AffineMap> affineMaps = {
816 AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
820 .create<linalg::GenericOp>(
822 ValueRange{scalarFactorValue, matmulRetValue},
823 ValueRange{outInitVal}, affineMaps,
825 utils::IteratorType::parallel,
826 utils::IteratorType::parallel},
827 [&](OpBuilder &nestedBuilder, Location nestedLoc,
829 auto mulf = nestedBuilder.create<arith::MulFOp>(
830 nestedLoc, args[0], args[1]);
831 auto addf = nestedBuilder.create<arith::AddFOp>(
832 nestedLoc, mulf.getResult(), args[2]);
833 nestedBuilder.create<linalg::YieldOp>(nestedLoc,
841 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
842 heightOffset, widthOffset, retRows, retCols,
847 return {combinedVal};
850 int64_t tilwH = valueShape[2];
851 int64_t tileW = valueShape[3];
852 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
853 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH);
854 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
855 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
856 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
857 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
859 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
860 {tileHBound, tileWBound, nUpperBound, fUpperBound},
861 {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
867 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
868 Value value, ArrayRef<int64_t> alignedShape) {
869 auto valueType = cast<ShapedType>(value.getType());
870 Type elementType = valueType.getElementType();
872 Value padValue = rewriter.create<arith::ConstantOp>(
873 loc, elementType, rewriter.getZeroAttr(elementType));
880 static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
882 RankedTensorType extractedType) {
883 OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
884 OpFoldResult oneIndex = rewriter.getIndexAttr(1);
885 SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
886 SmallVector<OpFoldResult, 4> strides(4, oneIndex);
888 ArrayRef<int64_t> extractedShape = extractedType.getShape();
889 SmallVector<OpFoldResult> sizes =
892 return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
893 offsets, sizes, strides);
899 attr, [](
const APInt &element) {
return element.getSExtValue() == 1; });
904 static FailureOr<Operation *>
905 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
906 int64_t m, int64_t r) {
907 Value input = convOp.getInputs()[0];
908 Value filter = convOp.getInputs()[1];
909 Value output = convOp.getOutputs()[0];
910 auto inputType = cast<ShapedType>(input.getType());
911 auto filterType = cast<ShapedType>(filter.getType());
912 auto outputType = cast<ShapedType>(output.getType());
914 if (!inputType.hasStaticShape())
915 return rewriter.notifyMatchFailure(convOp,
916 "expected a static shape for the input");
918 if (!filterType.hasStaticShape())
919 return rewriter.notifyMatchFailure(
920 convOp,
"expected a static shape for the filter");
923 return rewriter.notifyMatchFailure(convOp,
924 "expected all ones for dilations");
927 return rewriter.notifyMatchFailure(convOp,
"expected all ones for strides");
929 ArrayRef<int64_t> filterShape = filterType.getShape();
930 int64_t filterF = filterShape[0];
931 int64_t filterH = filterShape[1];
932 int64_t filterW = filterShape[2];
933 int64_t filterC = filterShape[3];
934 ArrayRef<int64_t> inputShape = inputType.getShape();
935 int64_t inputN = inputShape[0];
936 int64_t inputH = inputShape[1];
937 int64_t inputW = inputShape[2];
938 int64_t inputC = inputShape[3];
939 ArrayRef<int64_t> outputShape = outputType.getShape();
940 int64_t outputN = outputShape[0];
941 int64_t outputH = outputShape[1];
942 int64_t outputW = outputShape[2];
943 int64_t outputF = outputShape[3];
946 bool isSupportedFilter =
false;
947 if (filterH == filterW && filterH == r)
948 isSupportedFilter =
true;
949 if (filterH == r && filterW == 1)
950 isSupportedFilter =
true;
951 if (filterH == 1 && filterW == r)
952 isSupportedFilter =
true;
954 if (!isSupportedFilter)
955 return rewriter.notifyMatchFailure(
956 convOp,
"only support filter (r x r), (r x 1) or (1 x r)");
960 F_2_3, F_4_3, F_2_5};
962 TransformMapKeyTy key = {m, r};
963 auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
966 if (it == validConfigs.end())
970 Location loc = convOp.getLoc();
973 bool leftTransform = filterH != 1;
975 bool rightTransform = filterW != 1;
976 int64_t heightM = leftTransform ? m : 1;
977 int64_t widthM = rightTransform ? m : 1;
978 int64_t heightR = leftTransform ? r : 1;
979 int64_t widthR = rightTransform ? r : 1;
982 Type filterElementType = filterType.getElementType();
983 int64_t alphaH = heightM + heightR - 1;
984 int64_t alphaW = widthM + widthR - 1;
985 int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
986 int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
989 Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
991 auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
992 loc, retType, filter, retValue, m, r);
998 Type inputElementType = inputType.getElementType();
999 int64_t alignedInputH = tileH * heightM + (heightR - 1);
1000 int64_t alignedInputW = tileW * widthM + (widthR - 1);
1001 if (alignedInputH != inputH || alignedInputW != inputW) {
1002 input = padToAlignedTensor(rewriter, loc, input,
1003 {inputN, alignedInputH, alignedInputW, inputC});
1007 {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1008 retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
1010 auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
1011 loc, retType, input, retValue, m, r);
1013 Type outputElementType = outputType.getElementType();
1014 Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1015 transformedInput, outputElementType);
1021 int64_t alignedOutputH = tileH * heightM;
1022 int64_t alignedOutputW = tileW * widthM;
1023 bool isOutputUnaligned =
1024 ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1025 if (isOutputUnaligned) {
1027 {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1029 padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1030 outputType = alignedOutputType;
1033 Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
1034 loc, outputType, matmulRet, output, m, r);
1038 if (isOutputUnaligned) {
1039 transformedOutput = extractFromAlignedTensor(
1040 rewriter, loc, transformedOutput,
1042 outputElementType));
1045 rewriter.replaceOp(convOp, transformedOutput);
1047 return transformedOutput.getDefiningOp();
1051 FailureOr<Operation *>
1052 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1053 linalg::WinogradFilterTransformOp op) {
1054 Location loc = op.getLoc();
1055 Value filter = op.getFilter();
1056 auto filterType = cast<ShapedType>(filter.getType());
1057 auto filterShape = filterType.getShape();
1058 int64_t filterH = filterShape[1];
1059 int64_t filterW = filterShape[2];
1062 bool leftTransform = filterH != 1;
1064 bool rightTransform = filterW != 1;
1065 Value transformedFilter =
1066 filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
1067 op.getR(), leftTransform, rightTransform);
1068 if (!transformedFilter)
1071 rewriter.replaceOp(op, transformedFilter);
1073 return transformedFilter.getDefiningOp();
1077 FailureOr<Operation *>
1078 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1079 linalg::WinogradInputTransformOp op) {
1080 Location loc = op.getLoc();
1081 Value output = op.getOutput();
1082 auto outputType = cast<ShapedType>(output.getType());
1083 auto outputShape = outputType.getShape();
1085 int64_t outputH = outputShape[0];
1086 int64_t outputW = outputShape[1];
1089 bool leftTransform = outputH != 1;
1091 bool rightTransform = outputW != 1;
1092 Value transformedInput =
1093 inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
1094 op.getR(), leftTransform, rightTransform);
1095 if (!transformedInput)
1098 rewriter.replaceOp(op, transformedInput);
1100 return transformedInput.getDefiningOp();
1104 FailureOr<Operation *>
1105 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1106 linalg::WinogradOutputTransformOp op) {
1107 Location loc = op.getLoc();
1108 Value value = op.getValue();
1109 auto valueType = cast<ShapedType>(value.getType());
1110 auto valueShape = valueType.getShape();
1111 int64_t valueH = valueShape[0];
1112 int64_t valueW = valueShape[1];
1115 bool leftTransform = valueH != 1;
1117 bool rightTransform = valueW != 1;
1118 Value transformedOutput =
1119 outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
1120 op.getR(), leftTransform, rightTransform);
1121 if (!transformedOutput)
1124 rewriter.replaceOp(op, transformedOutput);
1126 return transformedOutput.getDefiningOp();
1130 class DecomposeWinogradFilterTransform final
1131 :
public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1135 LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1136 PatternRewriter &rewriter)
const override {
1137 return decomposeWinogradFilterTransformHelper(rewriter, op);
1142 class DecomposeWinogradInputTransform final
1143 :
public OpRewritePattern<linalg::WinogradInputTransformOp> {
1147 LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1148 PatternRewriter &rewriter)
const override {
1149 return decomposeWinogradInputTransformHelper(rewriter, op);
1154 class DecomposeWinogradOutputTransform final
1155 :
public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1159 LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1160 PatternRewriter &rewriter)
const override {
1161 return decomposeWinogradOutputTransformHelper(rewriter, op);
1166 class WinogradConv2DNhwcFhwc final
1167 :
public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1171 : OpRewritePattern(context), m(m), r(r) {}
1173 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1174 PatternRewriter &rewriter)
const override {
1175 if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
1189 linalg::Conv2DNhwcFhwcOp op, int64_t m,
1191 return winogradConv2DHelper(rewriter, op, m, r);
1194 FailureOr<Operation *>
1196 linalg::WinogradFilterTransformOp op) {
1197 return decomposeWinogradFilterTransformHelper(rewriter, op);
1200 FailureOr<Operation *>
1202 linalg::WinogradInputTransformOp op) {
1203 return decomposeWinogradInputTransformHelper(rewriter, op);
1206 FailureOr<Operation *>
1208 linalg::WinogradOutputTransformOp op) {
1209 return decomposeWinogradOutputTransformHelper(rewriter, op);
1216 patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
1222 .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1223 DecomposeWinogradOutputTransform>(context);
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns)
Patterns to decompose Winograd operators.
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold)
Create a tensor::PadOp that pads source to the size of the statically sized type whose static sizes a...
static bool hasAllOneValues(DenseIntElementsAttr attr)
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r)
Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
@ Type
An inlay hint that for a type annotation.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...