22 #include "llvm/Support/MathExtras.h"
41 constexpr
float G_2x2_3x3[] = {
48 constexpr
float GT_2x2_3x3[] = {
54 constexpr
float BT_2x2_3x3[] = {
61 constexpr
float B_2x2_3x3[] = {
68 constexpr
float AT_2x2_3x3[] = {
73 constexpr
float A_2x2_3x3[] = {
80 constexpr
float G_4x4_3x3[] = {
89 constexpr
float GT_4x4_3x3[] = {
90 1, -1./3, -1./3, 1./12, 1./12, 0,
91 0, 1./3, -1./3, -1./6, 1./6, 0,
92 0, -1./3, -1./3, 1./3, 1./3, 1
95 constexpr
float BT_4x4_3x3[] = {
96 1./4, 0, -5./16, 0, 1./16, 0,
97 0, 1./4, -1./4, -1./16, 1./16, 0,
98 0, -1./4, -1./4, 1./16, 1./16, 0,
99 0, 1./4, -1./8, -1./4, 1./8, 0,
100 0, -1./4, -1./8, 1./4, 1./8, 0,
101 0, 1./4, 0, -5./16, 0, 1./16
104 constexpr
float B_4x4_3x3[] = {
106 0, 1./4, -1./4, 1./4, -1./4, 1./4,
107 -5./16, -1./4, -1./4, -1./8, -1./8, 0,
108 0, -1./16, 1./16, -1./4, 1./4, -5./16,
109 1./16, 1./16, 1./16, 1./8, 1./8, 0,
113 constexpr
float AT_4x4_3x3[] = {
114 1./8, 1./4, 1./4, 1./8, 1./8, 0,
115 0, -1./4, 1./4, -1./4, 1./4, 0,
116 0, 1./4, 1./4, 1./2, 1./2, 0,
117 0, -1./4, 1./4, -1, 1, 1./2
120 constexpr
float A_4x4_3x3[] = {
122 1./4, -1./4, 1./4, -1./4,
123 1./4, 1./4, 1./4, 1./4,
124 1./8, -1./4, 1./2, -1,
129 constexpr
float G_2x2_5x5[] = {
131 1./6, -1./6, 1./6, -1./6, 1./6,
132 -1./6, -1./6, -1./6, -1./6, -1./6,
133 -4./15, 2./15, -1./15, 1./30, -1./60,
134 1./60, 1./30, 1./15, 2./15, 4./15,
138 constexpr
float GT_2x2_5x5[] = {
139 1, 1./6, -1./6, -4./15, 1./60, 0,
140 0, -1./6, -1./6, 2./15, 1./30, 0,
141 0, 1./6, -1./6, -1./15, 1./15, 0,
142 0, -1./6, -1./6, 1./30, 2./15, 0,
143 0, 1./6, -1./6, -1./60, 4./15, 1
146 constexpr
float BT_2x2_5x5[] = {
147 1./8, 3./16, -1./4, -3./16, 1./8, 0,
148 0, 1./8, 1./16, -5./16, 1./8, 0,
149 0, -1./8, -5./16, -1./16, 1./8, 0,
150 0, 1./4, -1./8, -1./4, 1./8, 0,
151 0, -1./8, -1./4, 1./8, 1./4, 0,
152 0, 1./8, 3./16, -1./4, -3./16, 1./8
155 constexpr
float B_2x2_5x5[] = {
157 3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
158 -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
159 -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
160 1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
164 constexpr
float AT_2x2_5x5[] = {
166 0, -1, 1, -1, 2, 1./2
169 constexpr
float A_2x2_5x5[] = {
179 using TransformMapKeyTy = std::pair<int, int>;
188 constexpr TransformMapKeyTy F_2_3{2, 3};
189 constexpr TransformMapKeyTy F_4_3{4, 3};
190 constexpr TransformMapKeyTy F_2_5{2, 5};
193 struct TransformMatrix {
194 TransformMatrix(
const float *
table, int64_t
rows, int64_t
cols,
205 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
206 TransformMatrix transform, Type type) {
207 ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
209 return builder.create<arith::ConstantOp>(
212 SmallVector<int64_t>{transform.rows, transform.cols}, type),
217 Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
218 Value loopNorFIndex, Value loopCorFIndex,
219 Value heightOffset, Value widthOffset,
220 int64_t extractHeight, int64_t extractWidth,
221 int64_t loopNorFIdx, int64_t loopCorFIdx,
222 int64_t heightIdx, int64_t widthIdx) {
223 auto sourceType = cast<ShapedType>(source.getType());
224 Type elementType = sourceType.getElementType();
225 int64_t srcSize = sourceType.getRank();
227 auto oneIndex = builder.getIndexAttr(1);
228 SmallVector<OpFoldResult> offsets;
229 offsets.resize(srcSize);
230 offsets[loopNorFIdx] = loopNorFIndex;
231 offsets[loopCorFIdx] = loopCorFIndex;
232 offsets[heightIdx] = heightOffset;
233 offsets[widthIdx] = widthOffset;
234 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
235 sizes[heightIdx] = builder.getIndexAttr(extractHeight);
236 sizes[widthIdx] = builder.getIndexAttr(extractWidth);
237 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
239 auto extractFilterType =
241 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
242 loc, extractFilterType, source, offsets, sizes, strides);
244 return extractFilterOp;
248 Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
249 Value tileHIndex, Value tileWIndex,
250 Value loopNorFIndex, Value loopCorFIndex,
251 int64_t tileHIdx, int64_t tileWIdx,
252 int64_t loopNorFIdx, int64_t loopCorFIdx,
253 int64_t heightIdx, int64_t widthIdx) {
254 auto sourceType = cast<ShapedType>(source.getType());
255 Type elementType = sourceType.getElementType();
256 auto sourceShape = sourceType.getShape();
257 int64_t srcSize = sourceType.getRank();
258 int64_t height = sourceShape[heightIdx];
259 int64_t width = sourceShape[widthIdx];
261 auto zeroIndex = builder.getIndexAttr(0);
262 auto oneIndex = builder.getIndexAttr(1);
263 SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
264 offsets.resize(srcSize);
265 offsets[tileHIdx] = tileHIndex;
266 offsets[tileWIdx] = tileWIndex;
267 offsets[loopNorFIdx] = loopNorFIndex;
268 offsets[loopCorFIdx] = loopCorFIndex;
269 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
270 sizes[heightIdx] = builder.getIndexAttr(height);
271 sizes[widthIdx] = builder.getIndexAttr(width);
272 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
275 auto extractFilterOp = builder.create<tensor::ExtractSliceOp>(
276 loc, extractFilterType, source, offsets, sizes, strides);
278 return extractFilterOp;
283 Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
284 Value dest, Value loopNorFIndex, Value loopCorFIndex,
285 Value heightOffset, Value widthOffset, int64_t height,
286 int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
287 int64_t heightIdx, int64_t widthIdx) {
288 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
289 auto oneIndex = builder.getIndexAttr(1);
290 SmallVector<OpFoldResult> retOffsets;
291 retOffsets.resize(destSize);
292 retOffsets[loopNorFIdx] = loopNorFIndex;
293 retOffsets[loopCorFIdx] = loopCorFIndex;
294 retOffsets[heightIdx] = heightOffset;
295 retOffsets[widthIdx] = widthOffset;
296 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
297 retSizes[heightIdx] = builder.getIndexAttr(height);
298 retSizes[widthIdx] = builder.getIndexAttr(width);
299 SmallVector<OpFoldResult> strides(destSize, oneIndex);
301 auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
302 loc, source, dest, retOffsets, retSizes, strides);
304 return insertSliceOp;
309 Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
310 Value dest, Value tileHIndex, Value tileWIndex,
311 Value loopNorFIndex, Value loopCorFIndex, int64_t height,
312 int64_t width, int64_t tileHIdx, int64_t tileWIdx,
313 int64_t loopNorFIdx, int64_t loopCorFIdx,
314 int64_t heightIdx, int64_t widthIdx) {
315 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
316 auto zeroIndex = builder.getIndexAttr(0);
317 auto oneIndex = builder.getIndexAttr(1);
318 SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
319 retOffsets.resize(destSize);
320 retOffsets[tileHIdx] = tileHIndex;
321 retOffsets[tileWIdx] = tileWIndex;
322 retOffsets[loopNorFIdx] = loopNorFIndex;
323 retOffsets[loopCorFIdx] = loopCorFIndex;
324 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
325 retSizes[heightIdx] = builder.getIndexAttr(height);
326 retSizes[widthIdx] = builder.getIndexAttr(width);
327 SmallVector<OpFoldResult> strides(destSize, oneIndex);
329 auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
330 loc, source, dest, retOffsets, retSizes, strides);
332 return insertSliceOp;
346 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
347 Value retValue, int64_t m, int64_t r,
348 bool leftTransform =
true,
bool rightTransform =
true) {
350 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
352 {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
353 {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
354 {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
358 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
360 {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
361 {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
362 {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
365 auto filterType = cast<ShapedType>(filter.getType());
366 Type elementType = filterType.getElementType();
367 auto filterShape = filterType.getShape();
368 int64_t filterF = filterShape[0];
369 int64_t filterH = filterShape[1];
370 int64_t filterW = filterShape[2];
371 int64_t filterC = filterShape[3];
373 if (filterH != r && filterH != 1)
375 if (filterW != r && filterW != 1)
378 Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
379 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
381 Value FIter = ivs[0];
382 Value CIter = ivs[1];
386 extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
387 zeroIdx, filterH, filterW, 0,
390 TransformMapKeyTy key = {m, r};
392 Value matmulRetValue = extractFilter;
393 Value zero = builder.create<arith::ConstantOp>(
394 loc, rewriter.getZeroAttr(elementType));
397 auto it = GMatrices.find(key);
398 if (it == GMatrices.end())
400 const TransformMatrix &GMatrix = it->second;
402 retRows = GMatrix.rows;
406 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
408 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
410 Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
412 auto matmulOp = builder.create<linalg::MatmulOp>(
413 loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
414 matmulRetValue = matmulOp.getResult(0);
417 if (rightTransform) {
419 auto it = GTMatrices.find(key);
420 if (it == GTMatrices.end())
422 const TransformMatrix >Matrix = it->second;
428 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
430 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
432 Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
434 auto matmulOp = builder.create<linalg::MatmulOp>(
435 loc, matmulType, ValueRange{matmulRetValue,
GT}, ValueRange{init});
436 matmulRetValue = matmulOp.getResult(0);
440 int64_t retHeight = leftTransform ? m + r - 1 : 1;
441 int64_t retWidth = rightTransform ? m + r - 1 : 1;
444 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
445 zeroIdx, zeroIdx, retHeight, retWidth,
449 return {insertSliceOp};
452 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
453 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
454 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
456 rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
457 {oneStep, oneStep}, {retValue}, buildBody);
478 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
479 Value retValue, int64_t m, int64_t r,
480 bool leftTransform =
true,
bool rightTransform =
true) {
482 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
484 {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
485 {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
486 {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
490 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
492 {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
493 {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
494 {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
497 auto inputType = cast<ShapedType>(input.getType());
498 Type elementType = inputType.getElementType();
499 auto inputShape = inputType.getShape();
500 int64_t inputN = inputShape[0];
501 int64_t inputC = inputShape[3];
502 auto valueType = cast<ShapedType>(retValue.getType());
503 auto valueShape = valueType.getShape();
504 int64_t tileH = valueShape[2];
505 int64_t tileW = valueShape[3];
506 int64_t alphaH = leftTransform ? m + r - 1 : 1;
507 int64_t alphaW = rightTransform ? m + r - 1 : 1;
509 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
511 Value tileHIter = ivs[0];
512 Value tileWIter = ivs[1];
513 Value NIter = ivs[2];
514 Value CIter = ivs[3];
516 auto context = builder.getContext();
520 builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
522 builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
526 extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
527 widthOffset, alphaH, alphaW, 0,
530 TransformMapKeyTy key = {m, r};
533 Value matmulRetValue = extractInput;
534 Value zero = builder.create<arith::ConstantOp>(
535 loc, rewriter.getZeroAttr(elementType));
538 auto it = BTMatrices.find(key);
539 if (it == BTMatrices.end())
541 const TransformMatrix &BTMatrix = it->second;
543 retRows = BTMatrix.rows;
547 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
549 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
552 create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
554 auto matmulOp = builder.create<linalg::MatmulOp>(
555 loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
556 matmulRetValue = matmulOp.getResult(0);
559 if (rightTransform) {
561 auto it = BMatrices.find(key);
562 if (it == BMatrices.end())
564 const TransformMatrix &BMatrix = it->second;
566 retCols = BMatrix.cols;
570 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
572 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
574 create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
576 auto matmulOp = builder.create<linalg::MatmulOp>(
577 loc, matmulType, ValueRange{matmulRetValue,
B}, ValueRange{init});
578 matmulRetValue = matmulOp.getResult(0);
582 auto combinedVal = insert2DDataTo6D(
583 builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
584 CIter, retRows, retCols, 2, 3, 4, 5,
587 return {combinedVal};
590 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
591 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH);
592 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
593 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
594 auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
595 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
597 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
598 {tileHBound, tileWBound, nUpperBound, cUpperBound},
599 {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
621 static Value matrixMultiply(RewriterBase &rewriter, Location loc,
622 Value transformedFilter, Value transformedInput,
623 Type outputElementType) {
625 auto filterType = cast<ShapedType>(transformedFilter.getType());
626 assert(filterType.hasStaticShape() &&
"only support static shapes.");
627 ArrayRef<int64_t> filterShape = filterType.getShape();
628 Type filterElementType = filterType.getElementType();
630 {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
632 SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
633 Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>(
634 loc, filterReassocType, transformedFilter, filterReassoc);
638 auto inputType = cast<ShapedType>(transformedInput.getType());
639 assert(inputType.hasStaticShape() &&
"only support static shapes.");
640 ArrayRef<int64_t> inputShape = inputType.getShape();
641 Type inputElementType = inputType.getElementType();
643 {inputShape[0] * inputShape[1],
644 inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
646 SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
647 Value collapseInput = rewriter.create<tensor::CollapseShapeOp>(
648 loc, inputReassocType, transformedInput, inputReassoc);
652 {inputShape[0] * inputShape[1],
653 inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
655 Value empty = rewriter
656 .create<tensor::EmptyOp>(loc, matmulType.getShape(),
659 Value zero = rewriter.create<arith::ConstantOp>(
660 loc, rewriter.getZeroAttr(outputElementType));
661 Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
663 auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
664 loc, matmulType, ValueRange({collapseInput, collapseFilter}),
669 SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
670 auto outputReassocType =
672 inputShape[3], inputShape[4], filterShape[3]},
674 auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
675 loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
696 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
697 Value output, int64_t m, int64_t r,
698 bool leftTransform =
true,
bool rightTransform =
true) {
700 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
702 {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
703 {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
704 {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
708 static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
710 {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
711 {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
712 {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
715 auto valueType = cast<ShapedType>(value.getType());
716 Type elementType = valueType.getElementType();
717 auto valueShape = valueType.getShape();
718 int64_t valueH = valueShape[0];
719 int64_t valueW = valueShape[1];
720 int64_t valueN = valueShape[4];
721 int64_t valueF = valueShape[5];
722 int64_t alphaH = leftTransform ? m + r - 1 : 1;
723 int64_t alphaW = rightTransform ? m + r - 1 : 1;
725 if (valueH != alphaH && valueH != 1)
727 if (valueW != alphaW && valueW != 1)
730 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
732 Value tileHIter = ivs[0];
733 Value tileWIter = ivs[1];
734 Value NIter = ivs[2];
735 Value FIter = ivs[3];
739 extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
743 TransformMapKeyTy key = {m, r};
746 int64_t leftScalarFactor = 1;
747 int64_t rightScalarFactor = 1;
748 Value matmulRetValue = extractValue;
749 Value zero = builder.create<arith::ConstantOp>(
750 loc, rewriter.getZeroAttr(elementType));
753 auto it = ATMatrices.find(key);
754 if (it == ATMatrices.end())
756 const TransformMatrix &ATMatrix = it->second;
758 leftScalarFactor = ATMatrix.scalarFactor;
759 retRows = ATMatrix.rows;
763 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
765 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
767 Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
769 auto matmulOp = builder.create<linalg::MatmulOp>(
770 loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init});
771 matmulRetValue = matmulOp.getResult(0);
774 if (rightTransform) {
776 auto it = AMatrices.find(key);
777 if (it == AMatrices.end())
779 const TransformMatrix &AMatrix = it->second;
781 rightScalarFactor = AMatrix.scalarFactor;
784 retCols = AMatrix.cols;
787 .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
789 auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
791 Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
793 auto matmulOp = builder.create<linalg::MatmulOp>(
794 loc, matmulType, ValueRange{matmulRetValue,
A}, ValueRange{init});
795 matmulRetValue = matmulOp.getResult(0);
798 if (leftScalarFactor * rightScalarFactor != 1) {
802 FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor));
804 auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
807 auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
808 SmallVector<AffineMap> affineMaps = {
810 auto broadcastedScalar =
812 .create<linalg::GenericOp>(
813 loc, matmulType, ValueRange{
scalarFactor}, ValueRange{init},
816 utils::IteratorType::parallel,
817 utils::IteratorType::parallel},
818 [&](OpBuilder &nestedBuilder, Location nestedLoc,
820 nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
824 matmulRetValue = builder
825 .create<linalg::MulOp>(
827 ValueRange{broadcastedScalar, matmulRetValue},
832 auto context = builder.getContext();
836 builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
838 builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
842 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
843 heightOffset, widthOffset, retRows, retCols,
848 return {combinedVal};
851 int64_t tilwH = valueShape[2];
852 int64_t tileW = valueShape[3];
853 auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
854 auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH);
855 auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW);
856 auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN);
857 auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF);
858 auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
860 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
861 {tileHBound, tileWBound, nUpperBound, fUpperBound},
862 {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
868 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
869 Value value, ArrayRef<int64_t> alignedShape) {
870 auto valueType = cast<ShapedType>(value.getType());
871 Type elementType = valueType.getElementType();
873 Value padValue = rewriter.create<arith::ConstantOp>(
874 loc, elementType, rewriter.getZeroAttr(elementType));
881 static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
883 RankedTensorType extractedType) {
884 OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
885 OpFoldResult oneIndex = rewriter.getIndexAttr(1);
886 SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
887 SmallVector<OpFoldResult, 4> strides(4, oneIndex);
889 ArrayRef<int64_t> extractedShape = extractedType.getShape();
890 SmallVector<OpFoldResult> sizes =
893 return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value,
894 offsets, sizes, strides);
900 attr, [](
const APInt &element) {
return element.getSExtValue() == 1; });
905 static FailureOr<Operation *>
906 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
907 int64_t m, int64_t r) {
908 Value input = convOp.getInputs()[0];
909 Value filter = convOp.getInputs()[1];
910 Value output = convOp.getOutputs()[0];
911 auto inputType = cast<ShapedType>(input.getType());
912 auto filterType = cast<ShapedType>(filter.getType());
913 auto outputType = cast<ShapedType>(output.getType());
915 if (!inputType.hasStaticShape())
916 return rewriter.notifyMatchFailure(convOp,
917 "expected a static shape for the input");
919 if (!filterType.hasStaticShape())
920 return rewriter.notifyMatchFailure(
921 convOp,
"expected a static shape for the filter");
924 return rewriter.notifyMatchFailure(convOp,
925 "expected all ones for dilations");
928 return rewriter.notifyMatchFailure(convOp,
"expected all ones for strides");
930 ArrayRef<int64_t> filterShape = filterType.getShape();
931 int64_t filterF = filterShape[0];
932 int64_t filterH = filterShape[1];
933 int64_t filterW = filterShape[2];
934 int64_t filterC = filterShape[3];
935 ArrayRef<int64_t> inputShape = inputType.getShape();
936 int64_t inputN = inputShape[0];
937 int64_t inputH = inputShape[1];
938 int64_t inputW = inputShape[2];
939 int64_t inputC = inputShape[3];
940 ArrayRef<int64_t> outputShape = outputType.getShape();
941 int64_t outputN = outputShape[0];
942 int64_t outputH = outputShape[1];
943 int64_t outputW = outputShape[2];
944 int64_t outputF = outputShape[3];
947 bool isSupportedFilter =
false;
948 if (filterH == filterW && filterH == r)
949 isSupportedFilter =
true;
950 if (filterH == r && filterW == 1)
951 isSupportedFilter =
true;
952 if (filterH == 1 && filterW == r)
953 isSupportedFilter =
true;
955 if (!isSupportedFilter)
956 return rewriter.notifyMatchFailure(
957 convOp,
"only support filter (r x r), (r x 1) or (1 x r)");
961 F_2_3, F_4_3, F_2_5};
963 TransformMapKeyTy key = {m, r};
964 auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
967 if (it == validConfigs.end())
971 Location loc = convOp.getLoc();
974 bool leftTransform = filterH != 1;
976 bool rightTransform = filterW != 1;
977 int64_t heightM = leftTransform ? m : 1;
978 int64_t widthM = rightTransform ? m : 1;
979 int64_t heightR = leftTransform ? r : 1;
980 int64_t widthR = rightTransform ? r : 1;
983 Type filterElementType = filterType.getElementType();
984 int64_t alphaH = heightM + heightR - 1;
985 int64_t alphaW = widthM + widthR - 1;
986 int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
987 int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
990 Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
992 auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
993 loc, retType, filter, retValue, m, r);
999 Type inputElementType = inputType.getElementType();
1000 int64_t alignedInputH = tileH * heightM + (heightR - 1);
1001 int64_t alignedInputW = tileW * widthM + (widthR - 1);
1002 if (alignedInputH != inputH || alignedInputW != inputW) {
1003 input = padToAlignedTensor(rewriter, loc, input,
1004 {inputN, alignedInputH, alignedInputW, inputC});
1008 {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1009 retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(),
1011 auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
1012 loc, retType, input, retValue, m, r);
1014 Type outputElementType = outputType.getElementType();
1015 Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1016 transformedInput, outputElementType);
1022 int64_t alignedOutputH = tileH * heightM;
1023 int64_t alignedOutputW = tileW * widthM;
1024 bool isOutputUnaligned =
1025 ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1026 if (isOutputUnaligned) {
1028 {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1030 padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1031 outputType = alignedOutputType;
1034 Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
1035 loc, outputType, matmulRet, output, m, r);
1039 if (isOutputUnaligned) {
1040 transformedOutput = extractFromAlignedTensor(
1041 rewriter, loc, transformedOutput,
1043 outputElementType));
1046 rewriter.replaceOp(convOp, transformedOutput);
1048 return transformedOutput.getDefiningOp();
1052 FailureOr<Operation *>
1053 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1054 linalg::WinogradFilterTransformOp op) {
1055 Location loc = op.getLoc();
1056 Value filter = op.getFilter();
1057 auto filterType = cast<ShapedType>(filter.getType());
1058 auto filterShape = filterType.getShape();
1059 int64_t filterH = filterShape[1];
1060 int64_t filterW = filterShape[2];
1063 bool leftTransform = filterH != 1;
1065 bool rightTransform = filterW != 1;
1066 Value transformedFilter =
1067 filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(),
1068 op.getR(), leftTransform, rightTransform);
1069 if (!transformedFilter)
1072 rewriter.replaceOp(op, transformedFilter);
1074 return transformedFilter.getDefiningOp();
1078 FailureOr<Operation *>
1079 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1080 linalg::WinogradInputTransformOp op) {
1081 Location loc = op.getLoc();
1082 Value input = op.getInput();
1083 auto inputType = cast<ShapedType>(input.getType());
1084 auto inputShape = inputType.getShape();
1085 int64_t inputH = inputShape[1];
1086 int64_t inputW = inputShape[2];
1089 bool leftTransform = inputH != 1;
1091 bool rightTransform = inputW != 1;
1092 Value transformedInput =
1093 inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
1094 op.getR(), leftTransform, rightTransform);
1095 if (!transformedInput)
1098 rewriter.replaceOp(op, transformedInput);
1100 return transformedInput.getDefiningOp();
1104 FailureOr<Operation *>
1105 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1106 linalg::WinogradOutputTransformOp op) {
1107 Location loc = op.getLoc();
1108 Value value = op.getValue();
1109 auto valueType = cast<ShapedType>(value.getType());
1110 auto valueShape = valueType.getShape();
1111 int64_t valueH = valueShape[0];
1112 int64_t valueW = valueShape[1];
1115 bool leftTransform = valueH != 1;
1117 bool rightTransform = valueW != 1;
1118 Value transformedOutput =
1119 outputTransform(rewriter, loc, value, op.getOutput(), op.getM(),
1120 op.getR(), leftTransform, rightTransform);
1121 if (!transformedOutput)
1124 rewriter.replaceOp(op, transformedOutput);
1126 return transformedOutput.getDefiningOp();
1130 class DecomposeWinogradFilterTransform final
1131 :
public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1135 LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1136 PatternRewriter &rewriter)
const override {
1137 return decomposeWinogradFilterTransformHelper(rewriter, op);
1142 class DecomposeWinogradInputTransform final
1143 :
public OpRewritePattern<linalg::WinogradInputTransformOp> {
1147 LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1148 PatternRewriter &rewriter)
const override {
1149 return decomposeWinogradInputTransformHelper(rewriter, op);
1154 class DecomposeWinogradOutputTransform final
1155 :
public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1159 LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1160 PatternRewriter &rewriter)
const override {
1161 return decomposeWinogradOutputTransformHelper(rewriter, op);
1166 class WinogradConv2DNhwcFhwc final
1167 :
public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1171 : OpRewritePattern(context), m(m), r(r) {}
1173 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1174 PatternRewriter &rewriter)
const override {
1175 if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
1189 linalg::Conv2DNhwcFhwcOp op, int64_t m,
1191 return winogradConv2DHelper(rewriter, op, m, r);
1194 FailureOr<Operation *>
1196 linalg::WinogradFilterTransformOp op) {
1197 return decomposeWinogradFilterTransformHelper(rewriter, op);
1200 FailureOr<Operation *>
1202 linalg::WinogradInputTransformOp op) {
1203 return decomposeWinogradInputTransformHelper(rewriter, op);
1206 FailureOr<Operation *>
1208 linalg::WinogradOutputTransformOp op) {
1209 return decomposeWinogradOutputTransformHelper(rewriter, op);
1216 patterns.
insert<WinogradConv2DNhwcFhwc>(context, m, r);
1222 .
insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1223 DecomposeWinogradOutputTransform>(context);
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns)
Patterns to decompose Winograd operators.
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold)
Create a tensor::PadOp that pads source to the size of the statically sized type whose static sizes a...
static bool hasAllOneValues(DenseIntElementsAttr attr)
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r)
Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
@ Type
An inlay hint that for a type annotation.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...