21 #include "llvm/Support/MathExtras.h"
49 constexpr
float G_2x2_3x3[] = {
56 constexpr
float GT_2x2_3x3[] = {
62 constexpr
float BT_2x2_3x3[] = {
69 constexpr
float B_2x2_3x3[] = {
76 constexpr
float AT_2x2_3x3[] = {
81 constexpr
float A_2x2_3x3[] = {
88 constexpr
float G_4x4_3x3[] = {
97 constexpr
float GT_4x4_3x3[] = {
98 1, -1./3, -1./3, 1./12, 1./12, 0,
99 0, 1./3, -1./3, -1./6, 1./6, 0,
100 0, -1./3, -1./3, 1./3, 1./3, 1
103 constexpr
float BT_4x4_3x3[] = {
104 1./4, 0, -5./16, 0, 1./16, 0,
105 0, 1./4, -1./4, -1./16, 1./16, 0,
106 0, -1./4, -1./4, 1./16, 1./16, 0,
107 0, 1./4, -1./8, -1./4, 1./8, 0,
108 0, -1./4, -1./8, 1./4, 1./8, 0,
109 0, 1./4, 0, -5./16, 0, 1./16
112 constexpr
float B_4x4_3x3[] = {
114 0, 1./4, -1./4, 1./4, -1./4, 1./4,
115 -5./16, -1./4, -1./4, -1./8, -1./8, 0,
116 0, -1./16, 1./16, -1./4, 1./4, -5./16,
117 1./16, 1./16, 1./16, 1./8, 1./8, 0,
121 constexpr
float AT_4x4_3x3[] = {
122 1./8, 1./4, 1./4, 1./8, 1./8, 0,
123 0, -1./4, 1./4, -1./4, 1./4, 0,
124 0, 1./4, 1./4, 1./2, 1./2, 0,
125 0, -1./4, 1./4, -1, 1, 1./2
128 constexpr
float A_4x4_3x3[] = {
130 1./4, -1./4, 1./4, -1./4,
131 1./4, 1./4, 1./4, 1./4,
132 1./8, -1./4, 1./2, -1,
137 constexpr
float G_2x2_5x5[] = {
139 1./6, -1./6, 1./6, -1./6, 1./6,
140 -1./6, -1./6, -1./6, -1./6, -1./6,
141 -4./15, 2./15, -1./15, 1./30, -1./60,
142 1./60, 1./30, 1./15, 2./15, 4./15,
146 constexpr
float GT_2x2_5x5[] = {
147 1, 1./6, -1./6, -4./15, 1./60, 0,
148 0, -1./6, -1./6, 2./15, 1./30, 0,
149 0, 1./6, -1./6, -1./15, 1./15, 0,
150 0, -1./6, -1./6, 1./30, 2./15, 0,
151 0, 1./6, -1./6, -1./60, 4./15, 1
154 constexpr
float BT_2x2_5x5[] = {
155 1./8, 3./16, -1./4, -3./16, 1./8, 0,
156 0, 1./8, 1./16, -5./16, 1./8, 0,
157 0, -1./8, -5./16, -1./16, 1./8, 0,
158 0, 1./4, -1./8, -1./4, 1./8, 0,
159 0, -1./8, -1./4, 1./8, 1./4, 0,
160 0, 1./8, 3./16, -1./4, -3./16, 1./8
163 constexpr
float B_2x2_5x5[] = {
165 3./16, 1./8, -1./8, 1./4, -1./8, 1./8,
166 -1./4, 1./16, -5./16, -1./8, -1./4, 3./16,
167 -3./16, -5./16, -1./16, -1./4, 1./8, -1./4,
168 1./8, 1./8, 1./8, 1./8, 1./4, -3./16,
172 constexpr
float AT_2x2_5x5[] = {
174 0, -1, 1, -1, 2, 1./2
177 constexpr
float A_2x2_5x5[] = {
188 struct TransformMatrix {
189 TransformMatrix(ArrayRef<float>
table, int64_t
rows, int64_t
cols,
200 Value create2DTransformMatrix(OpBuilder &builder, Location loc,
201 TransformMatrix transform, Type type) {
202 assert(transform.table.size() ==
203 static_cast<size_t>(transform.rows * transform.cols));
204 assert(type.isFloat() &&
"Only floats are supported by Winograd");
205 ArrayRef<float> constVec(transform.table.data(),
206 transform.rows * transform.cols);
208 llvm::map_to_vector<>(constVec, [&](
const float v) -> Attribute {
209 return builder.getFloatAttr(type, v);
211 SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
212 return arith::ConstantOp::create(
219 Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source,
220 Value loopNorFIndex, Value loopCorFIndex,
221 Value heightOffset, Value widthOffset,
222 int64_t extractHeight, int64_t extractWidth,
223 int64_t loopNorFIdx, int64_t loopCorFIdx,
224 int64_t heightIdx, int64_t widthIdx) {
225 auto sourceType = cast<ShapedType>(source.getType());
226 Type elementType = sourceType.getElementType();
227 int64_t srcSize = sourceType.getRank();
229 auto oneIndex = builder.getIndexAttr(1);
230 SmallVector<OpFoldResult> offsets;
231 offsets.resize(srcSize);
232 offsets[loopNorFIdx] = loopNorFIndex;
233 offsets[loopCorFIdx] = loopCorFIndex;
234 offsets[heightIdx] = heightOffset;
235 offsets[widthIdx] = widthOffset;
236 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
237 sizes[heightIdx] = builder.getIndexAttr(extractHeight);
238 sizes[widthIdx] = builder.getIndexAttr(extractWidth);
239 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
241 auto extractFilterType =
243 auto extractFilterOp = tensor::ExtractSliceOp::create(
244 builder, loc, extractFilterType, source, offsets, sizes, strides);
246 return extractFilterOp;
250 Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source,
251 Value tileHIndex, Value tileWIndex,
252 Value loopNorFIndex, Value loopCorFIndex,
253 int64_t tileHIdx, int64_t tileWIdx,
254 int64_t loopNorFIdx, int64_t loopCorFIdx,
255 int64_t heightIdx, int64_t widthIdx) {
256 auto sourceType = cast<ShapedType>(source.getType());
257 Type elementType = sourceType.getElementType();
258 auto sourceShape = sourceType.getShape();
259 int64_t srcSize = sourceType.getRank();
260 int64_t height = sourceShape[heightIdx];
261 int64_t width = sourceShape[widthIdx];
263 auto zeroIndex = builder.getIndexAttr(0);
264 auto oneIndex = builder.getIndexAttr(1);
265 SmallVector<OpFoldResult> offsets(srcSize, zeroIndex);
266 offsets.resize(srcSize);
267 offsets[tileHIdx] = tileHIndex;
268 offsets[tileWIdx] = tileWIndex;
269 offsets[loopNorFIdx] = loopNorFIndex;
270 offsets[loopCorFIdx] = loopCorFIndex;
271 SmallVector<OpFoldResult> sizes(srcSize, oneIndex);
272 sizes[heightIdx] = builder.getIndexAttr(height);
273 sizes[widthIdx] = builder.getIndexAttr(width);
274 SmallVector<OpFoldResult> strides(srcSize, oneIndex);
277 auto extractFilterOp = tensor::ExtractSliceOp::create(
278 builder, loc, extractFilterType, source, offsets, sizes, strides);
280 return extractFilterOp;
285 Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source,
286 Value dest, Value loopNorFIndex, Value loopCorFIndex,
287 Value heightOffset, Value widthOffset, int64_t height,
288 int64_t width, int64_t loopNorFIdx, int64_t loopCorFIdx,
289 int64_t heightIdx, int64_t widthIdx) {
290 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
291 auto oneIndex = builder.getIndexAttr(1);
292 SmallVector<OpFoldResult> retOffsets;
293 retOffsets.resize(destSize);
294 retOffsets[loopNorFIdx] = loopNorFIndex;
295 retOffsets[loopCorFIdx] = loopCorFIndex;
296 retOffsets[heightIdx] = heightOffset;
297 retOffsets[widthIdx] = widthOffset;
298 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
299 retSizes[heightIdx] = builder.getIndexAttr(height);
300 retSizes[widthIdx] = builder.getIndexAttr(width);
301 SmallVector<OpFoldResult> strides(destSize, oneIndex);
303 auto insertSliceOp = tensor::InsertSliceOp::create(
304 builder, loc, source, dest, retOffsets, retSizes, strides);
306 return insertSliceOp;
311 Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
312 Value dest, Value tileHIndex, Value tileWIndex,
313 Value loopNorFIndex, Value loopCorFIndex, int64_t height,
314 int64_t width, int64_t tileHIdx, int64_t tileWIdx,
315 int64_t loopNorFIdx, int64_t loopCorFIdx,
316 int64_t heightIdx, int64_t widthIdx) {
317 int64_t destSize = cast<ShapedType>(dest.getType()).getRank();
318 auto zeroIndex = builder.getIndexAttr(0);
319 auto oneIndex = builder.getIndexAttr(1);
320 SmallVector<OpFoldResult> retOffsets(destSize, zeroIndex);
321 retOffsets.resize(destSize);
322 retOffsets[tileHIdx] = tileHIndex;
323 retOffsets[tileWIdx] = tileWIndex;
324 retOffsets[loopNorFIdx] = loopNorFIndex;
325 retOffsets[loopCorFIdx] = loopCorFIndex;
326 SmallVector<OpFoldResult> retSizes(destSize, oneIndex);
327 retSizes[heightIdx] = builder.getIndexAttr(height);
328 retSizes[widthIdx] = builder.getIndexAttr(width);
329 SmallVector<OpFoldResult> strides(destSize, oneIndex);
331 auto insertSliceOp = tensor::InsertSliceOp::create(
332 builder, loc, source, dest, retOffsets, retSizes, strides);
334 return insertSliceOp;
348 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
349 Value retValue, WinogradConv2DFmr fmr,
350 bool leftTransform =
true,
bool rightTransform =
true) {
352 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
354 {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
355 {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
356 {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
360 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
362 {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
363 {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
364 {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
367 auto filterType = cast<ShapedType>(filter.getType());
368 Type elementType = filterType.getElementType();
369 auto filterShape = filterType.getShape();
370 int64_t filterF = filterShape[0];
371 int64_t filterH = filterShape[1];
372 int64_t filterW = filterShape[2];
373 int64_t filterC = filterShape[3];
377 if (filterH != r && filterH != 1)
379 if (filterW != r && filterW != 1)
383 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
385 Value FIter = ivs[0];
386 Value CIter = ivs[1];
390 extract2DDataFrom4D(builder, loc, filter, FIter, CIter, zeroIdx,
391 zeroIdx, filterH, filterW, 0,
395 Value matmulRetValue = extractFilter;
396 Value zero = arith::ConstantOp::create(builder, loc,
397 rewriter.getZeroAttr(elementType));
400 auto it = GMatrices.find(fmr);
401 if (it == GMatrices.end())
403 const TransformMatrix &GMatrix = it->second;
405 retRows = GMatrix.rows;
407 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
411 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
413 Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
415 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
416 ValueRange{G, extractFilter},
418 matmulRetValue = matmulOp.getResult(0);
421 if (rightTransform) {
423 auto it = GTMatrices.find(fmr);
424 if (it == GTMatrices.end())
426 const TransformMatrix >Matrix = it->second;
430 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
434 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
436 Value
GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
438 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
439 ValueRange{matmulRetValue,
GT},
441 matmulRetValue = matmulOp.getResult(0);
445 int64_t retHeight = leftTransform ? m + r - 1 : 1;
446 int64_t retWidth = rightTransform ? m + r - 1 : 1;
449 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], FIter, CIter,
450 zeroIdx, zeroIdx, retHeight, retWidth,
454 return {insertSliceOp};
461 rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound},
462 {oneStep, oneStep}, {retValue}, buildBody);
483 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
484 Value retValue, WinogradConv2DFmr fmr,
485 bool leftTransform =
true,
bool rightTransform =
true) {
487 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
489 {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
490 {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
491 {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
495 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
497 {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
498 {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
499 {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
504 auto inputType = cast<ShapedType>(input.getType());
505 Type elementType = inputType.getElementType();
506 auto inputShape = inputType.getShape();
507 int64_t inputN = inputShape[0];
508 int64_t inputC = inputShape[3];
509 auto valueType = cast<ShapedType>(retValue.getType());
510 auto valueShape = valueType.getShape();
511 int64_t tileH = valueShape[2];
512 int64_t tileW = valueShape[3];
513 int64_t alphaH = leftTransform ? m + r - 1 : 1;
514 int64_t alphaW = rightTransform ? m + r - 1 : 1;
516 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
518 Value tileHIter = ivs[0];
519 Value tileWIter = ivs[1];
520 Value NIter = ivs[2];
521 Value CIter = ivs[3];
523 auto *context = builder.getContext();
525 auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
528 Value heightOffset = affine::AffineApplyOp::create(
529 builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
530 Value widthOffset = affine::AffineApplyOp::create(
531 builder, loc, rightTransform ? affineMap : identityAffineMap,
536 extract2DDataFrom4D(builder, loc, input, NIter, CIter, heightOffset,
537 widthOffset, alphaH, alphaW, 0,
542 Value matmulRetValue = extractInput;
543 Value zero = arith::ConstantOp::create(builder, loc,
544 rewriter.getZeroAttr(elementType));
547 auto it = BTMatrices.find(fmr);
548 if (it == BTMatrices.end())
550 const TransformMatrix &BTMatrix = it->second;
552 retRows = BTMatrix.rows;
554 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
558 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
560 Value BT = create2DTransformMatrix(builder, loc, BTMatrix, elementType);
562 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
563 ValueRange{BT, matmulRetValue},
565 matmulRetValue = matmulOp.getResult(0);
568 if (rightTransform) {
570 auto it = BMatrices.find(fmr);
571 if (it == BMatrices.end())
573 const TransformMatrix &BMatrix = it->second;
575 retCols = BMatrix.cols;
577 auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(),
581 linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
582 Value
B = create2DTransformMatrix(builder, loc, BMatrix, elementType);
584 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
585 ValueRange{matmulRetValue,
B},
587 matmulRetValue = matmulOp.getResult(0);
591 auto combinedVal = insert2DDataTo6D(
592 builder, loc, matmulRetValue, args[0], tileHIter, tileWIter, NIter,
593 CIter, retRows, retCols, 2, 3, 4, 5,
596 return {combinedVal};
606 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
607 {tileHBound, tileWBound, nUpperBound, cUpperBound},
608 {oneStep, oneStep, oneStep, oneStep}, {retValue}, buildBody);
630 static Value matrixMultiply(RewriterBase &rewriter, Location loc,
631 Value transformedFilter, Value transformedInput,
632 Type outputElementType) {
634 auto filterType = cast<ShapedType>(transformedFilter.getType());
635 assert(filterType.hasStaticShape() &&
"only support static shapes.");
636 ArrayRef<int64_t> filterShape = filterType.getShape();
637 Type filterElementType = filterType.getElementType();
639 {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]},
641 SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}};
642 Value collapseFilter = tensor::CollapseShapeOp::create(
643 rewriter, loc, filterReassocType, transformedFilter, filterReassoc);
647 auto inputType = cast<ShapedType>(transformedInput.getType());
648 assert(inputType.hasStaticShape() &&
"only support static shapes.");
649 ArrayRef<int64_t> inputShape = inputType.getShape();
650 Type inputElementType = inputType.getElementType();
652 {inputShape[0] * inputShape[1],
653 inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]},
655 SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}};
656 Value collapseInput = tensor::CollapseShapeOp::create(
657 rewriter, loc, inputReassocType, transformedInput, inputReassoc);
661 {inputShape[0] * inputShape[1],
662 inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
664 Value empty = tensor::EmptyOp::create(rewriter, loc, matmulType.getShape(),
667 Value zero = arith::ConstantOp::create(
668 rewriter, loc, rewriter.getZeroAttr(outputElementType));
669 Value init = linalg::FillOp::create(rewriter, loc, zero, empty).getResult(0);
671 auto matmulOp = linalg::BatchMatmulOp::create(
672 rewriter, loc, matmulType, ValueRange({collapseInput, collapseFilter}),
677 SmallVector<ReassociationIndices> outputReassoc = {{0, 1}, {2, 3, 4}, {5}};
678 auto outputReassocType =
680 inputShape[3], inputShape[4], filterShape[3]},
682 auto expandOutput = tensor::ExpandShapeOp::create(
683 rewriter, loc, outputReassocType, matmulOp.getResult(0), outputReassoc);
704 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
705 Value output, WinogradConv2DFmr fmr,
706 bool leftTransform =
true,
bool rightTransform =
true) {
708 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
710 {WinogradConv2DFmr::F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
711 {WinogradConv2DFmr::F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
712 {WinogradConv2DFmr::F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
716 static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
718 {WinogradConv2DFmr::F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
719 {WinogradConv2DFmr::F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
720 {WinogradConv2DFmr::F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
725 auto valueType = cast<ShapedType>(value.getType());
726 Type elementType = valueType.getElementType();
727 auto valueShape = valueType.getShape();
728 int64_t valueH = valueShape[0];
729 int64_t valueW = valueShape[1];
730 int64_t valueN = valueShape[4];
731 int64_t valueF = valueShape[5];
732 int64_t alphaH = leftTransform ? m + r - 1 : 1;
733 int64_t alphaW = rightTransform ? m + r - 1 : 1;
735 if (valueH != alphaH && valueH != 1)
737 if (valueW != alphaW && valueW != 1)
740 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
742 auto *context = builder.getContext();
743 Value tileHIter = ivs[0];
744 Value tileWIter = ivs[1];
745 Value NIter = ivs[2];
746 Value FIter = ivs[3];
750 extract2DDataFrom6D(builder, loc, value, tileHIter, tileWIter, NIter,
754 const TransformMatrix &AMatrix = AMatrices.at(fmr);
755 const TransformMatrix &ATMatrix = ATMatrices.at(fmr);
756 int64_t
scalarFactor = (rightTransform ? AMatrix.scalarFactor : 1) *
757 (leftTransform ? ATMatrix.scalarFactor : 1);
758 int64_t retCols = rightTransform ? AMatrix.cols : 1;
759 int64_t retRows = leftTransform ? ATMatrix.rows : 1;
761 Value matmulRetValue = extractValue;
762 Value zero = arith::ConstantOp::create(builder, loc,
763 rewriter.getZeroAttr(elementType));
765 auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
768 Value heightOffset = affine::AffineApplyOp::create(
769 builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
770 Value widthOffset = affine::AffineApplyOp::create(
771 builder, loc, rightTransform ? affineMap : identityAffineMap,
775 extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
776 widthOffset, retRows, retCols,
782 Value init = outInitVal;
784 auto empty = tensor::EmptyOp::create(builder, loc,
785 matmulType.getShape(), elementType)
787 init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
790 Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
792 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
793 ValueRange{AT, matmulRetValue},
795 matmulRetValue = matmulOp.getResult(0);
798 if (rightTransform) {
801 Value init = outInitVal;
803 auto empty = tensor::EmptyOp::create(builder, loc,
804 matmulType.getShape(), elementType)
806 init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
809 Value
A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
811 auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
812 ValueRange{matmulRetValue,
A},
814 matmulRetValue = matmulOp.getResult(0);
819 Value scalarFactorValue = arith::ConstantOp::create(
822 auto identityAffineMap = rewriter.getMultiDimIdentityMap(2);
823 SmallVector<AffineMap> affineMaps = {
824 AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap};
827 linalg::GenericOp::create(
828 rewriter, loc, matmulType,
829 ValueRange{scalarFactorValue, matmulRetValue},
830 ValueRange{outInitVal}, affineMaps,
832 utils::IteratorType::parallel, utils::IteratorType::parallel},
833 [&](OpBuilder &nestedBuilder, Location nestedLoc,
835 auto mulf = arith::MulFOp::create(nestedBuilder, nestedLoc,
837 auto addf = arith::AddFOp::create(nestedBuilder, nestedLoc,
838 mulf.getResult(), args[2]);
839 linalg::YieldOp::create(nestedBuilder, nestedLoc,
847 insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
848 heightOffset, widthOffset, retRows, retCols,
853 return {combinedVal};
856 int64_t tilwH = valueShape[2];
857 int64_t tileW = valueShape[3];
865 rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx},
866 {tileHBound, tileWBound, nUpperBound, fUpperBound},
867 {oneStep, oneStep, oneStep, oneStep}, {output}, buildBody);
873 static Value padToAlignedTensor(RewriterBase &rewriter, Location loc,
874 Value value, ArrayRef<int64_t> alignedShape) {
875 auto valueType = cast<ShapedType>(value.getType());
876 Type elementType = valueType.getElementType();
878 Value padValue = arith::ConstantOp::create(rewriter, loc, elementType,
879 rewriter.getZeroAttr(elementType));
886 static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc,
888 RankedTensorType extractedType) {
889 OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
890 OpFoldResult oneIndex = rewriter.getIndexAttr(1);
891 SmallVector<OpFoldResult, 4> offsets(4, zeroIndex);
892 SmallVector<OpFoldResult, 4> strides(4, oneIndex);
894 ArrayRef<int64_t> extractedShape = extractedType.getShape();
895 SmallVector<OpFoldResult> sizes =
898 return tensor::ExtractSliceOp::create(rewriter, loc, extractedType, value,
899 offsets, sizes, strides);
905 attr, [](
const APInt &element) {
return element.getSExtValue() == 1; });
910 static FailureOr<Operation *>
911 winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
912 WinogradConv2DFmr fmr) {
913 if (!convOp.hasPureTensorSemantics())
914 return rewriter.notifyMatchFailure(
915 convOp,
"expected pure tensor semantics for linalg.conv_2d_nhwc_fhwc");
917 Value input = convOp.getInputs()[0];
918 Value filter = convOp.getInputs()[1];
919 Value output = convOp.getOutputs()[0];
920 auto inputType = cast<ShapedType>(input.getType());
921 auto filterType = cast<ShapedType>(filter.getType());
922 auto outputType = cast<ShapedType>(output.getType());
924 if (!inputType.hasStaticShape())
925 return rewriter.notifyMatchFailure(convOp,
926 "expected a static shape for the input");
928 if (!filterType.hasStaticShape())
929 return rewriter.notifyMatchFailure(
930 convOp,
"expected a static shape for the filter");
933 return rewriter.notifyMatchFailure(convOp,
934 "expected all ones for dilations");
937 return rewriter.notifyMatchFailure(convOp,
"expected all ones for strides");
939 ArrayRef<int64_t> filterShape = filterType.getShape();
940 int64_t filterF = filterShape[0];
941 int64_t filterH = filterShape[1];
942 int64_t filterW = filterShape[2];
943 int64_t filterC = filterShape[3];
944 ArrayRef<int64_t> inputShape = inputType.getShape();
945 int64_t inputN = inputShape[0];
946 int64_t inputH = inputShape[1];
947 int64_t inputW = inputShape[2];
948 int64_t inputC = inputShape[3];
949 ArrayRef<int64_t> outputShape = outputType.getShape();
950 int64_t outputN = outputShape[0];
951 int64_t outputH = outputShape[1];
952 int64_t outputW = outputShape[2];
953 int64_t outputF = outputShape[3];
958 bool isSupportedFilter =
false;
959 if (filterH == filterW && filterH == r)
960 isSupportedFilter =
true;
961 if (filterH == r && filterW == 1)
962 isSupportedFilter =
true;
963 if (filterH == 1 && filterW == r)
964 isSupportedFilter =
true;
966 if (!isSupportedFilter)
967 return rewriter.notifyMatchFailure(
968 convOp,
"only support filter (r x r), (r x 1) or (1 x r)");
971 Location loc = convOp.getLoc();
974 bool leftTransform = filterH != 1;
976 bool rightTransform = filterW != 1;
977 int64_t heightM = leftTransform ? m : 1;
978 int64_t widthM = rightTransform ? m : 1;
979 int64_t heightR = leftTransform ? r : 1;
980 int64_t widthR = rightTransform ? r : 1;
983 Type filterElementType = filterType.getElementType();
984 int64_t alphaH = heightM + heightR - 1;
985 int64_t alphaW = widthM + widthR - 1;
986 int64_t tileH = llvm::divideCeilSigned(outputH, heightM);
987 int64_t tileW = llvm::divideCeilSigned(outputW, widthM);
990 Value retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
992 auto transformedFilter = linalg::WinogradFilterTransformOp::create(
993 rewriter, loc, retType, filter, retValue, fmr);
999 Type inputElementType = inputType.getElementType();
1000 int64_t alignedInputH = tileH * heightM + (heightR - 1);
1001 int64_t alignedInputW = tileW * widthM + (widthR - 1);
1002 if (alignedInputH != inputH || alignedInputW != inputW) {
1003 input = padToAlignedTensor(rewriter, loc, input,
1004 {inputN, alignedInputH, alignedInputW, inputC});
1008 {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType);
1009 retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(),
1011 auto transformedInput = linalg::WinogradInputTransformOp::create(
1012 rewriter, loc, retType, input, retValue, fmr);
1014 Type outputElementType = outputType.getElementType();
1015 Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter,
1016 transformedInput, outputElementType);
1022 int64_t alignedOutputH = tileH * heightM;
1023 int64_t alignedOutputW = tileW * widthM;
1024 bool isOutputUnaligned =
1025 ((alignedOutputH != outputH) || (alignedOutputW != outputW));
1026 if (isOutputUnaligned) {
1028 {outputN, alignedOutputH, alignedOutputW, outputF}, outputElementType);
1030 padToAlignedTensor(rewriter, loc, output, alignedOutputType.getShape());
1031 outputType = alignedOutputType;
1034 Value transformedOutput = linalg::WinogradOutputTransformOp::create(
1035 rewriter, loc, outputType, matmulRet, output, fmr);
1039 if (isOutputUnaligned) {
1040 transformedOutput = extractFromAlignedTensor(
1041 rewriter, loc, transformedOutput,
1043 outputElementType));
1046 rewriter.replaceOp(convOp, transformedOutput);
1048 return transformedOutput.getDefiningOp();
1052 FailureOr<Operation *>
1053 decomposeWinogradFilterTransformHelper(RewriterBase &rewriter,
1054 linalg::WinogradFilterTransformOp op) {
1055 Location loc = op.getLoc();
1056 Value filter = op.getFilter();
1057 auto filterType = cast<ShapedType>(filter.getType());
1058 auto filterShape = filterType.getShape();
1059 int64_t filterH = filterShape[1];
1060 int64_t filterW = filterShape[2];
1063 bool leftTransform = filterH != 1;
1065 bool rightTransform = filterW != 1;
1066 Value transformedFilter =
1067 filterTransform(rewriter, loc, filter, op.getOutput(), op.getFmr(),
1068 leftTransform, rightTransform);
1069 if (!transformedFilter)
1072 rewriter.replaceOp(op, transformedFilter);
1074 return transformedFilter.getDefiningOp();
1078 FailureOr<Operation *>
1079 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
1080 linalg::WinogradInputTransformOp op) {
1081 Location loc = op.getLoc();
1082 Value output = op.getOutput();
1083 auto outputType = cast<ShapedType>(output.getType());
1084 auto outputShape = outputType.getShape();
1086 int64_t outputH = outputShape[0];
1087 int64_t outputW = outputShape[1];
1090 bool leftTransform = outputH != 1;
1092 bool rightTransform = outputW != 1;
1093 Value transformedInput =
1094 inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getFmr(),
1095 leftTransform, rightTransform);
1096 if (!transformedInput)
1099 rewriter.replaceOp(op, transformedInput);
1101 return transformedInput.getDefiningOp();
1105 FailureOr<Operation *>
1106 decomposeWinogradOutputTransformHelper(RewriterBase &rewriter,
1107 linalg::WinogradOutputTransformOp op) {
1108 Location loc = op.getLoc();
1109 Value value = op.getValue();
1110 auto valueType = cast<ShapedType>(value.getType());
1111 auto valueShape = valueType.getShape();
1112 int64_t valueH = valueShape[0];
1113 int64_t valueW = valueShape[1];
1116 bool leftTransform = valueH != 1;
1118 bool rightTransform = valueW != 1;
1119 Value transformedOutput =
1120 outputTransform(rewriter, loc, value, op.getOutput(), op.getFmr(),
1121 leftTransform, rightTransform);
1122 if (!transformedOutput)
1125 rewriter.replaceOp(op, transformedOutput);
1127 return transformedOutput.getDefiningOp();
1131 class DecomposeWinogradFilterTransform final
1132 :
public OpRewritePattern<linalg::WinogradFilterTransformOp> {
1136 LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op,
1137 PatternRewriter &rewriter)
const override {
1138 return decomposeWinogradFilterTransformHelper(rewriter, op);
1143 class DecomposeWinogradInputTransform final
1144 :
public OpRewritePattern<linalg::WinogradInputTransformOp> {
1148 LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op,
1149 PatternRewriter &rewriter)
const override {
1150 return decomposeWinogradInputTransformHelper(rewriter, op);
1155 class DecomposeWinogradOutputTransform final
1156 :
public OpRewritePattern<linalg::WinogradOutputTransformOp> {
1160 LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op,
1161 PatternRewriter &rewriter)
const override {
1162 return decomposeWinogradOutputTransformHelper(rewriter, op);
1167 class WinogradConv2DNhwcFhwc final
1168 :
public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
1172 : OpRewritePattern(context), fmr(fmr) {}
1174 LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
1175 PatternRewriter &rewriter)
const override {
1176 if (
failed(winogradConv2DHelper(rewriter, convOp, fmr)))
1183 WinogradConv2DFmr fmr;
1190 linalg::Conv2DNhwcFhwcOp op,
1191 linalg::WinogradConv2DFmr fmr) {
1192 return winogradConv2DHelper(rewriter, op, fmr);
1195 FailureOr<Operation *>
1197 linalg::WinogradFilterTransformOp op) {
1198 return decomposeWinogradFilterTransformHelper(rewriter, op);
1201 FailureOr<Operation *>
1203 linalg::WinogradInputTransformOp op) {
1204 return decomposeWinogradInputTransformHelper(rewriter, op);
1207 FailureOr<Operation *>
1209 linalg::WinogradOutputTransformOp op) {
1210 return decomposeWinogradOutputTransformHelper(rewriter, op);
1214 WinogradConv2DFmr fmr) {
1217 patterns.insert<WinogradConv2DNhwcFhwc>(context, fmr);
1223 .insert<DecomposeWinogradFilterTransform, DecomposeWinogradInputTransform,
1224 DecomposeWinogradOutputTransform>(context);
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
MLIRContext is the top-level object for a collection of MLIR operations.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
FailureOr< Operation * > decomposeWinogradFilterTransformOp(RewriterBase &rewriter, linalg::WinogradFilterTransformOp op)
Rewrite linalg.winograd_filter_transform.
FailureOr< Operation * > decomposeWinogradOutputTransformOp(RewriterBase &rewriter, linalg::WinogradOutputTransformOp op)
Rewrite linalg.winograd_output_transform.
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, WinogradConv2DFmr fmr)
Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns)
Patterns to decompose Winograd operators.
FailureOr< Operation * > winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, WinogradConv2DFmr fmr)
Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm F(m x m, r x r).
static bool hasAllOneValues(DenseIntElementsAttr attr)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
FailureOr< Operation * > decomposeWinogradInputTransformOp(RewriterBase &rewriter, linalg::WinogradInputTransformOp op)
Rewrite linalg.winograd_input_transform.
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...