33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/SmallBitVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/MathExtras.h"
49 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
51 if (complex::ConstantOp::isBuildableWith(value, type))
52 return complex::ConstantOp::create(builder, loc, type,
53 llvm::cast<ArrayAttr>(value));
59 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
60 if (tensorType.isDynamicDim(dim))
61 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
68 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
70 for (
int64_t i = 0; i < tensorType.getRank(); ++i)
77 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
78 assert(tensorType &&
"expected tensor type");
82 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
84 return destOp.getTiedOpOperand(opResult)->get();
92 if (!tensorType.hasStaticShape()) {
100 for (
int64_t sz : tensorType.getShape())
101 mixedSizes.push_back(
b.getIndexAttr(sz));
106 tensor::EmptyOp::create(
b, loc, mixedSizes, tensorType.getElementType());
114 if (llvm::isa<TensorType>(opResult.getType())) {
116 if (failed(destination))
118 result.push_back(*destination);
125 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127 return rtp1.getShape() == rtp2.getShape() &&
128 rtp1.getElementType() == rtp2.getElementType();
138 llvm::SmallBitVector droppedDims(mixedSizes.size());
139 int64_t shapePos = reducedShape.size() - 1;
141 for (
const auto &size : enumerate(llvm::reverse(mixedSizes))) {
142 size_t idx = mixedSizes.size() - size.index() - 1;
144 bool isStaticUnitSize =
145 isa<Attribute>(size.value()) &&
146 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
151 assert(isStaticUnitSize &&
"expected unit dim");
152 droppedDims.set(idx);
157 if (!isStaticUnitSize) {
163 if (reducedShape[shapePos] == 1) {
169 droppedDims.set(idx);
172 assert(shapePos < 0 &&
"dimension mismatch");
179static RankedTensorType
183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184 "incorrect number of dynamic sizes");
188 for (
int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
192 if (cst.has_value()) {
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
198 staticShape[i] = *cst;
200 foldedDynamicSizes.push_back(dynamicSize);
205 return RankedTensorType::get(staticShape, type.getElementType(),
214 if (inputs.size() != 1 || outputs.size() != 1)
216 Type a = inputs.front(),
b = outputs.front();
217 auto aT = dyn_cast<TensorType>(a);
218 auto bT = dyn_cast<TensorType>(
b);
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
233 using OpRewritePattern<BitcastOp>::OpRewritePattern;
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236 PatternRewriter &rewriter)
const final {
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239 if (!tensorBitcastOperand)
242 auto resultType = cast<TensorType>(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
253 results.
add<ChainedTensorBitcast>(context);
261 setNameFn(getResult(),
"cast");
267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268 auto targetType = llvm::dyn_cast<RankedTensorType>(
target);
271 if (!sourceType || !targetType)
275 if (sourceType.getElementType() != targetType.getElementType())
279 if (sourceType.getRank() != targetType.getRank())
283 if (sourceType.getEncoding() != targetType.getEncoding())
287 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (ShapedType::isStatic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
325 castOp.getSource().getType());
358 if (llvm::isa<BlockArgument>(opOperand.get()))
360 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
361 return castOp && canFoldIntoConsumerOp(castOp);
368 newOperands.reserve(op->getNumOperands());
374 for (
OpOperand &opOperand : op->getOpOperands()) {
375 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
377 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
378 if (op.isDpsInit(&opOperand) &&
379 !llvm::isa<MemRefType>(newOperands.back().getType()))
380 newResTy[dpsInitIdx++] = newOperands.back().getType();
390 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
392 operand.set(castOp.getOperand());
400 if (inputs.size() != 1 || outputs.size() != 1)
402 Type a = inputs.front(),
b = outputs.front();
403 auto aT = llvm::dyn_cast<TensorType>(a);
404 auto bT = llvm::dyn_cast<TensorType>(
b);
408 if (aT.getElementType() != bT.getElementType())
425 if (rank != two.getRank())
430 for (
int64_t i = 0; i < rank; ++i) {
431 if (one.isDynamicDim(i)) {
432 join.push_back(two.getDimSize(i));
435 if (two.isDynamicDim(i)) {
436 join.push_back(one.getDimSize(i));
439 if (one.getDimSize(i) != two.getDimSize(i))
441 join.push_back(one.getDimSize(i));
451 using OpRewritePattern<CastOp>::OpRewritePattern;
453 LogicalResult matchAndRewrite(CastOp tensorCast,
454 PatternRewriter &rewriter)
const final {
455 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
457 if (!tensorCastOperand)
461 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
462 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
463 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
477 auto newJoin =
joinShapes(sourceType, resultType);
478 if (firstJoin != newJoin)
481 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
482 tensorCastOperand.getOperand());
500 using OpRewritePattern<CastOp>::OpRewritePattern;
502 LogicalResult matchAndRewrite(CastOp tensorCast,
503 PatternRewriter &rewriter)
const final {
504 auto extractOperand =
505 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
508 auto rankedResultType =
509 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
510 if (!rankedResultType)
514 rankedResultType.getShape() ==
515 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
519 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
521 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
523 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
524 if (dimMask && dimMask->count(i))
526 int64_t dim = rankedResultType.getShape()[dimIndex++];
527 if (ShapedType::isDynamic(dim))
529 sizes[i] = rewriter.getIndexAttr(dim);
532 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
533 tensorCast, rankedResultType, extractOperand.getSource(),
534 extractOperand.getMixedOffsets(), sizes,
535 extractOperand.getMixedStrides());
544 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
551RankedTensorType ConcatOp::inferResultType(
int64_t dim,
TypeRange inputTypes) {
552 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
554 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
555 int64_t concatRank = tensorTypes[0].getRank();
558 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
561 for (
int64_t i = 0, e = concatRank; i < e; ++i) {
565 for (
auto tensorType : tensorTypes)
570 for (
auto tensorType : tensorTypes)
573 sizes[dim] = concatSize.asInteger();
574 return RankedTensorType::get(sizes, tensorTypes[0].
getElementType());
579 FailureOr<RankedTensorType> resultType =
580 inferResultType(dim, inputs.
getTypes());
581 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
582 build(builder,
result, *resultType, dim, inputs);
585LogicalResult ConcatOp::verify() {
586 if (getInputs().size() < 1)
590 for (
auto input : getInputs())
591 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
593 RankedTensorType resultType = getResultType();
594 int64_t resultRank = getRank();
595 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
596 return type.getRank() != resultRank;
598 return emitOpError(
"rank of concatenated inputs must match result rank");
600 Type resultElementType = resultType.getElementType();
601 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
602 return type.getElementType() != resultElementType;
604 return emitOpError(
"inputs and result element type must match");
607 if (dim >= resultRank)
608 return emitOpError(
"concatenation dim must be less than the tensor rank");
611 for (
int64_t i = 0, e = resultRank; i < e; ++i) {
615 for (
auto tensorType : inputTypes) {
616 FailureOr<SaturatedInteger> maybeSize =
619 return emitOpError(
"static concatenation size mismatch along ")
620 <<
"non-concatenated dimension " << i;
626 for (
auto tensorType : inputTypes)
629 sizes[dim] = concatSize.asInteger();
630 auto inferredResultType =
633 for (
auto [inferredSize, actualSize] :
634 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
635 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
636 ShapedType::isDynamic(actualSize);
637 if (!hasDynamic && inferredSize != actualSize)
639 << resultType <<
"does not match inferred shape "
640 << inferredResultType <<
" static sizes";
646FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
647 size_t numInputs = getInputs().size();
648 uint64_t concatDim = getDim();
651 inputShapes.reserve(numInputs);
653 concatOffsets.reserve(numInputs);
660 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
664 outputShape = inputShape;
665 concatOffsets.push_back(zero);
667 concatOffsets.push_back(outputShape[concatDim]);
669 builder, loc, addExpr,
670 {outputShape[concatDim], inputShape[concatDim]});
672 inputShapes.emplace_back(std::move(inputShape));
682 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
683 offsets[concatDim] = concatOffsets[
index];
684 auto insertSlice = tensor::InsertSliceOp::create(
695ConcatOp::reifyResultShapes(
OpBuilder &builder,
699 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
701 Value init = inputs[0];
709 for (
int64_t i = 0; i < rank; ++i) {
712 if (!
getType().isDynamicDim(i)) {
714 }
else if (!inferredResultType.isDynamicDim(i)) {
717 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
719 reifiedReturnShapes[0][i] =
720 tensor::DimOp::create(builder, init.
getLoc(), init, i).getResult();
724 if (
getType().isDynamicDim(dim)) {
729 for (
auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
732 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
740 reifiedReturnShapes[0][dim] =
746void ConcatOp::getAsmResultNames(
748 setNameFn(getResult(),
"concat");
753 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
761 using OpRewritePattern<ConcatOp>::OpRewritePattern;
763 LogicalResult matchAndRewrite(ConcatOp concatOp,
764 PatternRewriter &rewriter)
const override {
765 if (concatOp.getInputs().size() != 1)
768 concatOp.getInputs()[0]);
793 using OpRewritePattern<ConcatOp>::OpRewritePattern;
795 LogicalResult matchAndRewrite(ConcatOp concatOp,
796 PatternRewriter &rewriter)
const override {
797 int64_t dim = concatOp.getDim();
798 RankedTensorType inferredResultType =
799 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
802 LogicalResult matched = failure();
805 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
806 for (
auto [operandIdx, operandType] :
807 llvm::enumerate(concatOp->getOperandTypes())) {
809 inferredOperandShape[dim] =
810 cast<RankedTensorType>(operandType).getDimSize(dim);
811 auto inferredOperandType = RankedTensorType::get(
812 inferredOperandShape, inferredResultType.getElementType());
820 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
821 concatOp.getOperand(operandIdx));
823 concatOp->setOperand(operandIdx, castOp->getResult(0));
847 using OpRewritePattern<ConcatOp>::OpRewritePattern;
849 LogicalResult matchAndRewrite(ConcatOp concatOp,
850 PatternRewriter &rewriter)
const override {
851 int64_t dim = concatOp.getDim();
852 RankedTensorType inferredResultType =
853 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
857 concatOp.getResultType())) {
862 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
863 concatOp->getOperands());
875 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
884 setNameFn(getResult(),
"dim");
889 auto loc =
result.location;
891 build(builder,
result, source, indexValue);
894std::optional<int64_t> DimOp::getConstantIndex() {
903 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
904 if (!rankedSourceType)
907 if (rankedSourceType.getRank() <= constantIndex)
915 setResultRange(getResult(),
921 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
926 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
933 if (indexVal < 0 || indexVal >= tensorType.getRank())
937 if (!tensorType.isDynamicDim(
index.getInt())) {
942 Operation *definingOp = getSource().getDefiningOp();
945 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
947 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
950 assert(ShapedType::isDynamic(resultType.getShape()[
index.getInt()]));
953 auto dynExtents = fromElements.getDynamicExtents().begin();
954 for (
auto dim : resultType.getShape().take_front(
index.getInt()))
955 if (ShapedType::isDynamic(dim))
958 return Value{*dynExtents};
962 unsigned unsignedIndex =
index.getValue().getZExtValue();
964 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
967 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
968 sliceOp.isDynamicSize(unsignedIndex)) {
969 return {sliceOp.getDynamicSize(unsignedIndex)};
983 using OpRewritePattern<DimOp>::OpRewritePattern;
985 LogicalResult matchAndRewrite(DimOp dimOp,
986 PatternRewriter &rewriter)
const override {
987 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
990 Value newSource = castOp.getOperand();
999 using OpRewritePattern<DimOp>::OpRewritePattern;
1001 LogicalResult matchAndRewrite(DimOp dimOp,
1002 PatternRewriter &rewriter)
const override {
1003 auto source = dimOp.getSource();
1004 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1008 auto resultIndex = cast<OpResult>(source).getResultNumber();
1009 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1012 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1020 using OpRewritePattern<DimOp>::OpRewritePattern;
1022 LogicalResult matchAndRewrite(DimOp dim,
1023 PatternRewriter &rewriter)
const override {
1024 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1032 Location loc = dim.getLoc();
1034 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1035 if (extract.
getType() != dim.getType())
1037 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1046 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1056 assert(none_of(staticShape, ShapedType::isDynamic) &&
1057 "expected only static sizes");
1061void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1062 ArrayRef<int64_t> staticShape, Type elementType,
1063 ValueRange dynamicSizes, Attribute encoding) {
1064 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1065 build(builder,
result, tensorType, dynamicSizes);
1068void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1069 ArrayRef<OpFoldResult> sizes, Type elementType,
1070 Attribute encoding) {
1071 SmallVector<int64_t> staticShape;
1072 SmallVector<Value> dynamicSizes;
1074 build(builder,
result, staticShape, elementType, dynamicSizes, encoding);
1077LogicalResult EmptyOp::verify() {
1079 return emitOpError(
"incorrect number of dynamic sizes, has ")
1081 <<
getType().getNumDynamicDims();
1086EmptyOp::reifyResultShapes(OpBuilder &builder,
1088 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1090 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1091 if (
getType().isDynamicDim(i)) {
1100Value EmptyOp::getDynamicSize(
unsigned idx) {
1101 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1103 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1104 if (
getType().isDynamicDim(i))
1109SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1110 SmallVector<OpFoldResult>
result;
1113 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1114 if (
getType().isDynamicDim(i)) {
1135struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1136 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1138 LogicalResult matchAndRewrite(EmptyOp op,
1139 PatternRewriter &rewriter)
const override {
1140 SmallVector<Value> foldedDynamicSizes;
1142 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1145 if (foldedTensorType == op.getType())
1148 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1149 foldedDynamicSizes);
1155struct FoldEmptyTensorWithDimOp :
public OpRewritePattern<DimOp> {
1156 using OpRewritePattern<DimOp>::OpRewritePattern;
1158 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1159 PatternRewriter &rewriter)
const override {
1160 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1161 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1162 if (!emptyTensorOp || !maybeConstantIndex)
1164 auto emptyTensorType = emptyTensorOp.getType();
1165 if (*maybeConstantIndex < 0 ||
1166 *maybeConstantIndex >= emptyTensorType.getRank() ||
1167 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1170 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1190struct FoldEmptyTensorWithCastOp :
public OpRewritePattern<CastOp> {
1191 using OpRewritePattern<CastOp>::OpRewritePattern;
1193 LogicalResult matchAndRewrite(CastOp castOp,
1194 PatternRewriter &rewriter)
const override {
1197 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1202 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1203 ArrayRef<int64_t> resultShape = resultType.getShape();
1204 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1205 SmallVector<OpFoldResult> newMixedSizes;
1206 newMixedSizes.reserve(currMixedSizes.size());
1207 assert(resultShape.size() == currMixedSizes.size() &&
1208 "mismatch in result shape and sizes of empty op");
1209 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1210 int64_t newDim = std::get<0>(it);
1211 OpFoldResult currDim = std::get<1>(it);
1214 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1215 if (ShapedType::isDynamic(newDim) ||
1216 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1221 producer,
"mismatch in static value of shape of empty tensor "
1222 "result and cast result");
1224 newMixedSizes.push_back(attr);
1230 if (ShapedType::isStatic(newDim)) {
1231 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1237 newMixedSizes.push_back(currDim);
1242 resultType.getElementType());
1249void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1250 MLIRContext *context) {
1251 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1252 ReplaceEmptyTensorStaticShapeDims>(context);
1269struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1270 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1272 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1273 PatternRewriter &rewriter)
const final {
1274 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1277 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1280 extract, tensorCast.getSource(), extract.getIndices());
1295struct ExtractFromCollapseShape :
public OpRewritePattern<tensor::ExtractOp> {
1296 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1298 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1299 PatternRewriter &rewriter)
const final {
1301 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1304 if (!collapseOp.getSrcType().hasStaticShape())
1307 auto sourceSizes = collapseOp.getSrcType().getShape();
1309 SmallVector<Value>
indices(extractOp.getIndices().begin(),
1310 extractOp.getIndices().end());
1311 SmallVector<Value> sourceIndices;
1312 for (
auto [index, group] :
1313 llvm::zip(
indices, collapseOp.getReassociationIndices())) {
1314 assert(!group.empty() &&
"association indices groups cannot be empty");
1315 auto groupSize = group.size();
1317 if (groupSize == 1) {
1318 sourceIndices.push_back(index);
1322 SmallVector<int64_t> basis =
1323 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
1324 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1325 rewriter, extractOp.getLoc(), index, basis,
true);
1326 llvm::append_range(sourceIndices,
delinearize.getResults());
1328 if (collapseOp.getReassociationIndices().empty()) {
1331 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1333 rewriter, extractOp.getLoc(), zeroAffineMap,
1334 ArrayRef<OpFoldResult>{});
1335 for (int64_t i = 0; i < srcRank; i++) {
1336 sourceIndices.push_back(
1342 extractOp, collapseOp.getSrc(), sourceIndices);
1349void ExtractOp::getAsmResultNames(
1351 setNameFn(getResult(),
"extracted");
1354LogicalResult ExtractOp::verify() {
1356 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1357 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1358 return emitOpError(
"incorrect number of indices for extract_element");
1367 auto insertOp = extractOp.getTensor().
getDefiningOp<InsertOp>();
1372 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1373 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1374 return insertOp.getScalar();
1379OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1380 if (Attribute tensor = adaptor.getTensor()) {
1383 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1384 return splatTensor.getSplatValue<Attribute>();
1387 if (isa<DenseResourceElementsAttr>(tensor))
1392 SmallVector<uint64_t, 8>
indices;
1393 for (Attribute indice : adaptor.getIndices()) {
1394 if (!indice || !llvm::isa<IntegerAttr>(indice))
1396 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1400 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1401 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1402 auto rank = tensorType.getRank();
1403 assert(
static_cast<int64_t
>(
indices.size()) == tensorType.getRank() &&
1407 for (
int i = rank - 1; i >= 0; --i) {
1408 flatIndex +=
indices[i] * stride;
1409 stride *= tensorType.getDimSize(i);
1413 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1416 return fromElementsOp.getElements()[flatIndex];
1420 if (Attribute tensor = adaptor.getTensor()) {
1421 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1422 if (elementsAttr && elementsAttr.isValidIndex(
indices))
1423 return elementsAttr.getValues<Attribute>()[
indices];
1432void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1433 MLIRContext *context) {
1434 results.
add<ExtractFromTensorCast>(context);
1446void FromElementsOp::getAsmResultNames(
1448 setNameFn(getResult(),
"from_elements");
1453 assert(!elements.empty() &&
"expected at least one element");
1454 Type resultType = RankedTensorType::get(
1455 {
static_cast<int64_t>(elements.size())}, elements.front().
getType());
1456 build(builder,
result, resultType, elements);
1459OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1460 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1483struct ExtractElementFromIndexCast
1484 :
public OpRewritePattern<tensor::ExtractOp> {
1485 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1487 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1488 PatternRewriter &rewriter)
const final {
1489 Location loc = extract.getLoc();
1490 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1496 auto newExtract = tensor::ExtractOp::create(
1497 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1508void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1509 MLIRContext *context) {
1510 results.
add<ExtractElementFromIndexCast>(context);
1517void GatherOp::getAsmResultNames(
1519 setNameFn(getResult(),
"gather");
1534RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1535 RankedTensorType indicesType,
1536 ArrayRef<int64_t> gatherDims,
1538 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1539 resultShape.reserve(resultShape.size() + sourceType.getRank());
1540 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1541 if (llvm::binary_search(gatherDims, idx)) {
1543 resultShape.push_back(1);
1546 resultShape.push_back(sourceType.getDimSize(idx));
1548 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1554 StringRef gatherOrScatter, StringRef sourceOrDest) {
1556 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1558 int64_t numGatherDims = dims.size();
1559 if (numGatherDims > rank)
1561 <<
"_dims overflow " << sourceOrDest <<
" rank";
1564 <<
"_dims length must match the size of last dimension of indices";
1568 <<
"_dims value must be non-negative";
1571 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1573 for (
int64_t i = 1; i < numGatherDims; ++i) {
1574 if (dims[i - 1] >= dims[i])
1576 <<
"_dims values must be strictly increasing";
1581LogicalResult GatherOp::verify() {
1582 int64_t sourceRank = getSourceType().getRank();
1583 ArrayRef<int64_t> gatherDims = getGatherDims();
1585 getIndicesType().
getShape(), sourceRank,
1586 "gather",
"source")))
1589 RankedTensorType expectedResultType = GatherOp::inferResultType(
1590 getSourceType(), getIndicesType(), gatherDims,
false);
1591 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1592 getSourceType(), getIndicesType(), gatherDims,
true);
1593 if (getResultType() != expectedResultType &&
1594 getResultType() != expectedRankReducedResultType) {
1598 << expectedResultType <<
" or its rank-reduced variant "
1599 << expectedRankReducedResultType <<
" (got: " << getResultType()
1606OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1607 if (OpFoldResult reshapedSource = reshapeConstantSource(
1608 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1610 return reshapedSource;
1618void InsertOp::getAsmResultNames(
1620 setNameFn(getResult(),
"inserted");
1623LogicalResult InsertOp::verify() {
1625 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1626 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1627 return emitOpError(
"incorrect number of indices");
1631OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1632 Attribute scalar = adaptor.getScalar();
1633 Attribute dest = adaptor.getDest();
1635 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1636 if (scalar == splatDest.getSplatValue<Attribute>())
1645void GenerateOp::getAsmResultNames(
1647 setNameFn(getResult(),
"generated");
1650LogicalResult GenerateOp::reifyResultShapes(
1652 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1654 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1655 if (
getType().isDynamicDim(dim)) {
1656 reifiedReturnShapes[0][dim] = getOperand(idx++);
1658 reifiedReturnShapes[0][dim] =
1665LogicalResult GenerateOp::verify() {
1668 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1669 if (getNumOperands() != resultType.getNumDynamicDims())
1670 return emitError(
"must have as many index operands as dynamic extents "
1671 "in the result type");
1675LogicalResult GenerateOp::verifyRegions() {
1676 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1678 if (!llvm::all_of(getBody().getArgumentTypes(),
1679 [](Type ty) {
return ty.
isIndex(); }))
1680 return emitError(
"all body arguments must be index");
1681 if (getBody().getNumArguments() != resultTy.getRank())
1682 return emitError(
"must have one body argument per input dimension");
1685 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1687 if (yieldOp.getValue().getType() != resultTy.getElementType())
1689 "body must be terminated with a `yield` operation of the tensor "
1695void GenerateOp::build(
1696 OpBuilder &
b, OperationState &
result, Type resultTy,
1699 build(
b,
result, resultTy, dynamicExtents);
1702 OpBuilder::InsertionGuard guard(
b);
1703 Region *bodyRegion =
result.regions.front().get();
1704 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1705 SmallVector<Type, 2> argumentTypes(rank,
b.getIndexType());
1706 SmallVector<Location, 2> argumentLocs(rank,
result.location);
1708 b.createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1718struct StaticTensorGenerate :
public OpRewritePattern<GenerateOp> {
1719 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1721 LogicalResult matchAndRewrite(GenerateOp generateOp,
1722 PatternRewriter &rewriter)
const final {
1723 SmallVector<Value> foldedDynamicSizes;
1725 generateOp.getType(), generateOp.getDynamicExtents(),
1726 foldedDynamicSizes);
1729 if (foldedTensorType == generateOp.getType())
1732 auto loc = generateOp.getLoc();
1734 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1736 newOp.getBody().begin());
1738 generateOp.getType(), newOp);
1754struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1755 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1757 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1758 PatternRewriter &rewriter)
const final {
1759 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1764 Block *body = &tensorFromElements.getBody().front();
1767 rewriter.
clone(op, mapping);
1778void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1779 MLIRContext *context) {
1781 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1788void RankOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1789 setNameFn(getResult(),
"rank");
1792OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1794 auto type = getOperand().getType();
1795 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1796 if (shapedType && shapedType.hasRank())
1797 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
1798 return IntegerAttr();
1805void ReshapeOp::getAsmResultNames(
1807 setNameFn(getResult(),
"reshape");
1812 for (
auto dim : type.getShape())
1817LogicalResult ReshapeOp::verify() {
1818 TensorType operandType = llvm::cast<TensorType>(getSource().
getType());
1819 TensorType resultType = llvm::cast<TensorType>(getResult().
getType());
1822 return emitOpError(
"element types of source and destination tensor "
1823 "types should be the same");
1827 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1828 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1830 if (resultRankedType) {
1831 if (operandRankedType && resultRankedType.hasStaticShape() &&
1832 operandRankedType.hasStaticShape()) {
1834 return emitOpError(
"source and destination tensor should have the "
1835 "same number of elements");
1837 if (ShapedType::isDynamic(shapeSize))
1838 return emitOpError(
"cannot use shape operand with dynamic length to "
1839 "reshape to statically-ranked tensor type");
1840 if (shapeSize != resultRankedType.getRank())
1842 "length of shape operand differs from the result's tensor rank");
1847OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1848 if (OpFoldResult reshapedSource = reshapeConstantSource(
1849 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1851 return reshapedSource;
1856 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1857 getSourceMutable().assign(reshapeOpProducer.getSource());
1861 auto source = getSource();
1862 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1863 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1864 if (!sourceTy || !resultTy || sourceTy != resultTy)
1869 if (sourceTy.getRank() <= 1)
1872 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1873 auto elements = fromElements.getElements();
1875 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1876 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++
id) {
1877 auto element = elements[id];
1880 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1884 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1885 dynamicNoop &= dimOp.getSource() == source;
1889 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1893 dynamicNoop =
false;
1908void CollapseShapeOp::getAsmResultNames(
1910 setNameFn(getResult(),
"collapsed");
1913void ExpandShapeOp::getAsmResultNames(
1915 setNameFn(getResult(),
"expanded");
1918int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1919 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1920 "invalid resultDim");
1921 for (
const auto &it : llvm::enumerate(getReassociationIndices()))
1922 if (llvm::is_contained(it.value(), resultDim))
1924 llvm_unreachable(
"could not find reassociation group");
1927FailureOr<SmallVector<OpFoldResult>>
1928ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
1929 RankedTensorType expandedType,
1930 ArrayRef<ReassociationIndices> reassociation,
1931 ArrayRef<OpFoldResult> inputShape) {
1932 std::optional<SmallVector<OpFoldResult>> outputShape =
1937 return *outputShape;
1940SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1944void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1945 Type resultType, Value src,
1946 ArrayRef<ReassociationIndices> reassociation,
1947 ArrayRef<OpFoldResult> outputShape) {
1948 auto [staticOutputShape, dynamicOutputShape] =
1950 build(builder,
result, cast<RankedTensorType>(resultType), src,
1952 dynamicOutputShape, staticOutputShape);
1955void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1956 Type resultType, Value src,
1957 ArrayRef<ReassociationIndices> reassociation) {
1958 SmallVector<OpFoldResult> inputShape =
1960 auto tensorResultTy = cast<RankedTensorType>(resultType);
1961 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1962 builder,
result.location, tensorResultTy, reassociation, inputShape);
1963 SmallVector<OpFoldResult> outputShapeOrEmpty;
1964 if (succeeded(outputShape)) {
1965 outputShapeOrEmpty = *outputShape;
1967 build(builder,
result, tensorResultTy, src, reassociation,
1968 outputShapeOrEmpty);
1971SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1974SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1976 getReassociationIndices());
1979SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1982SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1984 getReassociationIndices());
1987RankedTensorType CollapseShapeOp::inferCollapsedType(
1988 RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1989 return inferCollapsedType(
1991 type.getContext(), reassociation)));
1997CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1998 ArrayRef<AffineMap> reassociation) {
1999 auto shape = type.getShape();
2000 SmallVector<int64_t, 4> newShape;
2001 newShape.reserve(reassociation.size());
2006 unsigned currentDim = 0;
2007 for (AffineMap m : reassociation) {
2008 unsigned dim = m.getNumResults();
2009 auto band = shape.slice(currentDim, dim);
2011 if (llvm::is_contained(band, ShapedType::kDynamic))
2012 size = ShapedType::kDynamic;
2014 for (
unsigned d = 0; d < dim; ++d)
2015 size *= shape[currentDim + d];
2016 newShape.push_back(size);
2020 return RankedTensorType::get(newShape, type.getElementType());
2023void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2024 ArrayRef<ReassociationIndices> reassociation,
2025 ArrayRef<NamedAttribute> attrs) {
2026 auto resultType = inferCollapsedType(
2027 llvm::cast<RankedTensorType>(src.
getType()),
2030 result.addAttribute(getReassociationAttrStrName(),
2032 build(
b,
result, resultType, src, attrs);
2035template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
2036 TensorReshapeOp, ExpandShapeOp>::value>
2038 RankedTensorType expandedType,
2039 RankedTensorType collapsedType) {
2041 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2044 auto maps = op.getReassociationMaps();
2045 RankedTensorType expectedType =
2046 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2048 return op.emitOpError(
"expected collapsed type to be ")
2049 << expectedType <<
", but got " << collapsedType;
2053LogicalResult ExpandShapeOp::verify() {
2054 auto srcType = getSrcType();
2055 auto resultType = getResultType();
2057 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2058 return emitOpError(
"expected number of static shape dims to be equal to "
2059 "the output rank (")
2060 << resultType.getRank() <<
") but found "
2061 << getStaticOutputShape().size() <<
" inputs instead";
2063 if ((int64_t)getOutputShape().size() !=
2064 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2065 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2066 "static_output_shape: static_output_shape has ")
2067 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2068 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2074LogicalResult CollapseShapeOp::verify() {
2081template <
typename TensorReshapeOp>
2082struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2083 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2084 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2085 PatternRewriter &rewriter)
const override {
2086 DenseElementsAttr attr;
2092 reshapeOp.getResultType(), attr.
getRawData());
2099template <
typename TensorReshapeOp>
2100class FoldReshapeWithSplat :
public OpRewritePattern<TensorReshapeOp> {
2102 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2104 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2105 PatternRewriter &rewriter)
const override {
2106 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2107 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2111 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2118template <
typename TensorReshapeOp>
2119struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2120 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2121 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2122 PatternRewriter &rewriter)
const override {
2124 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2128 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2130 if (!shapedTy.hasStaticShape())
2134 fromElements.getElements());
2140struct FoldCollapseOfCastOp :
public OpRewritePattern<CollapseShapeOp> {
2141 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2143 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2144 PatternRewriter &rewriter)
const override {
2145 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2149 RankedTensorType srcType =
2150 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2151 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2152 srcType, collapseShapeOp.getReassociationMaps());
2154 if (newResultType == collapseShapeOp.getResultType()) {
2156 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2159 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2160 newResultType, castOp.getSource(),
2161 collapseShapeOp.getReassociation());
2163 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2173struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2174 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2176 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2177 PatternRewriter &rewriter)
const override {
2178 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2182 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2183 SmallVector<ReassociationIndices, 4> reassoc =
2184 expandOp.getReassociationIndices();
2186 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2187 SmallVector<Value> dynamicOutputShape;
2188 auto outputIt = expandOp.getOutputShape().begin();
2190 for (
const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2191 for (uint64_t outDim : innerReassoc) {
2192 if (ShapedType::isStatic(newOutputShape[outDim]))
2199 Value val = *outputIt;
2201 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2202 dynamicOutputShape.push_back(val);
2208 newOutputShape[outDim] = cst.getSExtValue();
2210 dynamicOutputShape.push_back(val);
2216 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2220 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2221 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2222 for (
auto outDim : reassoc[inDim]) {
2223 auto ofr = newOutputShape[outDim];
2224 if (ShapedType::isDynamic(ofr)) {
2225 newInputShape[inDim] = ShapedType::kDynamic;
2228 newInputShape[inDim] *= ofr;
2232 SmallVector<OpFoldResult> outputOfr =
2234 auto inputType = RankedTensorType::get(
2235 newInputShape, expandOp.getSrcType().getElementType());
2236 auto outputType = RankedTensorType::get(
2237 newOutputShape, expandOp.getSrcType().getElementType());
2238 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2240 auto newExpand = ExpandShapeOp::create(
2241 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2242 expandOp.getReassociationIndices(), outputOfr);
2244 newExpand.getResult());
2250void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2251 MLIRContext *context) {
2253 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2254 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
2255 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2256 FoldReshapeWithSplat<ExpandShapeOp>,
2257 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2260void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2261 MLIRContext *context) {
2263 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2264 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2265 tensor::DimOp, RankedTensorType>,
2266 FoldReshapeWithConstant<CollapseShapeOp>,
2267 FoldReshapeWithSplat<CollapseShapeOp>,
2268 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2272OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2274 adaptor.getOperands());
2277OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2279 adaptor.getOperands());
2286void ExtractSliceOp::getAsmResultNames(
2288 setNameFn(getResult(),
"extracted_slice");
2295ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2296 ArrayRef<int64_t> staticSizes) {
2300 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2301 sourceTensorType.getRank() &&
2302 "unexpected staticSizes not equal to rank of source");
2303 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2304 sourceTensorType.getEncoding());
2309ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2310 ArrayRef<OpFoldResult> sizes) {
2311 SmallVector<int64_t> staticSizes;
2314 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2315 sourceTensorType.getRank() &&
2316 "unexpected staticSizes not equal to rank of source");
2317 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2318 sourceTensorType.getEncoding());
2329RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2330 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2331 ArrayRef<int64_t> sizes) {
2333 auto inferredType = llvm::cast<RankedTensorType>(
2334 inferResultType(sourceRankedTensorType, sizes));
2335 int rankDiff = inferredType.getRank() - desiredResultRank;
2337 auto shape = inferredType.getShape();
2338 llvm::SmallBitVector dimsToProject =
2340 SmallVector<int64_t> projectedShape;
2342 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2343 if (!dimsToProject.test(pos))
2344 projectedShape.push_back(shape[pos]);
2346 RankedTensorType::get(projectedShape, inferredType.getElementType());
2348 return inferredType;
2351RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2352 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2353 ArrayRef<OpFoldResult> sizes) {
2354 SmallVector<int64_t> staticSizes;
2355 SmallVector<Value> dynamicSizes;
2357 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2358 desiredResultRank, sourceRankedTensorType, staticSizes);
2363void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2364 RankedTensorType resultType, Value source,
2365 ArrayRef<OpFoldResult> offsets,
2366 ArrayRef<OpFoldResult> sizes,
2367 ArrayRef<OpFoldResult> strides,
2368 ArrayRef<NamedAttribute> attrs) {
2369 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2370 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2374 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2377 resultType = llvm::cast<RankedTensorType>(
2378 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2380 result.addAttributes(attrs);
2381 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
2382 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2383 b.getDenseI64ArrayAttr(staticSizes),
2384 b.getDenseI64ArrayAttr(staticStrides));
2389void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2390 ArrayRef<OpFoldResult> offsets,
2391 ArrayRef<OpFoldResult> sizes,
2392 ArrayRef<OpFoldResult> strides,
2393 ArrayRef<NamedAttribute> attrs) {
2394 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2399void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2400 ArrayRef<Range> ranges,
2401 ArrayRef<NamedAttribute> attrs) {
2403 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2408void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2409 RankedTensorType resultType, Value source,
2411 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2412 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2413 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
2414 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2415 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
2416 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2417 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
2418 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
2422void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2424 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2425 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2430 RankedTensorType expectedType) {
2435 return op->
emitError(
"expected rank to be smaller or equal to ")
2436 <<
"the other rank. ";
2438 return op->
emitError(
"expected type to be ")
2439 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2441 return op->
emitError(
"expected element type to be ")
2442 << expectedType.getElementType();
2444 llvm_unreachable(
"unexpected extract_slice op verification result");
2450void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2451 RankedTensorType resultType, Value source,
2452 ArrayRef<OpFoldResult> sizes,
2453 ArrayRef<NamedAttribute> attrs) {
2454 Attribute zeroIdxAttr =
b.getIndexAttr(0);
2455 Attribute oneIdxAttr =
b.getIndexAttr(1);
2456 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2457 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2458 build(
b,
result, resultType, source, readOffsets, sizes, readStrides, attrs);
2462LogicalResult ExtractSliceOp::verify() {
2463 RankedTensorType sourceType = getSourceType();
2466 RankedTensorType expectedType =
2467 ExtractSliceOp::inferResultType(sourceType,
getMixedSizes());
2475 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2476 getStaticStrides(),
true);
2478 return getOperation()->emitError(boundsResult.
errorMessage);
2483llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2488ExtractSliceOp::rankReduceIfNeeded(OpBuilder &
b, Location loc, Value value,
2489 ArrayRef<int64_t> desiredShape) {
2490 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2491 assert(sourceTensorType &&
"not a ranked tensor type");
2492 auto sourceShape = sourceTensorType.getShape();
2493 if (sourceShape.equals(desiredShape))
2495 auto maybeRankReductionMask =
2497 if (!maybeRankReductionMask)
2501 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2504LogicalResult ExtractSliceOp::reifyResultShapes(
2506 reifiedReturnShapes.resize(1);
2507 reifiedReturnShapes[0].reserve(
getType().getRank());
2510 for (
const auto &size :
enumerate(mixedSizes)) {
2511 if (droppedDims.test(size.index()))
2513 reifiedReturnShapes[0].push_back(size.value());
2534class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2536 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2538 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2539 PatternRewriter &rewriter)
const override {
2541 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2542 return matchPattern(operand, matchConstantIndex());
2546 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2555 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2556 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2557 sliceOp.getStaticStrides());
2562 Location loc = sliceOp.getLoc();
2563 Value newResult = ExtractSliceOp::create(
2564 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2565 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2566 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2567 sliceOp.getStaticStrides());
2576template <
typename IterTy,
typename ElemTy>
2577static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2578 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2579 ArrayRef<int64_t> strides,
2580 llvm::SmallVectorImpl<ElemTy> *outValues) {
2581 assert(offsets.size() == sizes.size());
2582 assert(offsets.size() == strides.size());
2583 if (offsets.empty())
2586 int64_t offset = offsets.front();
2587 int64_t size = sizes.front();
2588 int64_t stride = strides.front();
2589 if (offsets.size() == 1) {
2590 for (int64_t i = 0; i < size; ++i, offset += stride)
2591 outValues->push_back(*(values + offset));
2596 for (int64_t i = 0; i < size; ++i, offset += stride) {
2597 auto begin = values + offset * counts.front();
2598 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2599 offsets.drop_front(), sizes.drop_front(),
2600 strides.drop_front(), outValues);
2607class ConstantOpExtractSliceFolder final
2608 :
public OpRewritePattern<ExtractSliceOp> {
2610 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2612 ConstantOpExtractSliceFolder(MLIRContext *context,
2614 : OpRewritePattern<ExtractSliceOp>(context),
2615 controlFn(std::move(controlFn)) {}
2617 LogicalResult matchAndRewrite(ExtractSliceOp op,
2618 PatternRewriter &rewriter)
const override {
2619 DenseElementsAttr attr;
2628 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2629 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2630 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2637 int64_t count = sourceType.getNumElements();
2642 auto offsets = op.getStaticOffsets();
2643 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2645 auto sizes = op.getStaticSizes();
2646 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2648 auto strides = op.getStaticStrides();
2649 if (llvm::is_contained(strides, ShapedType::kDynamic))
2653 SmallVector<int64_t> counts;
2654 ArrayRef<int64_t> shape = sourceType.getShape();
2655 counts.reserve(shape.size());
2656 for (int64_t v : shape) {
2658 counts.push_back(count);
2662 DenseElementsAttr newAttr;
2664 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2665 SmallVector<APInt> outValues;
2666 outValues.reserve(sourceType.getNumElements());
2667 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2668 elems.begin(), counts, offsets, sizes, strides, &outValues);
2670 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2671 SmallVector<APFloat> outValues;
2672 outValues.reserve(sourceType.getNumElements());
2673 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2674 elems.begin(), counts, offsets, sizes, strides, &outValues);
2697 patterns.add<ConstantOpExtractSliceFolder>(
patterns.getContext(), controlFn);
2706 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2707 op.getType().getRank(), op.getSourceType(), mixedSizes);
2714 ExtractSliceOp newOp) {
2717 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2723void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2724 MLIRContext *context) {
2726 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2727 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2728 ExtractSliceOpCastFolder>(context);
2734 ShapedType shapedType) {
2741 auto shape = shapedType.getShape();
2742 for (
auto it : llvm::zip(op.getMixedSizes(),
shape))
2756 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2759 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2760 insertOp.isSameAs(extractOp, isSame))
2761 return insertOp.getSource();
2766OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2767 if (OpFoldResult reshapedSource = reshapeConstantSource(
2768 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2770 return reshapedSource;
2771 if (getSourceType() ==
getType() &&
2773 return this->getSource();
2777 return OpFoldResult();
2782 auto rankedTensorType = llvm::cast<RankedTensorType>(
tensor.getType());
2783 unsigned rank = rankedTensorType.getRank();
2787 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType,
tensor,
2788 offsets, sizes, strides);
2795void InsertSliceOp::getAsmResultNames(
2797 setNameFn(getResult(),
"inserted_slice");
2811 result.addAttributes(attrs);
2812 build(
b,
result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2813 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2814 b.getDenseI64ArrayAttr(staticSizes),
2815 b.getDenseI64ArrayAttr(staticStrides));
2820void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2821 Value dest, ArrayRef<Range> ranges,
2822 ArrayRef<NamedAttribute> attrs) {
2824 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
2828void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2830 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2831 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2832 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
2833 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2834 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
2835 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2836 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
2837 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
2843 RankedTensorType srcType, RankedTensorType dstType,
2848 RankedTensorType expected =
2849 ExtractSliceOp::inferResultType(dstType, staticSizes);
2851 *expectedType = expected;
2856LogicalResult InsertSliceOp::verify() {
2858 RankedTensorType expectedType;
2861 getStaticSizes(), getStaticStrides(), &expectedType);
2868 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2869 getStaticStrides(),
true);
2871 return getOperation()->emitError(boundsResult.
errorMessage);
2894 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2897 if (!prevInsertOp ||
2898 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2899 !prevInsertOp.isSameAs(insertOp, isSame))
2902 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2914 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2917 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2918 !extractOp.isSameAs(insertOp, isSame))
2921 return extractOp.getSource();
2924OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2925 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2926 getSourceType() ==
getType() &&
2928 return this->getSource();
2935 return OpFoldResult();
2938LogicalResult InsertSliceOp::reifyResultShapes(
2940 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
2949template <
typename InsertOpTy>
2950class InsertSliceOpConstantArgumentFolder final
2951 :
public OpRewritePattern<InsertOpTy> {
2953 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2955 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2956 PatternRewriter &rewriter)
const override {
2957 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2958 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2959 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2968 SliceBoundsVerificationResult sliceResult =
2970 mixedOffsets, mixedSizes, mixedStrides);
2975 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2976 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2978 Value toInsert = insertSliceOp.getSource();
2979 if (sourceType != insertSliceOp.getSourceType()) {
2980 OpBuilder::InsertionGuard g(rewriter);
2984 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2986 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2987 sourceType, toInsert);
2990 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2991 mixedSizes, mixedStrides);
3016template <
typename InsertOpTy>
3017struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
3018 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3020 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3021 PatternRewriter &rewriter)
const override {
3022 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3023 return matchPattern(operand, matchConstantIndex());
3027 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3030 return std::nullopt;
3031 return castOp.getSource();
3033 std::optional<Value> sourceCastSource =
3034 getSourceOfCastOp(insertSliceOp.getSource());
3035 std::optional<Value> destCastSource =
3036 getSourceOfCastOp(insertSliceOp.getDest());
3037 if (!sourceCastSource && !destCastSource)
3041 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3042 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3043 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
3044 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3045 if (!srcType || !dstType)
3051 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3053 staticSizes, srcType.getShape(),
true);
3054 if (!rankReductionMask.has_value())
3061 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3062 int64_t rankReducedIdx = 0;
3063 for (
auto [idx, size] :
enumerate(staticSizes)) {
3064 if (!rankReductionMask.value().contains(idx) &&
3065 !srcType.isDynamicDim(rankReducedIdx)) {
3067 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
3068 size = srcType.getDimSize(rankReducedIdx++);
3074 staticSizes, insertSliceOp.getStaticStrides()) !=
3075 SliceVerificationResult::Success)
3077 SliceBoundsVerificationResult sliceResult =
3079 mixedSizes, insertSliceOp.getMixedStrides());
3084 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3085 insertSliceOp.getMixedOffsets(), mixedSizes,
3086 insertSliceOp.getMixedStrides());
3089 bool isParallelInsert =
3090 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3091 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3092 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3093 insertSliceOp.getDestType(),
3122template <
typename InsertOpTy>
3123struct InsertSliceOpSourceCastInserter final
3124 :
public OpRewritePattern<InsertOpTy> {
3125 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3127 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3128 PatternRewriter &rewriter)
const override {
3129 RankedTensorType srcType = insertSliceOp.getSourceType();
3130 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3132 SmallVector<int64_t> newSrcShape(srcType.getShape());
3133 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3134 if (std::optional<int64_t> constInt =
3139 newSrcShape[i] = *constInt;
3145 RankedTensorType newSrcType = RankedTensorType::get(
3146 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3147 if (srcType == newSrcType ||
3149 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3157 OpBuilder::InsertionGuard g(rewriter);
3161 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3163 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3164 newSrcType, insertSliceOp.getSource());
3166 insertSliceOp, cast, insertSliceOp.getDest(),
3167 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3168 insertSliceOp.getMixedStrides());
3174llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3178void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3179 MLIRContext *context) {
3180 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3181 InsertSliceOpCastFolder<InsertSliceOp>,
3182 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3189 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3190 unsigned rank = rankedTensorType.getRank();
3194 return b.createOrFold<tensor::InsertSliceOp>(loc,
tensor, dest, offsets,
3203 setNameFn(getResult(),
"padded");
3206LogicalResult PadOp::verify() {
3207 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3208 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3210 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3211 if (!expectedType) {
3212 return emitError(
"failed to infer expectedType from sourceType ")
3213 << sourceType <<
", specified resultType is " << resultType;
3215 if (resultType.getRank() != expectedType.getRank()) {
3217 << resultType <<
" does not match the inferred type "
3220 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3221 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3223 if (expectedType.isDynamicDim(i))
3226 << resultType <<
" does not match the inferred type "
3233LogicalResult PadOp::verifyRegions() {
3234 auto ®ion = getRegion();
3235 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3236 Block &block = region.front();
3238 return emitError(
"expected the block to have ") << rank <<
" arguments";
3242 if (!en.value().isIndex())
3244 << (en.index() + 1) <<
" to be an index";
3249 if (yieldOp.getValue().getType() !=
3251 return emitOpError(
"expected yield type to match shape element type");
3256RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3257 ArrayRef<int64_t> staticLow,
3258 ArrayRef<int64_t> staticHigh,
3259 ArrayRef<int64_t> resultShape) {
3260 unsigned rank = sourceType.getRank();
3261 if (staticLow.size() != rank)
3262 return RankedTensorType();
3263 if (staticHigh.size() != rank)
3264 return RankedTensorType();
3265 if (!resultShape.empty() && resultShape.size() != rank)
3266 return RankedTensorType();
3268 SmallVector<int64_t, 4> inferredShape;
3269 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3270 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3271 staticHigh[i] == ShapedType::kDynamic) {
3272 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3275 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3276 assert((resultShape.empty() || size == resultShape[i] ||
3277 resultShape[i] == ShapedType::kDynamic) &&
3278 "mismatch between inferred shape and result shape");
3279 inferredShape.push_back(size);
3283 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3286void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3287 Value source, ArrayRef<int64_t> staticLow,
3289 bool nofold, ArrayRef<NamedAttribute> attrs) {
3290 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3292 resultType = inferResultType(sourceType, staticLow, staticHigh);
3293 result.addAttributes(attrs);
3294 build(
b,
result, resultType, source, low, high,
3295 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3296 nofold ?
b.getUnitAttr() : UnitAttr());
3299void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3301 ArrayRef<NamedAttribute> attrs) {
3302 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3303 unsigned rank = sourceType.getRank();
3304 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3305 build(
b,
result, resultType, source, staticVector, staticVector, low, high,
3309void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3310 Value source, ArrayRef<OpFoldResult> low,
3311 ArrayRef<OpFoldResult> high,
bool nofold,
3312 ArrayRef<NamedAttribute> attrs) {
3313 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3314 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3315 SmallVector<int64_t, 4> staticLow, staticHigh;
3323 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3325 assert(llvm::isa<RankedTensorType>(resultType));
3326 result.addAttributes(attrs);
3327 build(
b,
result, resultType, source, dynamicLow, dynamicHigh,
3328 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3329 nofold ?
b.getUnitAttr() : UnitAttr());
3332void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3333 Value source, ArrayRef<OpFoldResult> low,
3334 ArrayRef<OpFoldResult> high, Value constantPadValue,
3335 bool nofold, ArrayRef<NamedAttribute> attrs) {
3336 build(
b,
result, resultType, source, low, high, nofold, attrs);
3339 Region *region =
result.regions[0].get();
3340 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3341 SmallVector<Type> blockArgTypes(sourceRank,
b.getIndexType());
3342 SmallVector<Location> blockArgLocs(sourceRank,
result.location);
3346 OpBuilder::InsertionGuard guard(
b);
3347 b.createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3348 tensor::YieldOp::create(
b,
result.location, constantPadValue);
3351llvm::SmallBitVector PadOp::getPaddedDims() {
3352 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3353 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3354 for (
const auto &en :
enumerate(paddingWidths))
3356 paddedDims.set(en.index());
3358 extractPaddedDims(getMixedLowPad());
3359 extractPaddedDims(getMixedHighPad());
3366struct FoldStaticZeroPadding :
public OpRewritePattern<PadOp> {
3367 using OpRewritePattern<PadOp>::OpRewritePattern;
3369 LogicalResult matchAndRewrite(PadOp padTensorOp,
3370 PatternRewriter &rewriter)
const override {
3371 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3373 if (padTensorOp.getNofold())
3376 padTensorOp, padTensorOp.getResult().
getType(),
3377 padTensorOp.getSource());
3383struct FoldSourceTensorCast :
public OpRewritePattern<PadOp> {
3384 using OpRewritePattern<PadOp>::OpRewritePattern;
3386 LogicalResult matchAndRewrite(PadOp padTensorOp,
3387 PatternRewriter &rewriter)
const override {
3388 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3392 auto newResultType = PadOp::inferResultType(
3393 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3394 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3395 padTensorOp.getResultType().getShape());
3397 if (newResultType == padTensorOp.getResultType()) {
3399 padTensorOp.getSourceMutable().assign(castOp.getSource());
3402 auto newOp = PadOp::create(
3403 rewriter, padTensorOp->getLoc(), newResultType,
3404 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3405 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3406 padTensorOp.getHigh(), padTensorOp.getNofold(),
3409 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3412 padTensorOp, padTensorOp.getResultType(), newOp);
3420struct FoldTargetTensorCast :
public OpRewritePattern<PadOp> {
3421 using OpRewritePattern<PadOp>::OpRewritePattern;
3423 LogicalResult matchAndRewrite(PadOp padTensorOp,
3424 PatternRewriter &rewriter)
const override {
3425 if (!padTensorOp.getResult().hasOneUse())
3428 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3432 tensorCastOp.getDest().getType()))
3435 auto replacementOp = PadOp::create(
3436 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3437 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3438 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3439 padTensorOp.getHigh(), padTensorOp.getNofold(),
3441 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3443 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3444 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3484struct FoldOrthogonalPaddings :
public OpRewritePattern<PadOp> {
3485 using OpRewritePattern<PadOp>::OpRewritePattern;
3487 LogicalResult matchAndRewrite(PadOp padOp,
3488 PatternRewriter &rewriter)
const override {
3489 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3492 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3493 if (!outerPadOp || outerPadOp.getNofold())
3495 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3500 int64_t rank = padOp.getSourceType().getRank();
3501 if (outerSliceOp.getSourceType().getRank() != rank) {
3503 "cannot fold rank-reducing chain");
3507 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3509 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3513 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3515 "cannot fold PadOps with low padding");
3519 Attribute innerAttr, outerAttr;
3520 Value innerValue = padOp.getConstantPaddingValue();
3521 Value outerValue = outerPadOp.getConstantPaddingValue();
3522 if (!innerValue || !outerValue ||
3525 innerAttr != outerAttr) {
3527 padOp,
"cannot fold PadOps with different padding values");
3531 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3532 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3533 if (innerDims.anyCommon(outerDims)) {
3535 padOp,
"cannot fold PadOps with common padding dimensions");
3543 SmallVector<OpFoldResult> newOffsets(rank, rewriter.
getIndexAttr(0));
3545 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3546 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3547 if (!innerDims.test(en.index()) &&
3549 en.value() = outerOffset;
3552 if (!outerDims.test(en.index()) &&
3554 en.value() = innerOffset;
3558 padOp,
"cannot find zero-offset and zero-padding pair");
3566 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3568 if (!outerDims.test(en.index()))
3570 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3571 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3572 assert(ShapedType::isStatic(sourceSize) &&
3573 "expected padded dimension to have a static size");
3576 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3577 "match the size of the outer padding");
3579 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3583 SmallVector<OpFoldResult> newHighPad(rank, rewriter.
getIndexAttr(0));
3585 if (innerDims.test(en.index()))
3586 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3587 if (outerDims.test(en.index()))
3588 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3593 auto newSliceOp = ExtractSliceOp::create(
3594 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3595 newSizes, innerSliceOp.getMixedStrides());
3596 auto newPadOp = PadOp::create(
3597 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3598 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3601 newPadOp.getRegion().begin());
3602 rewriter.
replaceOp(padOp, newPadOp.getResult());
3607struct FoldStaticPadding :
public OpRewritePattern<PadOp> {
3608 using OpRewritePattern<PadOp>::OpRewritePattern;
3610 LogicalResult matchAndRewrite(PadOp padTensorOp,
3611 PatternRewriter &rewriter)
const override {
3612 Value input = padTensorOp.getSource();
3613 if (!llvm::isa<RankedTensorType>(input.
getType()))
3615 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3616 auto inputRank = inputDims.size();
3618 auto oldResultType =
3619 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3623 auto outputDims = oldResultType.getShape();
3626 SmallVector<int64_t> constOperandsLow;
3627 SmallVector<Value> newLows;
3628 for (
auto operand : padTensorOp.getLow()) {
3631 constOperandsLow.push_back(ShapedType::kDynamic);
3632 newLows.push_back(operand);
3635 constOperandsLow.push_back(intOp.getExtValue());
3637 SmallVector<int64_t> constOperandsHigh;
3638 SmallVector<Value> newHighs;
3639 for (
auto operand : padTensorOp.getHigh()) {
3642 constOperandsHigh.push_back(ShapedType::kDynamic);
3643 newHighs.push_back(operand);
3646 constOperandsHigh.push_back(intOp.getExtValue());
3649 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3650 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3653 if (inputDims.size() != outputDims.size() ||
3654 inputDims.size() != constLow.size() ||
3655 inputDims.size() != constHigh.size())
3660 for (
size_t i = 0; i < inputRank; i++) {
3661 if (constLow[i] == ShapedType::kDynamic)
3662 constLow[i] = constOperandsLow[lowCount++];
3663 if (constHigh[i] == ShapedType::kDynamic)
3664 constHigh[i] = constOperandsHigh[highCount++];
3667 auto staticLow = ArrayRef<int64_t>(constLow);
3668 auto staticHigh = ArrayRef<int64_t>(constHigh);
3671 SmallVector<int64_t> newOutDims;
3672 for (
size_t i = 0; i < inputRank; i++) {
3673 if (outputDims[i] == ShapedType::kDynamic) {
3674 newOutDims.push_back(
3675 (staticLow[i] == ShapedType::kDynamic ||
3676 staticHigh[i] == ShapedType::kDynamic ||
3677 inputDims[i] == ShapedType::kDynamic
3678 ? ShapedType::kDynamic
3679 : inputDims[i] + staticLow[i] + staticHigh[i]));
3681 newOutDims.push_back(outputDims[i]);
3685 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3686 llvm::all_of(newOutDims,
3687 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3691 auto newResultType = RankedTensorType::get(
3692 newOutDims, padTensorOp.getType().getElementType());
3693 auto newOp = PadOp::create(
3694 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3695 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3699 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3727struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3728 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3730 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3731 PatternRewriter &rewriter)
const override {
3732 if (padOp.getNofold()) {
3736 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3737 if (!producerPad || producerPad.getNofold()) {
3739 padOp,
"producer is not a foldable tensor.pad op");
3743 Value consumerPadValue = padOp.getConstantPaddingValue();
3744 Value producerPadValue = producerPad.getConstantPaddingValue();
3745 if (!consumerPadValue || !producerPadValue ||
3746 consumerPadValue != producerPadValue) {
3749 "cannot fold PadOps with different or non-constant padding values");
3752 Location loc = padOp.getLoc();
3757 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3758 ArrayRef<OpFoldResult> producerPaddings) {
3759 SmallVector<OpFoldResult> sumPaddings;
3760 for (
auto [consumerIndex, producerIndex] :
3761 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3763 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3768 SmallVector<OpFoldResult> newHighPad =
3769 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3770 SmallVector<OpFoldResult> newLowPad =
3771 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3773 auto newPadOp = tensor::PadOp::create(
3774 rewriter, padOp.getLoc(), padOp.getResultType(),
3775 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3778 newPadOp.getRegion().begin());
3779 rewriter.
replaceOp(padOp, newPadOp.getResult());
3787PadOp::reifyResultShapes(OpBuilder &
b,
3789 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
3790 SmallVector<OpFoldResult> lp = getMixedLowPad();
3791 SmallVector<OpFoldResult> hp = getMixedHighPad();
3792 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3793 if (!
getType().isDynamicDim(i)) {
3794 reifiedReturnShapes[0][i] =
b.getIndexAttr(
getType().getDimSize(i));
3797 Location loc = getLoc();
3798 Value dim =
b.createOrFold<tensor::DimOp>(
3801 AffineExpr d0, d1, d2;
3804 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3809void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3810 MLIRContext *context) {
3811 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3812 FoldOrthogonalPaddings, FoldStaticPadding,
3813 FoldConsecutiveConstantPadding>(context);
3825Value PadOp::getConstantPaddingValue() {
3826 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3829 Value padValue = yieldOp.getValue();
3840OpFoldResult PadOp::fold(FoldAdaptor) {
3841 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3851OpResult ParallelInsertSliceOp::getTiedOpResult() {
3852 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3853 for (
const auto &it :
3854 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3855 Operation &nextOp = it.value();
3856 if (&nextOp == getOperation())
3857 return parallelCombiningParent.getParentResult(it.index());
3859 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3863void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3864 Value source, Value dest,
3865 ArrayRef<OpFoldResult> offsets,
3866 ArrayRef<OpFoldResult> sizes,
3867 ArrayRef<OpFoldResult> strides,
3868 ArrayRef<NamedAttribute> attrs) {
3869 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3870 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3874 result.addAttributes(attrs);
3875 build(
b,
result, {}, source, dest, dynamicOffsets, dynamicSizes,
3876 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3877 b.getDenseI64ArrayAttr(staticSizes),
3878 b.getDenseI64ArrayAttr(staticStrides));
3883void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3884 Value source, Value dest,
3885 ArrayRef<Range> ranges,
3886 ArrayRef<NamedAttribute> attrs) {
3888 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
3892void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3893 Value source, Value dest,
ValueRange offsets,
3895 ArrayRef<NamedAttribute> attrs) {
3896 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3897 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
3898 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3899 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
3900 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3901 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
3902 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
3907void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
3908 Value dest, ArrayRef<OpFoldResult> sizes,
3909 ArrayRef<NamedAttribute> attrs) {
3910 Attribute zeroIdxAttr =
b.getIndexAttr(0);
3911 Attribute oneIdxAttr =
b.getIndexAttr(1);
3912 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3913 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3914 build(
b,
result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3917LogicalResult ParallelInsertSliceOp::verify() {
3918 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3919 return this->
emitError(
"expected InParallelOpInterface parent, got:")
3920 << *(getOperation()->getParentOp());
3923 RankedTensorType expectedType;
3926 getStaticSizes(), getStaticStrides(), &expectedType);
3933 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3934 getStaticStrides(),
true);
3936 return getOperation()->emitError(boundsResult.
errorMessage);
3941void ParallelInsertSliceOp::getCanonicalizationPatterns(
3942 RewritePatternSet &results, MLIRContext *context) {
3943 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3944 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3945 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3948llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3953MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3954 return getDestMutable();
3957Operation *ParallelInsertSliceOp::getIteratingParent() {
3959 if (
auto combiningOp =
3960 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3961 return combiningOp->getParentOp();
3969void ScatterOp::getAsmResultNames(
3971 setNameFn(getResult(),
"scatter");
3974LogicalResult ScatterOp::verify() {
3975 int64_t destRank = getDestType().getRank();
3976 ArrayRef<int64_t> scatterDims = getScatterDims();
3978 getIndicesType().
getShape(), destRank,
3979 "scatter",
"dest")))
3983 return emitOpError(
"requires 'unique' attribute to be set");
3990 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3991 getDestType(), getIndicesType(), scatterDims,
false);
3992 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3993 getDestType(), getIndicesType(), scatterDims,
true);
3994 if (getSourceType() != expectedSourceType &&
3995 getSourceType() != expectedRankReducedSourceType) {
3999 << expectedSourceType <<
" or its rank-reduced variant "
4000 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
4011void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4012 Type aggregateType,
ValueRange dynamicSizes) {
4013 build(builder,
result, aggregateType, element, dynamicSizes);
4016void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4017 ArrayRef<int64_t> staticShape,
ValueRange dynamicSizes) {
4018 auto aggregateType = RankedTensorType::get(staticShape, element.
getType());
4019 build(builder,
result, aggregateType, element, dynamicSizes);
4022void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4023 ArrayRef<OpFoldResult> sizes) {
4024 SmallVector<int64_t> staticShape;
4025 SmallVector<Value> dynamicSizes;
4027 build(builder,
result, element, staticShape, dynamicSizes);
4030void SplatOp::getAsmResultNames(
4032 setNameFn(getResult(),
"splat");
4035LogicalResult SplatOp::verify() {
4037 return emitOpError(
"incorrect number of dynamic sizes, has ")
4039 <<
getType().getNumDynamicDims();
4044SplatOp::reifyResultShapes(OpBuilder &builder,
4046 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
4048 for (int64_t i = 0; i <
getType().getRank(); ++i) {
4049 if (
getType().isDynamicDim(i)) {
4058OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4059 auto constOperand = adaptor.getInput();
4060 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4064 if (!
getType().hasStaticShape())
4079 if (isa<InsertSliceOp>(op.getOperation()) ||
4080 isa<LoopLikeOpInterface>(op.getOperation()))
4113 isa<linalg::RelayoutOpInterface>(*op))
4121 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4124 replacements.reserve(newOp->getNumResults());
4125 for (
auto [oldResult, newResult] :
4126 llvm::zip(op->getResults(), newOp->getResults())) {
4127 if (newResult.getType() != oldResult.getType()) {
4128 replacements.push_back(tensor::CastOp::create(
4129 rewriter, op->getLoc(), oldResult.
getType(), newResult));
4131 replacements.push_back(newResult);
4144void TensorDialect::getCanonicalizationPatterns(
4145 RewritePatternSet &results)
const {
4153#define GET_OP_CLASSES
4154#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
Attributes are known-constant values of operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.