22 #include "llvm/Support/MathExtras.h"
41 constexpr
float G_2x2_3x3[] = {
48 constexpr
float GT_2x2_3x3[] = {
54 constexpr
float BT_2x2_3x3[] = {
61 constexpr
float B_2x2_3x3[] = {
68 constexpr
float AT_2x2_3x3[] = {
73 constexpr
float A_2x2_3x3[] = {
80 constexpr
float G_4x4_3x3[] = {
89 constexpr
float GT_4x4_3x3[] = {
90 1, -1./3, -1./3, 1./12, 1./12, 0,
91 0, 1./3, -1./3, -1./6, 1./6, 0,
92 0, -1./3, -1./3, 1./3, 1./3, 1
95 constexpr
float BT_4x4_3x3[] = {
96 1./4, 0, -5./16, 0, 1./16, 0,
97 0, 1./4, -1./4, -1./16, 1./16, 0,
98 0, -1./4, -1./4, 1./16, 1./16, 0,
99 0, 1./4, -1./8, -1./4, 1./8, 0,
100 0, -1./4, -1./8, 1./4, 1./8, 0,
101 0, 1./4, 0, -5./16, 0, 1./16
104 constexpr
float B_4x4_3x3[] = {
106 0, 1./4, -1./4, 1./4, -1./4, 1./4,
107 -5./16, -1./4, -1./4, -1./8, -1./8, 0,
108 0, -1./16, 1./16, -1./4, 1./4, -5./16,
109 1./16, 1./16, 1./16, 1./8, 1./8, 0,
113 constexpr
float AT_4x4_3x3[] = {
114 1./8, 1./4, 1./4, 1./8, 1./8, 0,
115 0, -1./4, 1./4, -1./4, 1./4, 0,
116 0, 1./4, 1./4, 1./2, 1./2, 0,
117 0, -1./4, 1./4, -1, 1, 1./2
120 constexpr
float A_4x4_3x3[] = {
122 1./4, -1./4, 1./4, -1./4,
123 1./4, 1./4, 1./4, 1./4,
124 1./8, -1./4, 1./2, -1,
129 constexpr
float G_2x2_5x5[] = {
131 1./6, -1./6, 1./6, -1./6, 1./6,
132 -1./6, -1./6, -1./6, -1./6, -1./6,
133 -4./15, 2./15, -1./15, 1./30, -1./60,
134 1./60, 1./30, 1./15, 2./15, 4./15,
138 constexpr
float GT_2x2_5x5[] = {
139 1, 1./6, -1./6, -4./15, 1./60, 0,
140 0, -1./6, -1./6, 2./15, 1./30, 0,
141 0, 1./6, -1./6, -1./15, 1./15, 0,
142 0, -1./6, -1./6, 1./30, 2./15, 0,
143 0, 1./6, -1./6, -1./60, 4./15, 1
146 constexpr
float BT_2x2_5x5[] = {
147 1./8, 3./16, -1./4, -3./16, 1./8, 0,
148 0, 1./8, 1./16, -5./16, 1./8, 0,
149 0, -1./8, -5./16, -1./16, 1./8, 0,
150 0, 1./4, -1./8, -1./4, 1./8, 0,
151 0, -1./8, -1./4, 1./8, 1./4, 0,
152 0, 1./8, 3./16, -1./4, -3./16, 1./8
155 constexpr
float B_2x2_5x5[] = {
157 3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
158 -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
159 -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
160 1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
164 constexpr
float AT_2x2_5x5[] = {
166 0, -1, 1, -1, 2, 1./2
169 constexpr
float A_2x2_5x5[] = {
179 using TransformMapKeyTy = std::pair<int, int>;
188 constexpr TransformMapKeyTy F_2_3{2, 3};
189 constexpr TransformMapKeyTy F_4_3{4, 3};
190 constexpr TransformMapKeyTy F_2_5{2, 5};
193 struct TransformMatrix {
194 TransformMatrix(
const float *
table, int64_t
rows, int64_t
cols,
205 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
206 TransformMatrix transform, Type type) {
207 ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
209 return builder.create<arith::ConstantOp>(
212 SmallVector<int64_t>{transform.rows, transform.cols}, type),
217 Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
218 Value loopNorFIndex, Value loopCorFIndex,
219 Value heightOffset, Value widthOffset,
220 int64_t extractHeight, int64_t extractWidth,
221 int64_t loopNorFIdx, int64_t loopCorFIdx,
222 int64_t heightIdx, int64_t widthIdx) {
223 auto sourceType = cast<ShapedType>(source.getType());
224 Type elementType = sourceType.getElementType();
225 int64_t srcSize = sourceType.getRank();
227 auto oneIndex = builder.getIndexAttr(1);
228 SmallVector<OpFoldResult> offsets;
229 offsets.resize(srcSize);
230 offsets[loopNorFIdx] = loopNorFIndex;
231 offsets[loopCorFIdx] = loopCorFIndex;
232 offsets[heightIdx] = heightOffset;
233 offsets[widthIdx] = widthOffset;
234 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
235 sizes[heightIdx] = builder.getIndexAttr(extractHeight);
236 sizes[widthIdx] = builder.getIndexAttr(extractWidth);
237 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
239 auto extractFilterType =
241 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
242 loc, extractFilterType, source, offsets, sizes, strides);
244 return extractFilterOp;
248 Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
249 Value tileHIndex, Value tileWIndex,
250 Value loopNorFIndex, Value loopCorFIndex,
251 int64_t tileHIdx, int64_t tileWIdx,
252 int64_t loopNorFIdx, int64_t loopCorFIdx,
253 int64_t heightIdx, int64_t widthIdx) {
254 auto sourceType = cast<ShapedType>(source.getType());
255 Type elementType = sourceType.getElementType();
256 auto sourceShape = sourceType.getShape();
257 int64_t srcSize = sourceType.getRank();
258 int64_t height = sourceShape[heightIdx];
259 int64_t width = sourceShape[widthIdx];
261 auto zeroIndex = builder.getIndexAttr(0);
262 auto oneIndex = builder.getIndexAttr(1);
263 SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
264 offsets.resize(srcSize);
265 offsets[tileHIdx] = tileHIndex;
266 offsets[tileWIdx] = tileWIndex;
267 offsets[loopNorFIdx] = loopNorFIndex;
268 offsets[loopCorFIdx] = loopCorFIndex;
269 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
270 sizes[heightIdx] = builder.getIndexAttr(height);
271 sizes[widthIdx] = builder.getIndexAttr(width);
272 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
275 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
276 loc, extractFilterType, source, offsets, sizes, strides);
278 return extractFilterOp;
283 Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
284 Value dest, Value loopNorFIndex, Value loopCorFIndex,
285 Value heightOffset, Value widthOffset, int64_t height,
286 int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
287 int64_t heightIdx, int64_t widthIdx) {
288 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
289 auto oneIndex = builder.getIndexAttr(1);
290 SmallVector<OpFoldResult> retOffsets;
291 retOffsets.resize(destSize);
292 retOffsets[loopNorFIdx] = loopNorFIndex;
293 retOffsets[loopCorFIdx] = loopCorFIndex;
294 retOffsets[heightIdx] = heightOffset;
295 retOffsets[widthIdx] = widthOffset;
296 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
297 retSizes[heightIdx] = builder.getIndexAttr(height);
298 retSizes[widthIdx] = builder.getIndexAttr(width);
299 SmallVector<OpFoldResult> strides(destSize, oneIndex);
301 auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
302 loc, source, dest, retOffsets, retSizes, strides);
304 return insertSliceOp;
309 Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
310 Value dest, Value tileHIndex, Value tileWIndex,
311 Value loopNorFIndex, Value loopCorFIndex, int64_t height,
312 int64_t width, int64_t tileHIdx, int64_t tileWIdx,
313 int64_t loopNorFIdx, int64_t loopCorFIdx,
314 int64_t heightIdx, int64_t widthIdx) {
315 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
316 auto zeroIndex = builder.getIndexAttr(0);
317 auto oneIndex = builder.getIndexAttr(1);
318 SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
319 retOffsets.resize(destSize);
320 retOffsets[tileHIdx] = tileHIndex;
321 retOffsets[tileWIdx] = tileWIndex;
322 retOffsets[loopNorFIdx] = loopNorFIndex;
323 retOffsets[loopCorFIdx] = loopCorFIndex;
324 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
325 retSizes[heightIdx] = builder.getIndexAttr(height);
326 retSizes[widthIdx] = builder.getIndexAttr(width);
327 SmallVector<OpFoldResult> strides(destSize, oneIndex);
329 auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
330 loc, source, dest, retOffsets, retSizes, strides);
332 return insertSliceOp;
346 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
347 Value retValue, int64_t m, int64_t r,
348 bool leftTransform =
true,
bool rightTransform =
true) {
350 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
352 {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
353 {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
354 {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
358 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
360 {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
361 {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
362 {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
365 auto filterType = cast<ShapedType>(filter.getType());
366 Type elementType = filterType.getElementType();
367 auto filterShape = filterType.getShape();
368 int64_t filterF = filterShape[0];
369 int64_t filterH = filterShape[1];
370 int64_t filterW = filterShape[2];
371 int64_t filterC = filterShape[3];
373 if (filterH != r && filterH != 1)
375 if (filterW != r && filterW != 1)
378 Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
379 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
381 Value FIter = ivs[0];
382 Value CIter = ivs[1];
386 extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
387 zeroIdx, filterH, filterW, 0,
390 TransformMapKeyTy key = {m, r};
392 Value matmulRetValue = extractFilter;
393 Value zero = builder.create<arith::ConstantOp>(
394 loc, rewriter.getZeroAttr(elementType));
397 auto it = GMatrices.find(key);
398 if (it == GMatrices.end())
400 const TransformMatrix &GMatrix = it->second;
402 retRows = GMatrix.rows;
406 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
408 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
410 Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
412 auto matmulOp = builder.create<linalg::MatmulOp>(
413 loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
414 matmulRetValue = matmulOp.getResult(0);
417 if (rightTransform) {
419 auto it = GTMatrices.find(key);
420 if (it == GTMatrices.end())
422 const TransformMatrix >Matrix = it->second;
428 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
430 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
432 Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
434 auto matmulOp = builder.create<linalg::MatmulOp>(
435 loc, matmulType, ValueRange{matmulRetValue,
GT}, ValueRange{init});
436 matmulRetValue = matmulOp.getResult(0);
440 int64_t retHeight = leftTransform ? m + r - 1 : 1;
441 int64_t retWidth = rightTransform ? m + r - 1 : 1;
444 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
445 zeroIdx, zeroIdx, retHeight, retWidth,
449 return {insertSliceOp};
452 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
453 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
454 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
456 rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
457 {oneStep, oneStep}, {retValue}, buildBody);
478 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
479 Value retValue, int64_t m, int64_t r,
480 bool leftTransform =
true,
bool rightTransform =
true) {
482 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
484 {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
485 {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
486 {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
490 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
492 {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
493 {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
494 {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
497 auto inputType = cast<ShapedType>(input.getType());
498 Type elementType = inputType.getElementType();
499 auto inputShape = inputType.getShape();
500 int64_t inputN = inputShape[0];
501 int64_t inputC = inputShape[3];
502 auto valueType = cast<ShapedType>(retValue.getType());
503 auto valueShape = valueType.getShape();
504 int64_t tileH = valueShape[2];
505 int64_t tileW = valueShape[3];
506 int64_t alphaH = leftTransform ? m + r - 1 : 1;
507 int64_t alphaW = rightTransform ? m + r - 1 : 1;
509 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
511 Value tileHIter = ivs[0];
512 Value tileWIter = ivs[1];
513 Value NIter = ivs[2];
514 Value CIter = ivs[3];
516 auto context = builder.getContext();
520 builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
522 builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
526 extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
527 widthOffset, alphaH, alphaW, 0,
530 TransformMapKeyTy key = {m, r};
533 Value matmulRetValue = extractInput;
534 Value zero = builder.create<arith::ConstantOp>(
535 loc, rewriter.getZeroAttr(elementType));
538 auto it = BTMatrices.find(key);
539 if (it == BTMatrices.end())
541 const TransformMatrix &BTMatrix = it->second;
543 retRows = BTMatrix.rows;
547 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
549 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
552 create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
554 auto matmulOp = builder.create<linalg::MatmulOp>(
555 loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
556 matmulRetValue = matmulOp.getResult(0);
559 if (rightTransform) {
561 auto it = BMatrices.find(key);
562 if (it == BMatrices.end())
564 const TransformMatrix &BMatrix = it->second;
566 retCols = BMatrix.cols;
570 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
572 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
574 create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
576 auto matmulOp = builder.create<linalg::MatmulOp>(
577 loc, matmulType, ValueRange{matmulRetValue,
B}, ValueRange{init});
578 matmulRetValue = matmulOp.getResult(0);
582 auto combinedVal = insert2DDataTo6D(
583 builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
584 CIter, retRows, retCols, 2, 3, 4, 5,
587 return {combinedVal};
590 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
591 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH);
592 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
593 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
594 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
595 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
597 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
598 {tileHBound, tileWBound, nUpperBound, cUpperBound},
599 {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
621 static Value matrixMultiply(RewriterBase &rewriter, Location loc,
622 Value transformedFilter, Value transformedInput,
623 Type outputElementType) {
625 auto filterType = cast<ShapedType>(transformedFilter.getType());
626 assert(filterType.hasStaticShape() &&
"only support static shapes.");
627 ArrayRef<int64_t> filterShape = filterType.getShape();
628 Type filterElementType = filterType.getElementType();
630 {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
632 SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
633 Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
634 loc, filterReassocType, transformedFilter, filterReassoc);
638 auto inputType = cast<ShapedType>(transformedInput.getType());
639 assert(inputType.hasStaticShape() &&
"only support static shapes.");
640 ArrayRef<int64_t> inputShape = inputType.getShape();
641 Type inputElementType = inputType.getElementType();
643 {inputShape[0] * inputShape[1],
644 inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
646 SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
647 Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
648 loc, inputReassocType, transformedInput, inputReassoc);
652 {inputShape[0] * inputShape[1],
653 inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
655 Value empty = rewriter
656 .create<tensor::EmptyOp>(loc, matmulType.getShape(),
659 Value zero = rewriter.create<arith::ConstantOp>(
660 loc, rewriter.getZeroAttr(outputElementType));
661 Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
663 auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
664 loc, matmulType, ValueRange({collapseInput, collapseFilter}),
669 SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
670 auto outputReassocType =
672 inputShape[3], inputShape[4], filterShape[3]},
674 auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
675 loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
696 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
697 Value output, int64_t m, int64_t r,
698 bool leftTransform =
true,
bool rightTransform =
true) {
700 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
702 {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
703 {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
704 {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
708 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
710 {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
711 {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
712 {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
715 auto valueType = cast<ShapedType>(value.getType());
716 Type elementType = valueType.getElementType();
717 auto valueShape = valueType.getShape();
718 int64_t valueH = valueShape[0];
719 int64_t valueW = valueShape[1];
720 int64_t valueN = valueShape[4];
721 int64_t valueF = valueShape[5];
722 int64_t alphaH = leftTransform ? m + r - 1 : 1;
723 int64_t alphaW = rightTransform ? m + r - 1 : 1;
725 if (valueH != alphaH && valueH != 1)
727 if (valueW != alphaW && valueW != 1)
730 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
732 auto context = builder.getContext();
733 Value tileHIter = ivs[0];
734 Value tileWIter = ivs[1];
735 Value NIter = ivs[2];
736 Value FIter = ivs[3];
740 extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
744 const TransformMapKeyTy key = {m, r};
745 const TransformMatrix &AMatrix = AMatrices.at(key);
746 const TransformMatrix &ATMatrix = ATMatrices.at(key);
747 int64_t
scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
748 (leftTransform ? ATMatrix.scalarFactor : 1);
749 int64_t retCols = rightTransform ? AMatrix.cols : 1;
750 int64_t retRows = leftTransform ? ATMatrix.rows : 1;
752 Value matmulRetValue = extractValue;
753 Value zero = builder.create<arith::ConstantOp>(
754 loc, rewriter.getZeroAttr(elementType));
759 builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
761 builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
764 extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
765 widthOffset, retRows, retCols,
771 Value init = outInitVal;
774 .create<tensor::EmptyOp>(loc, matmulType.getShape(),
777 init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
780 Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
782 auto matmulOp = builder.create<linalg::MatmulOp>(
783 loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
784 matmulRetValue = matmulOp.getResult(0);
787 if (rightTransform) {
790 Value init = outInitVal;
793 .create<tensor::EmptyOp>(loc, matmulType.getShape(),
796 init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
799 Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
801 auto matmulOp = builder.create<linalg::MatmulOp>(
802 loc, matmulType, ValueRange{matmulRetValue,
A}, ValueRange{init});
803 matmulRetValue = matmulOp.getResult(0);
808 Value scalarFactorValue = builder.create<arith::ConstantOp>(
811 auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
812 SmallVector<AffineMap> affineMaps = {
813 AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
817 .create<linalg::GenericOp>(
819 ValueRange{scalarFactorValue, matmulRetValue},
820 ValueRange{outInitVal}, affineMaps,
822 utils::IteratorType::parallel,
823 utils::IteratorType::parallel},
824 [&](OpBuilder &nestedBuilder, Location nestedLoc,
826 auto mulf = nestedBuilder.create<arith::MulFOp>(
827 nestedLoc, args[0], args[1]);
828 auto addf = nestedBuilder.create<arith::AddFOp>(
829 nestedLoc, mulf.getResult(), args[2]);
830 nestedBuilder.create<linalg::YieldOp>(nestedLoc,
838 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
839 heightOffset, widthOffset, retRows, retCols,
844 return {combinedVal};
847 int64_t tilwH = valueShape[2];
848 int64_t tileW = valueShape[3];
849 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
850 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH);
851 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
852 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
853 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
854 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
856 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
857 {tileHBound, tileWBound, nUpperBound, fUpperBound},
858 {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
864 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
865 Value value, ArrayRef<int64_t> alignedShape) {
866 auto valueType = cast<ShapedType>(value.getType());
867 Type elementType = valueType.getElementType();
869 Value padValue = rewriter.create<arith::ConstantOp>(
870 loc, elementType, rewriter.getZeroAttr(elementType));
877 static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
879 RankedTensorType extractedType) {
880 OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
881 OpFoldResult oneIndex = rewriter.getIndexAttr(1);
882 SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
883 SmallVector<OpFoldResult, 4> strides(4, oneIndex);
885 ArrayRef<int64_t> extractedShape = extractedType.getShape();
886 SmallVector<OpFoldResult> sizes =
889 return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
890 offsets, sizes, strides);
896 attr, [](
const APInt &element) {
return element.getSExtValue() == 1; });
901 static FailureOr<Operation *>
902 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
903 int64_t m, int64_t r) {
904 Value input = convOp.getInputs()[0];
905 Value filter = convOp.getInputs()[1];
906 Value output = convOp.getOutputs()[0];
907 auto inputType = cast<ShapedType>(input.getType());
908 auto filterType = cast<ShapedType>(filter.getType());
909 auto outputType = cast<ShapedType>(output.getType());
911 if (!inputType.hasStaticShape())
912 return rewriter.notifyMatchFailure(convOp,
913 "expected a static shape for the input");
915 if (!filterType.hasStaticShape())
916 return rewriter.notifyMatchFailure(
917 convOp,
"expected a static shape for the filter");
920 return rewriter.notifyMatchFailure(convOp,
921 "expected all ones for dilations");
924 return rewriter.notifyMatchFailure(convOp,
"expected all ones for strides");
926 ArrayRef<int64_t> filterShape = filterType.getShape();
927 int64_t filterF = filterShape[0];
928 int64_t filterH = filterShape[1];
929 int64_t filterW = filterShape[2];
930 int64_t filterC = filterShape[3];
931 ArrayRef<int64_t> inputShape = inputType.getShape();
932 int64_t inputN = inputShape[0];
933 int64_t inputH = inputShape[1];
934 int64_t inputW = inputShape[2];
935 int64_t inputC = inputShape[3];
936 ArrayRef<int64_t> outputShape = outputType.getShape();
937 int64_t outputN = outputShape[0];
938 int64_t outputH = outputShape[1];
939 int64_t outputW = outputShape[2];
940 int64_t outputF = outputShape[3];
943 bool isSupportedFilter =
false;
944 if (filterH == filterW && filterH == r)
945 isSupportedFilter =
true;
946 if (filterH == r && filterW == 1)
947 isSupportedFilter =
true;
948 if (filterH == 1 && filterW == r)
949 isSupportedFilter =
true;
951 if (!isSupportedFilter)
952 return rewriter.notifyMatchFailure(
953 convOp,
"only support filter (r x r), (r x 1) or (1 x r)");
957 F_2_3, F_4_3, F_2_5};
959 TransformMapKeyTy key = {m, r};
960 auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
963 if (it == validConfigs.end())
967 Location loc = convOp.getLoc();
970 bool leftTransform = filterH != 1;
972 bool rightTransform = filterW != 1;
973 int64_t heightM = leftTransform ? m : 1;
974 int64_t widthM = rightTransform ? m : 1;
975 int64_t heightR = leftTransform ? r : 1;
976 int64_t widthR = rightTransform ? r : 1;
979 Type filterElementType = filterType.getElementType();
980 int64_t alphaH = heightM + heightR - 1;
981 int64_t alphaW = widthM + widthR - 1;
982 int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
983 int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
986 Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
988 auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
989 loc, retType, filter, retValue, m, r);
995 Type inputElementType = inputType.getElementType();
996 int64_t alignedInputH = tileH * heightM + (heightR - 1);
997 int64_t alignedInputW = tileW * widthM + (widthR - 1);
998 if (alignedInputH != inputH || alignedInputW != inputW) {
999 input = padToAlignedTensor(rewriter, loc, input,
1000 {inputN, alignedInputH, alignedInputW, inputC});
1004 {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1005 retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
1007 auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
1008 loc, retType, input, retValue, m, r);
1010 Type outputElementType = outputType.getElementType();
1011 Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1012 transformedInput, outputElementType);
1018 int64_t alignedOutputH = tileH * heightM;
1019 int64_t alignedOutputW = tileW * widthM;
1020 bool isOutputUnaligned =
1021 ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1022 if (isOutputUnaligned) {
1024 {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1026 padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1027 outputType = alignedOutputType;
1030 Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
1031 loc, outputType, matmulRet, output, m, r);
1035 if (isOutputUnaligned) {
1036 transformedOutput = extractFromAlignedTensor(
1037 rewriter, loc, transformedOutput,
1039 outputElementType));
1042 rewriter.replaceOp(convOp, transformedOutput);
1044 return transformedOutput.getDefiningOp();
1048 FailureOr<Operation *>
1049 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1050 linalg::WinogradFilterTransformOp op) {
1051 Location loc = op.getLoc();
1052 Value filter = op.getFilter();
1053 auto filterType = cast<ShapedType>(filter.getType());
1054 auto filterShape = filterType.getShape();
1055 int64_t filterH = filterShape[1];
1056 int64_t filterW = filterShape[2];
1059 bool leftTransform = filterH != 1;
1061 bool rightTransform = filterW != 1;
1062 Value transformedFilter =
1063 filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
1064 op.getR(), leftTransform, rightTransform);
1065 if (!transformedFilter)
1068 rewriter.replaceOp(op, transformedFilter);
1070 return transformedFilter.getDefiningOp();
1074 FailureOr<Operation *>
1075 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1076 linalg::WinogradInputTransformOp op) {
1077 Location loc = op.getLoc();
1078 Value input = op.getInput();
1079 auto inputType = cast<ShapedType>(input.getType());
1080 auto inputShape = inputType.getShape();
1081 int64_t inputH = inputShape[1];
1082 int64_t inputW = inputShape[2];
1085 bool leftTransform = inputH != 1;
1087 bool rightTransform = inputW != 1;
1088 Value transformedInput =
1089 inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
1090 op.getR(), leftTransform, rightTransform);
1091 if (!transformedInput)
1094 rewriter.replaceOp(op, transformedInput);
1096 return transformedInput.getDefiningOp();
1100 FailureOr<Operation *>
1101 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1102 linalg::WinogradOutputTransformOp op) {
1103 Location loc = op.getLoc();
1104 Value value = op.getValue();
1105 auto valueType = cast<ShapedType>(value.getType());
1106 auto valueShape = valueType.getShape();
1107 int64_t valueH = valueShape[0];
1108 int64_t valueW = valueShape[1];
1111 bool leftTransform = valueH != 1;
1113 bool rightTransform = valueW != 1;
1114 Value transformedOutput =
1115 outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
1116 op.getR(), leftTransform, rightTransform);
1117 if (!transformedOutput)
1120 rewriter.replaceOp(op, transformedOutput);
1122 return transformedOutput.getDefiningOp();
1126 class DecomposeWinogradFilterTransform final
1127 :
public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1131 LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1132 PatternRewriter &rewriter)
const override {
1133 return decomposeWinogradFilterTransformHelper(rewriter, op);
1138 class DecomposeWinogradInputTransform final
1139 :
public OpRewritePattern<linalg::WinogradInputTransformOp> {
1143 LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1144 PatternRewriter &rewriter)
const override {
1145 return decomposeWinogradInputTransformHelper(rewriter, op);
1150 class DecomposeWinogradOutputTransform final
1151 :
public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1155 LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1156 PatternRewriter &rewriter)
const override {
1157 return decomposeWinogradOutputTransformHelper(rewriter, op);
1162 class WinogradConv2DNhwcFhwc final
1163 :
public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1167 : OpRewritePattern(context), m(m), r(r) {}
1169 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1170 PatternRewriter &rewriter)
const override {
1171 if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
1185 linalg::Conv2DNhwcFhwcOp op, int64_t m,
1187 return winogradConv2DHelper(rewriter, op, m, r);
1190 FailureOr<Operation *>
1192 linalg::WinogradFilterTransformOp op) {
1193 return decomposeWinogradFilterTransformHelper(rewriter, op);
1196 FailureOr<Operation *>
1198 linalg::WinogradInputTransformOp op) {
1199 return decomposeWinogradInputTransformHelper(rewriter, op);
1202 FailureOr<Operation *>
1204 linalg::WinogradOutputTransformOp op) {
1205 return decomposeWinogradOutputTransformHelper(rewriter, op);
1212 patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
1218 .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1219 DecomposeWinogradOutputTransform>(context);
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns)
Patterns to decompose Winograd operators.
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold)
Create a tensor::PadOp that pads source to the size of the statically sized type whose static sizes a...
static bool hasAllOneValues(DenseIntElementsAttr attr)
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r)
Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
@ Type
An inlay hint that for a type annotation.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...