34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/Repeated.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/SmallBitVector.h"
38#include "llvm/ADT/SmallVectorExtras.h"
39#include "llvm/ADT/StringRef.h"
40#include "llvm/Support/Casting.h"
41#include "llvm/Support/MathExtras.h"
52 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
54 if (complex::ConstantOp::isBuildableWith(value, type))
55 return complex::ConstantOp::create(builder, loc, type,
56 llvm::cast<ArrayAttr>(value));
62 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
63 if (tensorType.isDynamicDim(dim))
64 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
71 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
73 for (
int64_t i = 0; i < tensorType.getRank(); ++i)
80 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
81 assert(tensorType &&
"expected tensor type");
85 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
87 return destOp.getTiedOpOperand(opResult)->get();
95 if (!tensorType.hasStaticShape()) {
103 for (
int64_t sz : tensorType.getShape())
104 mixedSizes.push_back(
b.getIndexAttr(sz));
109 tensor::EmptyOp::create(
b, loc, mixedSizes, tensorType.getElementType());
117 if (llvm::isa<TensorType>(opResult.getType())) {
119 if (failed(destination))
121 result.push_back(*destination);
128 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
129 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
130 return rtp1.getShape() == rtp2.getShape() &&
131 rtp1.getElementType() == rtp2.getElementType();
141 llvm::SmallBitVector droppedDims(mixedSizes.size());
142 int64_t shapePos = reducedShape.size() - 1;
144 for (
const auto &size : enumerate(llvm::reverse(mixedSizes))) {
145 size_t idx = mixedSizes.size() - size.index() - 1;
147 bool isStaticUnitSize =
148 isa<Attribute>(size.value()) &&
149 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
154 assert(isStaticUnitSize &&
"expected unit dim");
155 droppedDims.set(idx);
160 if (!isStaticUnitSize) {
166 if (reducedShape[shapePos] == 1) {
172 droppedDims.set(idx);
175 assert(shapePos < 0 &&
"dimension mismatch");
182static RankedTensorType
186 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
187 "incorrect number of dynamic sizes");
191 for (
int64_t i = 0, e = type.getRank(); i < e; ++i) {
192 if (type.isDynamicDim(i)) {
193 Value dynamicSize = dynamicSizes[ctr++];
195 if (cst.has_value()) {
197 if (cst.value() < 0) {
198 foldedDynamicSizes.push_back(dynamicSize);
201 staticShape[i] = *cst;
203 foldedDynamicSizes.push_back(dynamicSize);
208 return RankedTensorType::get(staticShape, type.getElementType(),
217 if (inputs.size() != 1 || outputs.size() != 1)
219 Type a = inputs.front(),
b = outputs.front();
220 auto aT = dyn_cast<TensorType>(a);
221 auto bT = dyn_cast<TensorType>(
b);
225 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
236 using OpRewritePattern<BitcastOp>::OpRewritePattern;
238 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
239 PatternRewriter &rewriter)
const final {
240 auto tensorBitcastOperand =
241 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
242 if (!tensorBitcastOperand)
245 auto resultType = cast<TensorType>(tensorBitcast.getType());
246 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
247 tensorBitcastOperand.getOperand());
256 results.
add<ChainedTensorBitcast>(context);
264 setNameFn(getResult(),
"cast");
270 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
271 auto targetType = llvm::dyn_cast<RankedTensorType>(
target);
274 if (!sourceType || !targetType)
278 if (sourceType.getElementType() != targetType.getElementType())
282 if (sourceType.getRank() != targetType.getRank())
286 if (sourceType.getEncoding() != targetType.getEncoding())
290 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
291 if (ShapedType::isStatic(std::get<0>(t)) &&
292 ShapedType::isDynamic(std::get<1>(t)))
328 castOp.getSource().getType());
361 if (llvm::isa<BlockArgument>(opOperand.get()))
363 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
364 return castOp && canFoldIntoConsumerOp(castOp);
371 newOperands.reserve(op->getNumOperands());
377 for (
OpOperand &opOperand : op->getOpOperands()) {
378 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
380 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
381 if (op.isDpsInit(&opOperand) &&
382 !llvm::isa<MemRefType>(newOperands.back().getType()))
383 newResTy[dpsInitIdx++] = newOperands.back().getType();
393 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
395 operand.set(castOp.getOperand());
403 if (inputs.size() != 1 || outputs.size() != 1)
405 Type a = inputs.front(),
b = outputs.front();
406 auto aT = llvm::dyn_cast<TensorType>(a);
407 auto bT = llvm::dyn_cast<TensorType>(
b);
411 if (aT.getElementType() != bT.getElementType())
428 if (rank != two.getRank())
433 for (
int64_t i = 0; i < rank; ++i) {
434 if (one.isDynamicDim(i)) {
435 join.push_back(two.getDimSize(i));
438 if (two.isDynamicDim(i)) {
439 join.push_back(one.getDimSize(i));
442 if (one.getDimSize(i) != two.getDimSize(i))
444 join.push_back(one.getDimSize(i));
454 using OpRewritePattern<CastOp>::OpRewritePattern;
456 LogicalResult matchAndRewrite(CastOp tensorCast,
457 PatternRewriter &rewriter)
const final {
458 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
460 if (!tensorCastOperand)
464 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
465 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
466 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
480 auto newJoin =
joinShapes(sourceType, resultType);
481 if (firstJoin != newJoin)
484 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
485 tensorCastOperand.getOperand());
503 using OpRewritePattern<CastOp>::OpRewritePattern;
505 LogicalResult matchAndRewrite(CastOp tensorCast,
506 PatternRewriter &rewriter)
const final {
507 auto extractOperand =
508 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
511 auto rankedResultType =
512 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
513 if (!rankedResultType)
517 rankedResultType.getShape() ==
518 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
522 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
524 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
526 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
527 if (dimMask && dimMask->count(i))
529 int64_t dim = rankedResultType.getShape()[dimIndex++];
530 if (ShapedType::isDynamic(dim))
532 sizes[i] = rewriter.getIndexAttr(dim);
535 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
536 tensorCast, rankedResultType, extractOperand.getSource(),
537 extractOperand.getMixedOffsets(), sizes,
538 extractOperand.getMixedStrides());
547 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
554RankedTensorType ConcatOp::inferResultType(
int64_t dim,
TypeRange inputTypes) {
555 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
557 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
558 int64_t concatRank = tensorTypes[0].getRank();
561 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
564 for (
int64_t i = 0, e = concatRank; i < e; ++i) {
568 for (
auto tensorType : tensorTypes)
573 for (
auto tensorType : tensorTypes)
576 sizes[dim] = concatSize.asInteger();
577 return RankedTensorType::get(sizes, tensorTypes[0].
getElementType());
582 FailureOr<RankedTensorType> resultType =
583 inferResultType(dim, inputs.
getTypes());
584 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
585 build(builder,
result, *resultType, dim, inputs);
588LogicalResult ConcatOp::verify() {
589 if (getInputs().size() < 1)
593 for (
auto input : getInputs())
594 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
596 RankedTensorType resultType = getResultType();
597 int64_t resultRank = getRank();
598 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
599 return type.getRank() != resultRank;
601 return emitOpError(
"rank of concatenated inputs must match result rank");
603 Type resultElementType = resultType.getElementType();
604 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
605 return type.getElementType() != resultElementType;
607 return emitOpError(
"inputs and result element type must match");
610 if (dim >= resultRank)
611 return emitOpError(
"concatenation dim must be less than the tensor rank");
614 for (
int64_t i = 0, e = resultRank; i < e; ++i) {
618 for (
auto tensorType : inputTypes) {
619 FailureOr<SaturatedInteger> maybeSize =
622 return emitOpError(
"static concatenation size mismatch along ")
623 <<
"non-concatenated dimension " << i;
629 for (
auto tensorType : inputTypes)
632 sizes[dim] = concatSize.asInteger();
633 auto inferredResultType =
636 for (
auto [inferredSize, actualSize] :
637 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
638 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
639 ShapedType::isDynamic(actualSize);
640 if (!hasDynamic && inferredSize != actualSize)
642 << resultType <<
"does not match inferred shape "
643 << inferredResultType <<
" static sizes";
649FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
650 size_t numInputs = getInputs().size();
651 uint64_t concatDim = getDim();
654 inputShapes.reserve(numInputs);
656 concatOffsets.reserve(numInputs);
663 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
667 outputShape = inputShape;
668 concatOffsets.push_back(zero);
670 concatOffsets.push_back(outputShape[concatDim]);
672 builder, loc, addExpr,
673 {outputShape[concatDim], inputShape[concatDim]});
675 inputShapes.emplace_back(std::move(inputShape));
685 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
686 offsets[concatDim] = concatOffsets[
index];
687 auto insertSlice = tensor::InsertSliceOp::create(
698ConcatOp::reifyResultShapes(
OpBuilder &builder,
702 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
704 Value init = inputs[0];
712 for (
int64_t i = 0; i < rank; ++i) {
715 if (!
getType().isDynamicDim(i)) {
717 }
else if (!inferredResultType.isDynamicDim(i)) {
720 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
722 reifiedReturnShapes[0][i] =
723 tensor::DimOp::create(builder, init.
getLoc(), init, i).getResult();
727 if (
getType().isDynamicDim(dim)) {
732 for (
auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
735 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
743 reifiedReturnShapes[0][dim] =
749void ConcatOp::getAsmResultNames(
751 setNameFn(getResult(),
"concat");
756 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
764 using OpRewritePattern<ConcatOp>::OpRewritePattern;
766 LogicalResult matchAndRewrite(ConcatOp concatOp,
767 PatternRewriter &rewriter)
const override {
768 if (concatOp.getInputs().size() != 1)
771 concatOp.getInputs()[0]);
796 using OpRewritePattern<ConcatOp>::OpRewritePattern;
798 LogicalResult matchAndRewrite(ConcatOp concatOp,
799 PatternRewriter &rewriter)
const override {
800 int64_t dim = concatOp.getDim();
801 RankedTensorType inferredResultType =
802 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
805 LogicalResult matched = failure();
808 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
809 for (
auto [operandIdx, operandType] :
810 llvm::enumerate(concatOp->getOperandTypes())) {
812 inferredOperandShape[dim] =
813 cast<RankedTensorType>(operandType).getDimSize(dim);
814 auto inferredOperandType = RankedTensorType::get(
815 inferredOperandShape, inferredResultType.getElementType());
823 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
824 concatOp.getOperand(operandIdx));
826 concatOp->setOperand(operandIdx, castOp->getResult(0));
850 using OpRewritePattern<ConcatOp>::OpRewritePattern;
852 LogicalResult matchAndRewrite(ConcatOp concatOp,
853 PatternRewriter &rewriter)
const override {
854 int64_t dim = concatOp.getDim();
855 RankedTensorType inferredResultType =
856 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
860 concatOp.getResultType())) {
865 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
866 concatOp->getOperands());
878 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
887 setNameFn(getResult(),
"dim");
892 auto loc =
result.location;
894 build(builder,
result, source, indexValue);
897std::optional<int64_t> DimOp::getConstantIndex() {
906 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
907 if (!rankedSourceType)
910 if (rankedSourceType.getRank() <= constantIndex)
918 setResultRange(getResult(),
924 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
929 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
936 if (indexVal < 0 || indexVal >= tensorType.getRank())
940 if (!tensorType.isDynamicDim(
index.getInt())) {
945 Operation *definingOp = getSource().getDefiningOp();
948 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
950 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
953 assert(ShapedType::isDynamic(resultType.getShape()[
index.getInt()]));
956 auto dynExtents = fromElements.getDynamicExtents().begin();
957 for (
auto dim : resultType.getShape().take_front(
index.getInt()))
958 if (ShapedType::isDynamic(dim))
961 return Value{*dynExtents};
965 unsigned unsignedIndex =
index.getValue().getZExtValue();
967 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
970 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
971 sliceOp.isDynamicSize(unsignedIndex)) {
972 return {sliceOp.getDynamicSize(unsignedIndex)};
986 using OpRewritePattern<DimOp>::OpRewritePattern;
988 LogicalResult matchAndRewrite(DimOp dimOp,
989 PatternRewriter &rewriter)
const override {
990 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
993 Value newSource = castOp.getOperand();
1002 using OpRewritePattern<DimOp>::OpRewritePattern;
1004 LogicalResult matchAndRewrite(DimOp dimOp,
1005 PatternRewriter &rewriter)
const override {
1006 auto source = dimOp.getSource();
1007 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1011 auto resultIndex = cast<OpResult>(source).getResultNumber();
1012 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1015 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1023 using OpRewritePattern<DimOp>::OpRewritePattern;
1025 LogicalResult matchAndRewrite(DimOp dim,
1026 PatternRewriter &rewriter)
const override {
1027 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1035 Location loc = dim.getLoc();
1037 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1038 if (extract.
getType() != dim.getType())
1040 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1049 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1059 assert(none_of(staticShape, ShapedType::isDynamic) &&
1060 "expected only static sizes");
1064void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1065 ArrayRef<int64_t> staticShape, Type elementType,
1066 ValueRange dynamicSizes, Attribute encoding) {
1067 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1068 build(builder,
result, tensorType, dynamicSizes);
1071void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1072 ArrayRef<OpFoldResult> sizes, Type elementType,
1073 Attribute encoding) {
1074 SmallVector<int64_t> staticShape;
1075 SmallVector<Value> dynamicSizes;
1077 build(builder,
result, staticShape, elementType, dynamicSizes, encoding);
1080LogicalResult EmptyOp::verify() {
1086EmptyOp::reifyResultShapes(OpBuilder &builder,
1088 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1090 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1091 if (
getType().isDynamicDim(i)) {
1100Value EmptyOp::getDynamicSize(
unsigned idx) {
1101 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1103 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1104 if (
getType().isDynamicDim(i))
1109SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1110 SmallVector<OpFoldResult>
result;
1114 if (ShapedType::isDynamic(dim)) {
1117 result.push_back(
b.getIndexAttr(dim));
1135struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1136 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1138 LogicalResult matchAndRewrite(EmptyOp op,
1139 PatternRewriter &rewriter)
const override {
1140 SmallVector<Value> foldedDynamicSizes;
1142 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1145 if (foldedTensorType == op.getType())
1148 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1149 foldedDynamicSizes);
1155struct FoldEmptyTensorWithDimOp :
public OpRewritePattern<DimOp> {
1156 using OpRewritePattern<DimOp>::OpRewritePattern;
1158 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1159 PatternRewriter &rewriter)
const override {
1160 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1161 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1162 if (!emptyTensorOp || !maybeConstantIndex)
1164 auto emptyTensorType = emptyTensorOp.getType();
1165 if (*maybeConstantIndex < 0 ||
1166 *maybeConstantIndex >= emptyTensorType.getRank() ||
1167 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1170 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1190struct FoldEmptyTensorWithCastOp :
public OpRewritePattern<CastOp> {
1191 using OpRewritePattern<CastOp>::OpRewritePattern;
1193 LogicalResult matchAndRewrite(CastOp castOp,
1194 PatternRewriter &rewriter)
const override {
1197 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1202 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1203 ArrayRef<int64_t> resultShape = resultType.getShape();
1204 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1205 SmallVector<OpFoldResult> newMixedSizes;
1206 newMixedSizes.reserve(currMixedSizes.size());
1207 assert(resultShape.size() == currMixedSizes.size() &&
1208 "mismatch in result shape and sizes of empty op");
1209 for (
auto [newDim, currDim] : llvm::zip(resultShape, currMixedSizes)) {
1212 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1213 if (ShapedType::isDynamic(newDim) ||
1214 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1219 producer,
"mismatch in static value of shape of empty tensor "
1220 "result and cast result");
1222 newMixedSizes.push_back(attr);
1228 if (ShapedType::isStatic(newDim)) {
1229 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1235 newMixedSizes.push_back(currDim);
1239 resultType.getElementType(),
1240 resultType.getEncoding());
1247void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1248 MLIRContext *context) {
1249 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1250 ReplaceEmptyTensorStaticShapeDims>(context);
1267struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1268 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1270 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1271 PatternRewriter &rewriter)
const final {
1272 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1275 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1278 extract, tensorCast.getSource(), extract.getIndices());
1293struct ExtractFromCollapseShape :
public OpRewritePattern<tensor::ExtractOp> {
1294 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1296 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1297 PatternRewriter &rewriter)
const final {
1299 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1302 if (!collapseOp.getSrcType().hasStaticShape())
1305 auto sourceSizes = collapseOp.getSrcType().getShape();
1307 SmallVector<Value>
indices(extractOp.getIndices().begin(),
1308 extractOp.getIndices().end());
1309 SmallVector<Value> sourceIndices;
1310 for (
auto [index, group] :
1311 llvm::zip(
indices, collapseOp.getReassociationIndices())) {
1312 assert(!group.empty() &&
"association indices groups cannot be empty");
1313 auto groupSize = group.size();
1315 if (groupSize == 1) {
1316 sourceIndices.push_back(index);
1320 SmallVector<int64_t> basis =
1321 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
1322 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1323 rewriter, extractOp.getLoc(), index, basis,
true);
1324 llvm::append_range(sourceIndices,
delinearize.getResults());
1326 if (collapseOp.getReassociationIndices().empty()) {
1329 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1331 rewriter, extractOp.getLoc(), zeroAffineMap,
1332 ArrayRef<OpFoldResult>{});
1333 for (int64_t i = 0; i < srcRank; i++) {
1334 sourceIndices.push_back(
1340 extractOp, collapseOp.getSrc(), sourceIndices);
1347void ExtractOp::getAsmResultNames(
1349 setNameFn(getResult(),
"extracted");
1352LogicalResult ExtractOp::verify() {
1354 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1355 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1356 return emitOpError(
"incorrect number of indices for extract_element");
1365 auto insertOp = extractOp.getTensor().
getDefiningOp<InsertOp>();
1370 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1371 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1372 return insertOp.getScalar();
1377OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1378 if (Attribute tensor = adaptor.getTensor()) {
1381 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1382 return splatTensor.getSplatValue<Attribute>();
1385 if (isa<DenseResourceElementsAttr>(tensor))
1390 SmallVector<uint64_t, 8>
indices;
1391 for (Attribute indice : adaptor.getIndices()) {
1392 if (!indice || !llvm::isa<IntegerAttr>(indice))
1394 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1398 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1399 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1400 auto rank = tensorType.getRank();
1401 assert(
static_cast<int64_t
>(
indices.size()) == tensorType.getRank() &&
1405 for (
int i = rank - 1; i >= 0; --i) {
1406 flatIndex +=
indices[i] * stride;
1407 stride *= tensorType.getDimSize(i);
1411 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1414 return fromElementsOp.getElements()[flatIndex];
1418 if (Attribute tensor = adaptor.getTensor()) {
1419 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1420 if (elementsAttr && elementsAttr.isValidIndex(
indices))
1421 return elementsAttr.getValues<Attribute>()[
indices];
1430void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1431 MLIRContext *context) {
1432 results.
add<ExtractFromTensorCast>(context);
1437 patterns.
add<ExtractFromCollapseShape>(patterns.
getContext());
1444void FromElementsOp::getAsmResultNames(
1446 setNameFn(getResult(),
"from_elements");
1451 assert(!elements.empty() &&
"expected at least one element");
1452 Type resultType = RankedTensorType::get(
1453 {
static_cast<int64_t>(elements.size())}, elements.front().
getType());
1454 build(builder,
result, resultType, elements);
1457OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1462 Type eltType =
getType().getElementType();
1465 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1488struct ExtractElementFromIndexCast
1489 :
public OpRewritePattern<tensor::ExtractOp> {
1490 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1492 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1493 PatternRewriter &rewriter)
const final {
1494 Location loc = extract.getLoc();
1495 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1501 auto newExtract = tensor::ExtractOp::create(
1502 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1513void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1514 MLIRContext *context) {
1515 results.
add<ExtractElementFromIndexCast>(context);
1522void GatherOp::getAsmResultNames(
1524 setNameFn(getResult(),
"gather");
1539RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1540 RankedTensorType indicesType,
1541 ArrayRef<int64_t> gatherDims,
1543 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1544 resultShape.reserve(resultShape.size() + sourceType.getRank());
1545 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1546 if (llvm::binary_search(gatherDims, idx)) {
1548 resultShape.push_back(1);
1551 resultShape.push_back(sourceType.getDimSize(idx));
1553 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1559 StringRef gatherOrScatter, StringRef sourceOrDest) {
1561 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1563 int64_t numGatherDims = dims.size();
1564 if (numGatherDims > rank)
1566 <<
"_dims overflow " << sourceOrDest <<
" rank";
1569 <<
"_dims length must match the size of last dimension of indices";
1573 <<
"_dims value must be non-negative";
1576 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1578 for (
int64_t i = 1; i < numGatherDims; ++i) {
1579 if (dims[i - 1] >= dims[i])
1581 <<
"_dims values must be strictly increasing";
1586LogicalResult GatherOp::verify() {
1587 int64_t sourceRank = getSourceType().getRank();
1588 ArrayRef<int64_t> gatherDims = getGatherDims();
1590 getIndicesType().
getShape(), sourceRank,
1591 "gather",
"source")))
1594 RankedTensorType expectedResultType = GatherOp::inferResultType(
1595 getSourceType(), getIndicesType(), gatherDims,
false);
1596 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1597 getSourceType(), getIndicesType(), gatherDims,
true);
1598 if (getResultType() != expectedResultType &&
1599 getResultType() != expectedRankReducedResultType) {
1603 << expectedResultType <<
" or its rank-reduced variant "
1604 << expectedRankReducedResultType <<
" (got: " << getResultType()
1611OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1612 if (OpFoldResult reshapedSource = reshapeConstantSource(
1613 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1615 return reshapedSource;
1623void InsertOp::getAsmResultNames(
1625 setNameFn(getResult(),
"inserted");
1628LogicalResult InsertOp::verify() {
1630 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1631 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1632 return emitOpError(
"incorrect number of indices");
1636OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1637 Attribute scalar = adaptor.getScalar();
1638 Attribute dest = adaptor.getDest();
1640 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1641 if (scalar == splatDest.getSplatValue<Attribute>())
1650void GenerateOp::getAsmResultNames(
1652 setNameFn(getResult(),
"generated");
1655LogicalResult GenerateOp::reifyResultShapes(
1657 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1659 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1660 if (
getType().isDynamicDim(dim)) {
1661 reifiedReturnShapes[0][dim] = getOperand(idx++);
1663 reifiedReturnShapes[0][dim] =
1670LogicalResult GenerateOp::verify() {
1673 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1680LogicalResult GenerateOp::verifyRegions() {
1681 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1683 if (!llvm::all_of(getBody().getArgumentTypes(),
1684 [](Type ty) {
return ty.
isIndex(); }))
1685 return emitError(
"all body arguments must be index");
1686 if (getBody().getNumArguments() != resultTy.getRank())
1687 return emitError(
"must have one body argument per input dimension");
1690 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1692 if (yieldOp.getValue().getType() != resultTy.getElementType())
1694 "body must be terminated with a `yield` operation of the tensor "
1700void GenerateOp::build(
1701 OpBuilder &
b, OperationState &
result, Type resultTy,
1704 build(
b,
result, resultTy, dynamicExtents);
1707 OpBuilder::InsertionGuard guard(
b);
1708 Region *bodyRegion =
result.regions.front().get();
1709 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1710 SmallVector<Type, 2> argumentTypes(rank,
b.getIndexType());
1711 SmallVector<Location, 2> argumentLocs(rank,
result.location);
1713 b.createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1723struct StaticTensorGenerate :
public OpRewritePattern<GenerateOp> {
1724 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1726 LogicalResult matchAndRewrite(GenerateOp generateOp,
1727 PatternRewriter &rewriter)
const final {
1728 SmallVector<Value> foldedDynamicSizes;
1730 generateOp.getType(), generateOp.getDynamicExtents(),
1731 foldedDynamicSizes);
1734 if (foldedTensorType == generateOp.getType())
1737 auto loc = generateOp.getLoc();
1739 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1741 newOp.getBody().begin());
1743 generateOp.getType(), newOp);
1759struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1760 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1762 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1763 PatternRewriter &rewriter)
const final {
1764 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1769 Block *body = &tensorFromElements.getBody().front();
1772 rewriter.
clone(op, mapping);
1783void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1784 MLIRContext *context) {
1786 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1793void RankOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1794 setNameFn(getResult(),
"rank");
1797OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1799 auto type = getOperand().getType();
1800 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1801 if (shapedType && shapedType.hasRank())
1802 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
1803 return IntegerAttr();
1810void ReshapeOp::getAsmResultNames(
1812 setNameFn(getResult(),
"reshape");
1817 for (
auto dim : type.getShape())
1822LogicalResult ReshapeOp::verify() {
1823 TensorType operandType = llvm::cast<TensorType>(getSource().
getType());
1824 TensorType resultType = llvm::cast<TensorType>(getResult().
getType());
1827 return emitOpError(
"element types of source and destination tensor "
1828 "types should be the same");
1832 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1833 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1835 if (resultRankedType) {
1836 if (operandRankedType && resultRankedType.hasStaticShape() &&
1837 operandRankedType.hasStaticShape()) {
1839 return emitOpError(
"source and destination tensor should have the "
1840 "same number of elements");
1842 if (ShapedType::isDynamic(shapeSize))
1843 return emitOpError(
"cannot use shape operand with dynamic length to "
1844 "reshape to statically-ranked tensor type");
1845 if (shapeSize != resultRankedType.getRank())
1847 "length of shape operand differs from the result's tensor rank");
1852OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1853 if (OpFoldResult reshapedSource = reshapeConstantSource(
1854 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1856 return reshapedSource;
1861 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1862 getSourceMutable().assign(reshapeOpProducer.getSource());
1866 auto source = getSource();
1867 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1868 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1869 if (!sourceTy || !resultTy || sourceTy != resultTy)
1874 if (sourceTy.getRank() <= 1)
1877 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1878 auto elements = fromElements.getElements();
1880 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1881 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++
id) {
1882 auto element = elements[id];
1885 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1889 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1890 dynamicNoop &= dimOp.getSource() == source;
1894 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1898 dynamicNoop =
false;
1913void CollapseShapeOp::getAsmResultNames(
1915 setNameFn(getResult(),
"collapsed");
1918void ExpandShapeOp::getAsmResultNames(
1920 setNameFn(getResult(),
"expanded");
1923int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1924 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1925 "invalid resultDim");
1926 for (
const auto &it : llvm::enumerate(getReassociationIndices()))
1927 if (llvm::is_contained(it.value(), resultDim))
1929 llvm_unreachable(
"could not find reassociation group");
1932FailureOr<SmallVector<OpFoldResult>>
1933ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
1934 RankedTensorType expandedType,
1935 ArrayRef<ReassociationIndices> reassociation,
1936 ArrayRef<OpFoldResult> inputShape) {
1937 std::optional<SmallVector<OpFoldResult>> outputShape =
1942 return *outputShape;
1945SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1949void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1950 Type resultType, Value src,
1951 ArrayRef<ReassociationIndices> reassociation,
1952 ArrayRef<OpFoldResult> outputShape) {
1953 auto [staticOutputShape, dynamicOutputShape] =
1955 build(builder,
result, cast<RankedTensorType>(resultType), src,
1957 dynamicOutputShape, staticOutputShape);
1960void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1961 Type resultType, Value src,
1962 ArrayRef<ReassociationIndices> reassociation) {
1963 SmallVector<OpFoldResult> inputShape =
1965 auto tensorResultTy = cast<RankedTensorType>(resultType);
1966 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1967 builder,
result.location, tensorResultTy, reassociation, inputShape);
1968 SmallVector<OpFoldResult> outputShapeOrEmpty;
1969 if (succeeded(outputShape)) {
1970 outputShapeOrEmpty = *outputShape;
1972 build(builder,
result, tensorResultTy, src, reassociation,
1973 outputShapeOrEmpty);
1976SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1979SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1981 getReassociationIndices());
1984SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1987SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1989 getReassociationIndices());
1992RankedTensorType CollapseShapeOp::inferCollapsedType(
1993 RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
1994 return inferCollapsedType(
1996 type.getContext(), reassociation)));
2002CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2003 ArrayRef<AffineMap> reassociation) {
2004 auto shape = type.getShape();
2005 SmallVector<int64_t, 4> newShape;
2006 newShape.reserve(reassociation.size());
2011 unsigned currentDim = 0;
2012 for (AffineMap m : reassociation) {
2013 unsigned dim = m.getNumResults();
2014 auto band = shape.slice(currentDim, dim);
2016 if (llvm::is_contained(band, ShapedType::kDynamic))
2017 size = ShapedType::kDynamic;
2019 for (
unsigned d = 0; d < dim; ++d)
2020 size *= shape[currentDim + d];
2021 newShape.push_back(size);
2025 return RankedTensorType::get(newShape, type.getElementType());
2028void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2029 ArrayRef<ReassociationIndices> reassociation,
2030 ArrayRef<NamedAttribute> attrs) {
2031 auto srcType = llvm::cast<RankedTensorType>(src.
getType());
2032 RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2034 RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2035 srcType.getEncoding());
2036 result.addAttribute(getReassociationAttrStrName(),
2038 build(
b,
result, resultType, src, attrs);
2041template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
2042 TensorReshapeOp, ExpandShapeOp>::value>
2044 RankedTensorType expandedType,
2045 RankedTensorType collapsedType) {
2047 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2051 if (expandedType.hasStaticShape() && collapsedType.hasStaticShape()) {
2052 int64_t expandedNumElements = expandedType.getNumElements();
2053 int64_t collapsedNumElements = collapsedType.getNumElements();
2054 if (expandedNumElements != collapsedNumElements) {
2055 return op.emitOpError(
"number of elements must be preserved: ")
2056 << expandedNumElements <<
" != " << collapsedNumElements;
2060 auto maps = op.getReassociationMaps();
2061 RankedTensorType expectedType =
2062 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2064 return op.emitOpError(
"expected collapsed type to be ")
2065 << expectedType <<
", but got " << collapsedType;
2069LogicalResult ExpandShapeOp::verify() {
2070 RankedTensorType srcType = getSrc().getType();
2071 RankedTensorType resultType = getResult().getType();
2073 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2074 return emitOpError(
"expected number of static shape dims to be equal to "
2075 "the output rank (")
2076 << resultType.getRank() <<
") but found "
2077 << getStaticOutputShape().size() <<
" inputs instead";
2079 if ((int64_t)getOutputShape().size() !=
2080 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2081 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2082 "static_output_shape: static_output_shape has ")
2083 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2084 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2095 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2096 for (
auto [pos, shape] : llvm::enumerate(resShape))
2097 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos])
2098 return emitOpError(
"invalid output shape provided at pos ") << pos;
2103LogicalResult CollapseShapeOp::verify() {
2104 CollapseShapeOp op = *
this;
2105 if (llvm::any_of(op.getReassociationIndices(),
2107 return op.emitOpError(
"reassociation indices must not be empty");
2109 RankedTensorType srcType = op.getSrc().getType();
2110 RankedTensorType resultType = op.getResult().getType();
2118template <
typename TensorReshapeOp>
2119struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2120 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2121 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2122 PatternRewriter &rewriter)
const override {
2123 DenseElementsAttr attr;
2130 if (!reshapeOp.getResultType().hasStaticShape())
2133 reshapeOp.getResultType(), attr.
getRawData());
2140template <
typename TensorReshapeOp>
2141class FoldReshapeWithSplat :
public OpRewritePattern<TensorReshapeOp> {
2143 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2145 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2146 PatternRewriter &rewriter)
const override {
2147 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2148 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2152 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2159template <
typename TensorReshapeOp>
2160struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2161 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2162 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2163 PatternRewriter &rewriter)
const override {
2165 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2169 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2171 if (!shapedTy.hasStaticShape())
2175 fromElements.getElements());
2181struct FoldCollapseOfCastOp :
public OpRewritePattern<CollapseShapeOp> {
2182 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2184 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2185 PatternRewriter &rewriter)
const override {
2186 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2190 RankedTensorType srcType =
2191 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2192 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2193 srcType, collapseShapeOp.getReassociationMaps());
2195 if (newResultType == collapseShapeOp.getResultType()) {
2197 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2200 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2201 newResultType, castOp.getSource(),
2202 collapseShapeOp.getReassociation());
2204 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2214struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2215 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2217 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2218 PatternRewriter &rewriter)
const override {
2219 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2223 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2224 SmallVector<ReassociationIndices, 4> reassoc =
2225 expandOp.getReassociationIndices();
2227 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2228 SmallVector<Value> dynamicOutputShape;
2229 auto outputIt = expandOp.getOutputShape().begin();
2231 for (
const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2232 for (uint64_t outDim : innerReassoc) {
2233 if (ShapedType::isStatic(newOutputShape[outDim]))
2240 Value val = *outputIt;
2242 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2243 dynamicOutputShape.push_back(val);
2249 newOutputShape[outDim] = cst.getSExtValue();
2251 dynamicOutputShape.push_back(val);
2257 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2261 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2262 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2263 for (
auto outDim : reassoc[inDim]) {
2264 auto ofr = newOutputShape[outDim];
2265 if (ShapedType::isDynamic(ofr)) {
2266 newInputShape[inDim] = ShapedType::kDynamic;
2269 newInputShape[inDim] *= ofr;
2273 SmallVector<OpFoldResult> outputOfr =
2275 auto inputType = RankedTensorType::get(
2276 newInputShape, expandOp.getSrcType().getElementType());
2277 auto outputType = RankedTensorType::get(
2278 newOutputShape, expandOp.getSrcType().getElementType());
2279 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2281 auto newExpand = ExpandShapeOp::create(
2282 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2283 expandOp.getReassociationIndices(), outputOfr);
2285 newExpand.getResult());
2291void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2292 MLIRContext *context) {
2294 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2295 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2296 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2297 FoldReshapeWithSplat<ExpandShapeOp>,
2298 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2301void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2302 MLIRContext *context) {
2304 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2305 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2306 tensor::DimOp, RankedTensorType>,
2307 FoldReshapeWithConstant<CollapseShapeOp>,
2308 FoldReshapeWithSplat<CollapseShapeOp>,
2309 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2313OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2315 adaptor.getOperands());
2318OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2320 adaptor.getOperands());
2327void ExtractSliceOp::getAsmResultNames(
2329 setNameFn(getResult(),
"extracted_slice");
2336ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2337 ArrayRef<int64_t> staticSizes) {
2341 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2342 sourceTensorType.getRank() &&
2343 "unexpected staticSizes not equal to rank of source");
2344 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2345 sourceTensorType.getEncoding());
2350ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2351 ArrayRef<OpFoldResult> sizes) {
2352 SmallVector<int64_t> staticSizes;
2355 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2356 sourceTensorType.getRank() &&
2357 "unexpected staticSizes not equal to rank of source");
2358 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2359 sourceTensorType.getEncoding());
2370RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2371 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2372 ArrayRef<int64_t> sizes) {
2374 auto inferredType = llvm::cast<RankedTensorType>(
2375 inferResultType(sourceRankedTensorType, sizes));
2376 int rankDiff = inferredType.getRank() - desiredResultRank;
2378 auto shape = inferredType.getShape();
2379 llvm::SmallBitVector dimsToProject =
2381 SmallVector<int64_t> projectedShape;
2383 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2384 if (!dimsToProject.test(pos))
2385 projectedShape.push_back(shape[pos]);
2387 RankedTensorType::get(projectedShape, inferredType.getElementType());
2389 return inferredType;
2392RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2393 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2394 ArrayRef<OpFoldResult> sizes) {
2395 SmallVector<int64_t> staticSizes;
2396 SmallVector<Value> dynamicSizes;
2398 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2399 desiredResultRank, sourceRankedTensorType, staticSizes);
2404void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2405 RankedTensorType resultType, Value source,
2406 ArrayRef<OpFoldResult> offsets,
2407 ArrayRef<OpFoldResult> sizes,
2408 ArrayRef<OpFoldResult> strides,
2409 ArrayRef<NamedAttribute> attrs) {
2410 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2411 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2415 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2418 resultType = llvm::cast<RankedTensorType>(
2419 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2421 result.addAttributes(attrs);
2422 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
2423 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2424 b.getDenseI64ArrayAttr(staticSizes),
2425 b.getDenseI64ArrayAttr(staticStrides));
2430void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2431 ArrayRef<OpFoldResult> offsets,
2432 ArrayRef<OpFoldResult> sizes,
2433 ArrayRef<OpFoldResult> strides,
2434 ArrayRef<NamedAttribute> attrs) {
2435 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2440void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2441 ArrayRef<Range> ranges,
2442 ArrayRef<NamedAttribute> attrs) {
2444 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2449void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2450 RankedTensorType resultType, Value source,
2452 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2453 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2454 offsets, [](Value v) -> OpFoldResult {
return v; });
2455 SmallVector<OpFoldResult> sizeValues =
2456 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
2457 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2458 strides, [](Value v) -> OpFoldResult {
return v; });
2459 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
2463void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2465 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2466 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2471 RankedTensorType expectedType) {
2476 return op->
emitError(
"expected rank to be smaller or equal to ")
2477 <<
"the other rank. ";
2479 return op->
emitError(
"expected type to be ")
2480 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2482 return op->
emitError(
"expected element type to be ")
2483 << expectedType.getElementType();
2485 llvm_unreachable(
"unexpected extract_slice op verification result");
2491void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2492 RankedTensorType resultType, Value source,
2493 ArrayRef<OpFoldResult> sizes,
2494 ArrayRef<NamedAttribute> attrs) {
2495 Attribute zeroIdxAttr =
b.getIndexAttr(0);
2496 Attribute oneIdxAttr =
b.getIndexAttr(1);
2497 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2498 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2499 build(
b,
result, resultType, source, readOffsets, sizes, readStrides, attrs);
2503LogicalResult ExtractSliceOp::verify() {
2504 RankedTensorType sourceType = getSourceType();
2507 RankedTensorType expectedType =
2508 ExtractSliceOp::inferResultType(sourceType,
getMixedSizes());
2516 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2517 getStaticStrides(),
true);
2519 return getOperation()->emitError(boundsResult.
errorMessage);
2524llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2529ExtractSliceOp::rankReduceIfNeeded(OpBuilder &
b, Location loc, Value value,
2530 ArrayRef<int64_t> desiredShape) {
2531 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2532 assert(sourceTensorType &&
"not a ranked tensor type");
2533 auto sourceShape = sourceTensorType.getShape();
2534 if (sourceShape.equals(desiredShape))
2536 auto maybeRankReductionMask =
2538 if (!maybeRankReductionMask)
2542 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2545LogicalResult ExtractSliceOp::reifyResultShapes(
2547 reifiedReturnShapes.resize(1);
2548 reifiedReturnShapes[0].reserve(
getType().getRank());
2551 for (
const auto &size :
enumerate(mixedSizes)) {
2552 if (droppedDims.test(size.index()))
2554 reifiedReturnShapes[0].push_back(size.value());
2575class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2577 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2579 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2580 PatternRewriter &rewriter)
const override {
2582 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2583 return matchPattern(operand, matchConstantIndex());
2587 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2596 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2597 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2598 sliceOp.getStaticStrides());
2603 Location loc = sliceOp.getLoc();
2604 Value newResult = ExtractSliceOp::create(
2605 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2606 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2607 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2608 sliceOp.getStaticStrides());
2617template <
typename IterTy,
typename ElemTy>
2618static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2619 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2620 ArrayRef<int64_t> strides,
2621 llvm::SmallVectorImpl<ElemTy> *outValues) {
2622 assert(offsets.size() == sizes.size());
2623 assert(offsets.size() == strides.size());
2624 if (offsets.empty())
2627 int64_t offset = offsets.front();
2628 int64_t size = sizes.front();
2629 int64_t stride = strides.front();
2630 if (offsets.size() == 1) {
2631 for (int64_t i = 0; i < size; ++i, offset += stride)
2632 outValues->push_back(*(values + offset));
2637 for (int64_t i = 0; i < size; ++i, offset += stride) {
2638 auto begin = values + offset * counts.front();
2639 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2640 offsets.drop_front(), sizes.drop_front(),
2641 strides.drop_front(), outValues);
2648class ConstantOpExtractSliceFolder final
2649 :
public OpRewritePattern<ExtractSliceOp> {
2651 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2653 ConstantOpExtractSliceFolder(MLIRContext *context,
2655 : OpRewritePattern<ExtractSliceOp>(context),
2656 controlFn(std::move(controlFn)) {}
2658 LogicalResult matchAndRewrite(ExtractSliceOp op,
2659 PatternRewriter &rewriter)
const override {
2660 DenseElementsAttr attr;
2669 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2670 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2671 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2678 int64_t count = sourceType.getNumElements();
2683 auto offsets = op.getStaticOffsets();
2684 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2686 auto sizes = op.getStaticSizes();
2687 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2689 auto strides = op.getStaticStrides();
2690 if (llvm::is_contained(strides, ShapedType::kDynamic))
2694 SmallVector<int64_t> counts;
2695 ArrayRef<int64_t> shape = sourceType.getShape();
2696 counts.reserve(shape.size());
2697 for (int64_t v : shape) {
2699 counts.push_back(count);
2703 SmallVector<Attribute> outValues;
2704 outValues.reserve(resultType.getNumElements());
2705 sliceElements(attr.
value_begin<Attribute>(), counts, offsets, sizes,
2706 strides, &outValues);
2723 patterns.
add<ConstantOpExtractSliceFolder>(patterns.
getContext(), controlFn);
2733 RankedTensorType nonReducedType =
2734 ExtractSliceOp::inferResultType(op.getSourceType(), mixedSizes);
2738 llvm::SmallBitVector droppedDims = op.getDroppedDims();
2739 if (droppedDims.none())
2740 return nonReducedType;
2744 for (
auto i : llvm::seq<int64_t>(mixedSizes.size()))
2745 if (!droppedDims.test(i))
2746 targetShape.push_back(nonReducedType.getDimSize(i));
2748 return RankedTensorType::get(targetShape, nonReducedType.getElementType(),
2749 nonReducedType.getEncoding());
2756 ExtractSliceOp newOp) {
2759 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2765void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2766 MLIRContext *context) {
2768 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2769 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2770 ExtractSliceOpCastFolder>(context);
2776 ShapedType shapedType) {
2783 auto shape = shapedType.getShape();
2784 for (
auto it : llvm::zip(op.getMixedSizes(),
shape))
2798 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2801 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2802 insertOp.isSameAs(extractOp, isSame))
2803 return insertOp.getSource();
2808OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2809 if (OpFoldResult reshapedSource = reshapeConstantSource(
2810 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2812 return reshapedSource;
2813 if (getSourceType() ==
getType() &&
2815 return this->getSource();
2819 return OpFoldResult();
2824 auto rankedTensorType = llvm::cast<RankedTensorType>(
tensor.getType());
2825 unsigned rank = rankedTensorType.getRank();
2829 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType,
tensor,
2830 offsets, sizes, strides);
2837void InsertSliceOp::getAsmResultNames(
2839 setNameFn(getResult(),
"inserted_slice");
2853 result.addAttributes(attrs);
2854 build(
b,
result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2855 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2856 b.getDenseI64ArrayAttr(staticSizes),
2857 b.getDenseI64ArrayAttr(staticStrides));
2862void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2863 Value dest, ArrayRef<Range> ranges,
2864 ArrayRef<NamedAttribute> attrs) {
2866 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
2870void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2872 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2873 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2874 offsets, [](Value v) -> OpFoldResult {
return v; });
2875 SmallVector<OpFoldResult> sizeValues =
2876 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
2877 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2878 strides, [](Value v) -> OpFoldResult {
return v; });
2879 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
2885 RankedTensorType srcType, RankedTensorType dstType,
2890 RankedTensorType expected =
2891 ExtractSliceOp::inferResultType(dstType, staticSizes);
2893 *expectedType = expected;
2898LogicalResult InsertSliceOp::verify() {
2900 RankedTensorType expectedType;
2903 getStaticSizes(), getStaticStrides(), &expectedType);
2910 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2911 getStaticStrides(),
true);
2913 return getOperation()->emitError(boundsResult.
errorMessage);
2936 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2939 if (!prevInsertOp ||
2940 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2941 !prevInsertOp.isSameAs(insertOp, isSame))
2944 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2956 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2959 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2960 !extractOp.isSameAs(insertOp, isSame))
2963 return extractOp.getSource();
2966OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2967 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2968 getSourceType() ==
getType() &&
2970 return this->getSource();
2977 return OpFoldResult();
2980LogicalResult InsertSliceOp::reifyResultShapes(
2982 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
2991template <
typename InsertOpTy>
2992class InsertSliceOpConstantArgumentFolder final
2993 :
public OpRewritePattern<InsertOpTy> {
2995 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2997 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2998 PatternRewriter &rewriter)
const override {
2999 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
3000 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3001 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
3010 SliceBoundsVerificationResult sliceResult =
3012 mixedOffsets, mixedSizes, mixedStrides);
3017 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
3018 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
3020 Value toInsert = insertSliceOp.getSource();
3021 if (sourceType != insertSliceOp.getSourceType()) {
3022 OpBuilder::InsertionGuard g(rewriter);
3026 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
3028 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3029 sourceType, toInsert);
3032 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
3033 mixedSizes, mixedStrides);
3058template <
typename InsertOpTy>
3059struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
3060 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3062 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3063 PatternRewriter &rewriter)
const override {
3064 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3065 return matchPattern(operand, matchConstantIndex());
3069 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3072 return std::nullopt;
3073 return castOp.getSource();
3075 std::optional<Value> sourceCastSource =
3076 getSourceOfCastOp(insertSliceOp.getSource());
3077 std::optional<Value> destCastSource =
3078 getSourceOfCastOp(insertSliceOp.getDest());
3079 if (!sourceCastSource && !destCastSource)
3083 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3084 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3085 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
3086 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3087 if (!srcType || !dstType)
3093 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3095 staticSizes, srcType.getShape(),
true);
3096 if (!rankReductionMask.has_value())
3103 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3104 int64_t rankReducedIdx = 0;
3105 for (
auto [idx, size] :
enumerate(staticSizes)) {
3106 if (!rankReductionMask.value().contains(idx) &&
3107 !srcType.isDynamicDim(rankReducedIdx)) {
3109 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
3110 size = srcType.getDimSize(rankReducedIdx++);
3116 staticSizes, insertSliceOp.getStaticStrides()) !=
3117 SliceVerificationResult::Success)
3119 SliceBoundsVerificationResult sliceResult =
3121 mixedSizes, insertSliceOp.getMixedStrides());
3126 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3127 insertSliceOp.getMixedOffsets(), mixedSizes,
3128 insertSliceOp.getMixedStrides());
3131 bool isParallelInsert =
3132 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3133 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3134 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3135 insertSliceOp.getDestType(),
3164template <
typename InsertOpTy>
3165struct InsertSliceOpSourceCastInserter final
3166 :
public OpRewritePattern<InsertOpTy> {
3167 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3169 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3170 PatternRewriter &rewriter)
const override {
3171 RankedTensorType srcType = insertSliceOp.getSourceType();
3172 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3174 SmallVector<int64_t> newSrcShape(srcType.getShape());
3175 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3176 if (std::optional<int64_t> constInt =
3181 newSrcShape[i] = *constInt;
3187 RankedTensorType newSrcType = RankedTensorType::get(
3188 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3189 if (srcType == newSrcType ||
3191 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3199 OpBuilder::InsertionGuard g(rewriter);
3203 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3205 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3206 newSrcType, insertSliceOp.getSource());
3208 insertSliceOp, cast, insertSliceOp.getDest(),
3209 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3210 insertSliceOp.getMixedStrides());
3216llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3220void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3221 MLIRContext *context) {
3222 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3223 InsertSliceOpCastFolder<InsertSliceOp>,
3224 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3231 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3232 unsigned rank = rankedTensorType.getRank();
3236 return b.createOrFold<tensor::InsertSliceOp>(loc,
tensor, dest, offsets,
3245 setNameFn(getResult(),
"padded");
3248LogicalResult PadOp::verify() {
3249 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3250 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3252 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3253 if (!expectedType) {
3254 return emitError(
"failed to infer expectedType from sourceType ")
3255 << sourceType <<
", specified resultType is " << resultType;
3257 if (resultType.getRank() != expectedType.getRank()) {
3259 << resultType <<
" does not match the inferred type "
3262 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3263 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3265 if (expectedType.isDynamicDim(i))
3268 << resultType <<
" does not match the inferred type "
3275LogicalResult PadOp::verifyRegions() {
3276 auto ®ion = getRegion();
3277 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3278 Block &block = region.front();
3280 return emitError(
"expected the block to have ") << rank <<
" arguments";
3284 if (!en.value().isIndex())
3286 << (en.index() + 1) <<
" to be an index";
3291 if (yieldOp.getValue().getType() !=
3293 return emitOpError(
"expected yield type to match shape element type");
3298RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3299 ArrayRef<int64_t> staticLow,
3300 ArrayRef<int64_t> staticHigh,
3301 ArrayRef<int64_t> resultShape) {
3302 unsigned rank = sourceType.getRank();
3303 if (staticLow.size() != rank)
3304 return RankedTensorType();
3305 if (staticHigh.size() != rank)
3306 return RankedTensorType();
3307 if (!resultShape.empty() && resultShape.size() != rank)
3308 return RankedTensorType();
3310 SmallVector<int64_t, 4> inferredShape;
3311 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3312 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3313 staticHigh[i] == ShapedType::kDynamic) {
3314 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3317 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3318 assert((resultShape.empty() || size == resultShape[i] ||
3319 resultShape[i] == ShapedType::kDynamic) &&
3320 "mismatch between inferred shape and result shape");
3321 inferredShape.push_back(size);
3325 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3328void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3329 Value source, ArrayRef<int64_t> staticLow,
3331 bool nofold, ArrayRef<NamedAttribute> attrs) {
3332 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3334 resultType = inferResultType(sourceType, staticLow, staticHigh);
3335 result.addAttributes(attrs);
3336 build(
b,
result, resultType, source, low, high,
3337 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3338 nofold ?
b.getUnitAttr() : UnitAttr());
3341void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3343 ArrayRef<NamedAttribute> attrs) {
3344 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3345 unsigned rank = sourceType.getRank();
3346 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3347 build(
b,
result, resultType, source, staticVector, staticVector, low, high,
3351void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3352 Value source, ArrayRef<OpFoldResult> low,
3353 ArrayRef<OpFoldResult> high,
bool nofold,
3354 ArrayRef<NamedAttribute> attrs) {
3355 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3356 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3357 SmallVector<int64_t, 4> staticLow, staticHigh;
3365 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3367 assert(llvm::isa<RankedTensorType>(resultType));
3368 result.addAttributes(attrs);
3369 build(
b,
result, resultType, source, dynamicLow, dynamicHigh,
3370 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3371 nofold ?
b.getUnitAttr() : UnitAttr());
3374void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3375 Value source, ArrayRef<OpFoldResult> low,
3376 ArrayRef<OpFoldResult> high, Value constantPadValue,
3377 bool nofold, ArrayRef<NamedAttribute> attrs) {
3378 build(
b,
result, resultType, source, low, high, nofold, attrs);
3381 Region *region =
result.regions[0].get();
3382 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3383 Repeated<Type> blockArgTypes(sourceRank,
b.getIndexType());
3384 SmallVector<Location> blockArgLocs(sourceRank,
result.location);
3388 OpBuilder::InsertionGuard guard(
b);
3389 b.createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3390 tensor::YieldOp::create(
b,
result.location, constantPadValue);
3393llvm::SmallBitVector PadOp::getPaddedDims() {
3394 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3395 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3396 for (
const auto &en :
enumerate(paddingWidths))
3398 paddedDims.set(en.index());
3400 extractPaddedDims(getMixedLowPad());
3401 extractPaddedDims(getMixedHighPad());
3408struct FoldStaticZeroPadding :
public OpRewritePattern<PadOp> {
3409 using OpRewritePattern<PadOp>::OpRewritePattern;
3411 LogicalResult matchAndRewrite(PadOp padTensorOp,
3412 PatternRewriter &rewriter)
const override {
3413 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3415 if (padTensorOp.getNofold())
3418 padTensorOp, padTensorOp.getResult().
getType(),
3419 padTensorOp.getSource());
3425struct FoldSourceTensorCast :
public OpRewritePattern<PadOp> {
3426 using OpRewritePattern<PadOp>::OpRewritePattern;
3428 LogicalResult matchAndRewrite(PadOp padTensorOp,
3429 PatternRewriter &rewriter)
const override {
3430 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3434 auto newResultType = PadOp::inferResultType(
3435 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3436 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3437 padTensorOp.getResultType().getShape());
3439 if (newResultType == padTensorOp.getResultType()) {
3441 padTensorOp.getSourceMutable().assign(castOp.getSource());
3444 auto newOp = PadOp::create(
3445 rewriter, padTensorOp->getLoc(), newResultType,
3446 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3447 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3448 padTensorOp.getHigh(), padTensorOp.getNofold(),
3451 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3454 padTensorOp, padTensorOp.getResultType(), newOp);
3462struct FoldTargetTensorCast :
public OpRewritePattern<PadOp> {
3463 using OpRewritePattern<PadOp>::OpRewritePattern;
3465 LogicalResult matchAndRewrite(PadOp padTensorOp,
3466 PatternRewriter &rewriter)
const override {
3467 if (!padTensorOp.getResult().hasOneUse())
3470 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3474 tensorCastOp.getDest().getType()))
3477 auto replacementOp = PadOp::create(
3478 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3479 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3480 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3481 padTensorOp.getHigh(), padTensorOp.getNofold(),
3483 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3485 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3486 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3526struct FoldOrthogonalPaddings :
public OpRewritePattern<PadOp> {
3527 using OpRewritePattern<PadOp>::OpRewritePattern;
3529 LogicalResult matchAndRewrite(PadOp padOp,
3530 PatternRewriter &rewriter)
const override {
3531 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3534 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3535 if (!outerPadOp || outerPadOp.getNofold())
3537 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3542 int64_t rank = padOp.getSourceType().getRank();
3543 if (outerSliceOp.getSourceType().getRank() != rank) {
3545 "cannot fold rank-reducing chain");
3549 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3551 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3555 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3557 "cannot fold PadOps with low padding");
3561 Attribute innerAttr, outerAttr;
3562 Value innerValue = padOp.getConstantPaddingValue();
3563 Value outerValue = outerPadOp.getConstantPaddingValue();
3564 if (!innerValue || !outerValue ||
3567 innerAttr != outerAttr) {
3569 padOp,
"cannot fold PadOps with different padding values");
3573 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3574 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3575 if (innerDims.anyCommon(outerDims)) {
3577 padOp,
"cannot fold PadOps with common padding dimensions");
3585 SmallVector<OpFoldResult> newOffsets(rank, rewriter.
getIndexAttr(0));
3587 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3588 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3589 if (!innerDims.test(en.index()) &&
3591 en.value() = outerOffset;
3594 if (!outerDims.test(en.index()) &&
3596 en.value() = innerOffset;
3600 padOp,
"cannot find zero-offset and zero-padding pair");
3608 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3610 if (!outerDims.test(en.index()))
3612 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3613 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3614 assert(ShapedType::isStatic(sourceSize) &&
3615 "expected padded dimension to have a static size");
3618 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3619 "match the size of the outer padding");
3621 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3625 SmallVector<OpFoldResult> newHighPad(rank, rewriter.
getIndexAttr(0));
3627 if (innerDims.test(en.index()))
3628 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3629 if (outerDims.test(en.index()))
3630 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3635 auto newSliceOp = ExtractSliceOp::create(
3636 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3637 newSizes, innerSliceOp.getMixedStrides());
3638 auto newPadOp = PadOp::create(
3639 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3640 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3643 newPadOp.getRegion().begin());
3644 rewriter.
replaceOp(padOp, newPadOp.getResult());
3649struct FoldStaticPadding :
public OpRewritePattern<PadOp> {
3650 using OpRewritePattern<PadOp>::OpRewritePattern;
3652 LogicalResult matchAndRewrite(PadOp padTensorOp,
3653 PatternRewriter &rewriter)
const override {
3654 Value input = padTensorOp.getSource();
3655 if (!llvm::isa<RankedTensorType>(input.
getType()))
3657 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3658 auto inputRank = inputDims.size();
3660 auto oldResultType =
3661 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3665 auto outputDims = oldResultType.getShape();
3668 SmallVector<int64_t> constOperandsLow;
3669 SmallVector<Value> newLows;
3670 for (
auto operand : padTensorOp.getLow()) {
3673 constOperandsLow.push_back(ShapedType::kDynamic);
3674 newLows.push_back(operand);
3677 constOperandsLow.push_back(intOp.getExtValue());
3679 SmallVector<int64_t> constOperandsHigh;
3680 SmallVector<Value> newHighs;
3681 for (
auto operand : padTensorOp.getHigh()) {
3684 constOperandsHigh.push_back(ShapedType::kDynamic);
3685 newHighs.push_back(operand);
3688 constOperandsHigh.push_back(intOp.getExtValue());
3691 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3692 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3695 if (inputDims.size() != outputDims.size() ||
3696 inputDims.size() != constLow.size() ||
3697 inputDims.size() != constHigh.size())
3702 for (
size_t i = 0; i < inputRank; i++) {
3703 if (constLow[i] == ShapedType::kDynamic)
3704 constLow[i] = constOperandsLow[lowCount++];
3705 if (constHigh[i] == ShapedType::kDynamic)
3706 constHigh[i] = constOperandsHigh[highCount++];
3709 auto staticLow = ArrayRef<int64_t>(constLow);
3710 auto staticHigh = ArrayRef<int64_t>(constHigh);
3713 SmallVector<int64_t> newOutDims;
3714 for (
size_t i = 0; i < inputRank; i++) {
3715 if (outputDims[i] == ShapedType::kDynamic) {
3716 newOutDims.push_back(
3717 (staticLow[i] == ShapedType::kDynamic ||
3718 staticHigh[i] == ShapedType::kDynamic ||
3719 inputDims[i] == ShapedType::kDynamic
3720 ? ShapedType::kDynamic
3721 : inputDims[i] + staticLow[i] + staticHigh[i]));
3723 newOutDims.push_back(outputDims[i]);
3727 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3728 llvm::all_of(newOutDims,
3729 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3733 auto newResultType = RankedTensorType::get(
3734 newOutDims, padTensorOp.getType().getElementType());
3735 auto newOp = PadOp::create(
3736 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3737 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3741 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3769struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3770 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3772 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3773 PatternRewriter &rewriter)
const override {
3774 if (padOp.getNofold()) {
3778 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3779 if (!producerPad || producerPad.getNofold()) {
3781 padOp,
"producer is not a foldable tensor.pad op");
3785 Value consumerPadValue = padOp.getConstantPaddingValue();
3786 Value producerPadValue = producerPad.getConstantPaddingValue();
3787 if (!consumerPadValue || !producerPadValue ||
3788 consumerPadValue != producerPadValue) {
3791 "cannot fold PadOps with different or non-constant padding values");
3794 Location loc = padOp.getLoc();
3799 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3800 ArrayRef<OpFoldResult> producerPaddings) {
3801 SmallVector<OpFoldResult> sumPaddings;
3802 for (
auto [consumerIndex, producerIndex] :
3803 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3805 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3810 SmallVector<OpFoldResult> newHighPad =
3811 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3812 SmallVector<OpFoldResult> newLowPad =
3813 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3815 auto newPadOp = tensor::PadOp::create(
3816 rewriter, padOp.getLoc(), padOp.getResultType(),
3817 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3820 newPadOp.getRegion().begin());
3821 rewriter.
replaceOp(padOp, newPadOp.getResult());
3829PadOp::reifyResultShapes(OpBuilder &
b,
3831 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
3832 SmallVector<OpFoldResult> lp = getMixedLowPad();
3833 SmallVector<OpFoldResult> hp = getMixedHighPad();
3834 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3835 if (!
getType().isDynamicDim(i)) {
3836 reifiedReturnShapes[0][i] =
b.getIndexAttr(
getType().getDimSize(i));
3839 Location loc = getLoc();
3840 Value dim =
b.createOrFold<tensor::DimOp>(
3843 AffineExpr d0, d1, d2;
3846 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3851void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3852 MLIRContext *context) {
3853 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3854 FoldOrthogonalPaddings, FoldStaticPadding,
3855 FoldConsecutiveConstantPadding>(context);
3867Value PadOp::getConstantPaddingValue() {
3868 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3871 Value padValue = yieldOp.getValue();
3882OpFoldResult PadOp::fold(FoldAdaptor) {
3883 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3893OpResult ParallelInsertSliceOp::getTiedOpResult() {
3894 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3895 for (
const auto &it :
3896 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3897 Operation &nextOp = it.value();
3898 if (&nextOp == getOperation())
3899 return parallelCombiningParent.getParentResult(it.index());
3901 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3905void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3906 Value source, Value dest,
3907 ArrayRef<OpFoldResult> offsets,
3908 ArrayRef<OpFoldResult> sizes,
3909 ArrayRef<OpFoldResult> strides,
3910 ArrayRef<NamedAttribute> attrs) {
3911 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3912 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3916 result.addAttributes(attrs);
3917 build(
b,
result, {}, source, dest, dynamicOffsets, dynamicSizes,
3918 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3919 b.getDenseI64ArrayAttr(staticSizes),
3920 b.getDenseI64ArrayAttr(staticStrides));
3925void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3926 Value source, Value dest,
3927 ArrayRef<Range> ranges,
3928 ArrayRef<NamedAttribute> attrs) {
3930 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
3934void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3935 Value source, Value dest,
ValueRange offsets,
3937 ArrayRef<NamedAttribute> attrs) {
3938 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3939 offsets, [](Value v) -> OpFoldResult {
return v; });
3940 SmallVector<OpFoldResult> sizeValues =
3941 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
3942 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3943 strides, [](Value v) -> OpFoldResult {
return v; });
3944 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
3949void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
3950 Value dest, ArrayRef<OpFoldResult> sizes,
3951 ArrayRef<NamedAttribute> attrs) {
3952 Attribute zeroIdxAttr =
b.getIndexAttr(0);
3953 Attribute oneIdxAttr =
b.getIndexAttr(1);
3954 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3955 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3956 build(
b,
result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3959LogicalResult ParallelInsertSliceOp::verify() {
3960 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3961 return this->
emitError(
"expected InParallelOpInterface parent, got:")
3962 << *(getOperation()->getParentOp());
3965 RankedTensorType expectedType;
3968 getStaticSizes(), getStaticStrides(), &expectedType);
3975 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3976 getStaticStrides(),
true);
3978 return getOperation()->emitError(boundsResult.
errorMessage);
3983void ParallelInsertSliceOp::getCanonicalizationPatterns(
3984 RewritePatternSet &results, MLIRContext *context) {
3985 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3986 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3987 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3990llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3995MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3996 return getDestMutable();
3999Operation *ParallelInsertSliceOp::getIteratingParent() {
4001 if (
auto combiningOp =
4002 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
4003 return combiningOp->getParentOp();
4011void ScatterOp::getAsmResultNames(
4013 setNameFn(getResult(),
"scatter");
4016LogicalResult ScatterOp::verify() {
4017 int64_t destRank = getDestType().getRank();
4018 ArrayRef<int64_t> scatterDims = getScatterDims();
4020 getIndicesType().
getShape(), destRank,
4021 "scatter",
"dest")))
4025 return emitOpError(
"requires 'unique' attribute to be set");
4032 RankedTensorType expectedSourceType = GatherOp::inferResultType(
4033 getDestType(), getIndicesType(), scatterDims,
false);
4034 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
4035 getDestType(), getIndicesType(), scatterDims,
true);
4036 if (getSourceType() != expectedSourceType &&
4037 getSourceType() != expectedRankReducedSourceType) {
4041 << expectedSourceType <<
" or its rank-reduced variant "
4042 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
4053void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4054 Type aggregateType,
ValueRange dynamicSizes) {
4055 build(builder,
result, aggregateType, element, dynamicSizes);
4058void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4059 ArrayRef<int64_t> staticShape,
ValueRange dynamicSizes) {
4060 auto aggregateType = RankedTensorType::get(staticShape, element.
getType());
4061 build(builder,
result, aggregateType, element, dynamicSizes);
4064void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4065 ArrayRef<OpFoldResult> sizes) {
4066 SmallVector<int64_t> staticShape;
4067 SmallVector<Value> dynamicSizes;
4069 build(builder,
result, element, staticShape, dynamicSizes);
4072void SplatOp::getAsmResultNames(
4074 setNameFn(getResult(),
"splat");
4077LogicalResult SplatOp::verify() {
4083SplatOp::reifyResultShapes(OpBuilder &builder,
4085 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
4087 for (int64_t i = 0; i <
getType().getRank(); ++i) {
4088 if (
getType().isDynamicDim(i)) {
4097OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4098 auto constOperand = adaptor.getInput();
4099 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4103 if (!
getType().hasStaticShape())
4118 if (isa<InsertSliceOp>(op.getOperation()) ||
4119 isa<LoopLikeOpInterface>(op.getOperation()))
4152 isa<linalg::RelayoutOpInterface>(*op))
4160 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4163 replacements.reserve(newOp->getNumResults());
4164 for (
auto [oldResult, newResult] :
4165 llvm::zip(op->getResults(), newOp->getResults())) {
4166 if (newResult.getType() != oldResult.getType()) {
4167 replacements.push_back(tensor::CastOp::create(
4168 rewriter, op->getLoc(), oldResult.
getType(), newResult));
4170 replacements.push_back(newResult);
4183void TensorDialect::getCanonicalizationPatterns(
4184 RewritePatternSet &results)
const {
4192#define GET_OP_CLASSES
4193#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
Attributes are known-constant values of operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
auto value_begin() const
Get an iterator of the given type to the start of the held element values.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.