34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/SmallBitVector.h"
37#include "llvm/ADT/StringRef.h"
38#include "llvm/Support/Casting.h"
39#include "llvm/Support/MathExtras.h"
50 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
52 if (complex::ConstantOp::isBuildableWith(value, type))
53 return complex::ConstantOp::create(builder, loc, type,
54 llvm::cast<ArrayAttr>(value));
60 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
61 if (tensorType.isDynamicDim(dim))
62 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
69 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
71 for (
int64_t i = 0; i < tensorType.getRank(); ++i)
78 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
79 assert(tensorType &&
"expected tensor type");
83 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
85 return destOp.getTiedOpOperand(opResult)->get();
93 if (!tensorType.hasStaticShape()) {
101 for (
int64_t sz : tensorType.getShape())
102 mixedSizes.push_back(
b.getIndexAttr(sz));
107 tensor::EmptyOp::create(
b, loc, mixedSizes, tensorType.getElementType());
115 if (llvm::isa<TensorType>(opResult.getType())) {
117 if (failed(destination))
119 result.push_back(*destination);
126 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
127 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
128 return rtp1.getShape() == rtp2.getShape() &&
129 rtp1.getElementType() == rtp2.getElementType();
139 llvm::SmallBitVector droppedDims(mixedSizes.size());
140 int64_t shapePos = reducedShape.size() - 1;
142 for (
const auto &size : enumerate(llvm::reverse(mixedSizes))) {
143 size_t idx = mixedSizes.size() - size.index() - 1;
145 bool isStaticUnitSize =
146 isa<Attribute>(size.value()) &&
147 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
152 assert(isStaticUnitSize &&
"expected unit dim");
153 droppedDims.set(idx);
158 if (!isStaticUnitSize) {
164 if (reducedShape[shapePos] == 1) {
170 droppedDims.set(idx);
173 assert(shapePos < 0 &&
"dimension mismatch");
180static RankedTensorType
184 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
185 "incorrect number of dynamic sizes");
189 for (
int64_t i = 0, e = type.getRank(); i < e; ++i) {
190 if (type.isDynamicDim(i)) {
191 Value dynamicSize = dynamicSizes[ctr++];
193 if (cst.has_value()) {
195 if (cst.value() < 0) {
196 foldedDynamicSizes.push_back(dynamicSize);
199 staticShape[i] = *cst;
201 foldedDynamicSizes.push_back(dynamicSize);
206 return RankedTensorType::get(staticShape, type.getElementType(),
215 if (inputs.size() != 1 || outputs.size() != 1)
217 Type a = inputs.front(),
b = outputs.front();
218 auto aT = dyn_cast<TensorType>(a);
219 auto bT = dyn_cast<TensorType>(
b);
223 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
234 using OpRewritePattern<BitcastOp>::OpRewritePattern;
236 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
237 PatternRewriter &rewriter)
const final {
238 auto tensorBitcastOperand =
239 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
240 if (!tensorBitcastOperand)
243 auto resultType = cast<TensorType>(tensorBitcast.getType());
244 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
245 tensorBitcastOperand.getOperand());
254 results.
add<ChainedTensorBitcast>(context);
262 setNameFn(getResult(),
"cast");
268 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
269 auto targetType = llvm::dyn_cast<RankedTensorType>(
target);
272 if (!sourceType || !targetType)
276 if (sourceType.getElementType() != targetType.getElementType())
280 if (sourceType.getRank() != targetType.getRank())
284 if (sourceType.getEncoding() != targetType.getEncoding())
288 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
289 if (ShapedType::isStatic(std::get<0>(t)) &&
290 ShapedType::isDynamic(std::get<1>(t)))
326 castOp.getSource().getType());
359 if (llvm::isa<BlockArgument>(opOperand.get()))
361 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
362 return castOp && canFoldIntoConsumerOp(castOp);
369 newOperands.reserve(op->getNumOperands());
375 for (
OpOperand &opOperand : op->getOpOperands()) {
376 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
378 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
379 if (op.isDpsInit(&opOperand) &&
380 !llvm::isa<MemRefType>(newOperands.back().getType()))
381 newResTy[dpsInitIdx++] = newOperands.back().getType();
391 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
393 operand.set(castOp.getOperand());
401 if (inputs.size() != 1 || outputs.size() != 1)
403 Type a = inputs.front(),
b = outputs.front();
404 auto aT = llvm::dyn_cast<TensorType>(a);
405 auto bT = llvm::dyn_cast<TensorType>(
b);
409 if (aT.getElementType() != bT.getElementType())
426 if (rank != two.getRank())
431 for (
int64_t i = 0; i < rank; ++i) {
432 if (one.isDynamicDim(i)) {
433 join.push_back(two.getDimSize(i));
436 if (two.isDynamicDim(i)) {
437 join.push_back(one.getDimSize(i));
440 if (one.getDimSize(i) != two.getDimSize(i))
442 join.push_back(one.getDimSize(i));
452 using OpRewritePattern<CastOp>::OpRewritePattern;
454 LogicalResult matchAndRewrite(CastOp tensorCast,
455 PatternRewriter &rewriter)
const final {
456 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
458 if (!tensorCastOperand)
462 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
463 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
464 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
478 auto newJoin =
joinShapes(sourceType, resultType);
479 if (firstJoin != newJoin)
482 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
483 tensorCastOperand.getOperand());
501 using OpRewritePattern<CastOp>::OpRewritePattern;
503 LogicalResult matchAndRewrite(CastOp tensorCast,
504 PatternRewriter &rewriter)
const final {
505 auto extractOperand =
506 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
509 auto rankedResultType =
510 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
511 if (!rankedResultType)
515 rankedResultType.getShape() ==
516 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
520 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
522 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
524 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
525 if (dimMask && dimMask->count(i))
527 int64_t dim = rankedResultType.getShape()[dimIndex++];
528 if (ShapedType::isDynamic(dim))
530 sizes[i] = rewriter.getIndexAttr(dim);
533 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
534 tensorCast, rankedResultType, extractOperand.getSource(),
535 extractOperand.getMixedOffsets(), sizes,
536 extractOperand.getMixedStrides());
545 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
552RankedTensorType ConcatOp::inferResultType(
int64_t dim,
TypeRange inputTypes) {
553 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
555 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
556 int64_t concatRank = tensorTypes[0].getRank();
559 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
562 for (
int64_t i = 0, e = concatRank; i < e; ++i) {
566 for (
auto tensorType : tensorTypes)
571 for (
auto tensorType : tensorTypes)
574 sizes[dim] = concatSize.asInteger();
575 return RankedTensorType::get(sizes, tensorTypes[0].
getElementType());
580 FailureOr<RankedTensorType> resultType =
581 inferResultType(dim, inputs.
getTypes());
582 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
583 build(builder,
result, *resultType, dim, inputs);
586LogicalResult ConcatOp::verify() {
587 if (getInputs().size() < 1)
591 for (
auto input : getInputs())
592 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
594 RankedTensorType resultType = getResultType();
595 int64_t resultRank = getRank();
596 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
597 return type.getRank() != resultRank;
599 return emitOpError(
"rank of concatenated inputs must match result rank");
601 Type resultElementType = resultType.getElementType();
602 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
603 return type.getElementType() != resultElementType;
605 return emitOpError(
"inputs and result element type must match");
608 if (dim >= resultRank)
609 return emitOpError(
"concatenation dim must be less than the tensor rank");
612 for (
int64_t i = 0, e = resultRank; i < e; ++i) {
616 for (
auto tensorType : inputTypes) {
617 FailureOr<SaturatedInteger> maybeSize =
620 return emitOpError(
"static concatenation size mismatch along ")
621 <<
"non-concatenated dimension " << i;
627 for (
auto tensorType : inputTypes)
630 sizes[dim] = concatSize.asInteger();
631 auto inferredResultType =
634 for (
auto [inferredSize, actualSize] :
635 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
636 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
637 ShapedType::isDynamic(actualSize);
638 if (!hasDynamic && inferredSize != actualSize)
640 << resultType <<
"does not match inferred shape "
641 << inferredResultType <<
" static sizes";
647FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
648 size_t numInputs = getInputs().size();
649 uint64_t concatDim = getDim();
652 inputShapes.reserve(numInputs);
654 concatOffsets.reserve(numInputs);
661 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
665 outputShape = inputShape;
666 concatOffsets.push_back(zero);
668 concatOffsets.push_back(outputShape[concatDim]);
670 builder, loc, addExpr,
671 {outputShape[concatDim], inputShape[concatDim]});
673 inputShapes.emplace_back(std::move(inputShape));
683 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
684 offsets[concatDim] = concatOffsets[
index];
685 auto insertSlice = tensor::InsertSliceOp::create(
696ConcatOp::reifyResultShapes(
OpBuilder &builder,
700 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
702 Value init = inputs[0];
710 for (
int64_t i = 0; i < rank; ++i) {
713 if (!
getType().isDynamicDim(i)) {
715 }
else if (!inferredResultType.isDynamicDim(i)) {
718 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
720 reifiedReturnShapes[0][i] =
721 tensor::DimOp::create(builder, init.
getLoc(), init, i).getResult();
725 if (
getType().isDynamicDim(dim)) {
730 for (
auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
733 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
741 reifiedReturnShapes[0][dim] =
747void ConcatOp::getAsmResultNames(
749 setNameFn(getResult(),
"concat");
754 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
762 using OpRewritePattern<ConcatOp>::OpRewritePattern;
764 LogicalResult matchAndRewrite(ConcatOp concatOp,
765 PatternRewriter &rewriter)
const override {
766 if (concatOp.getInputs().size() != 1)
769 concatOp.getInputs()[0]);
794 using OpRewritePattern<ConcatOp>::OpRewritePattern;
796 LogicalResult matchAndRewrite(ConcatOp concatOp,
797 PatternRewriter &rewriter)
const override {
798 int64_t dim = concatOp.getDim();
799 RankedTensorType inferredResultType =
800 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
803 LogicalResult matched = failure();
806 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
807 for (
auto [operandIdx, operandType] :
808 llvm::enumerate(concatOp->getOperandTypes())) {
810 inferredOperandShape[dim] =
811 cast<RankedTensorType>(operandType).getDimSize(dim);
812 auto inferredOperandType = RankedTensorType::get(
813 inferredOperandShape, inferredResultType.getElementType());
821 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
822 concatOp.getOperand(operandIdx));
824 concatOp->setOperand(operandIdx, castOp->getResult(0));
848 using OpRewritePattern<ConcatOp>::OpRewritePattern;
850 LogicalResult matchAndRewrite(ConcatOp concatOp,
851 PatternRewriter &rewriter)
const override {
852 int64_t dim = concatOp.getDim();
853 RankedTensorType inferredResultType =
854 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
858 concatOp.getResultType())) {
863 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
864 concatOp->getOperands());
876 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
885 setNameFn(getResult(),
"dim");
890 auto loc =
result.location;
892 build(builder,
result, source, indexValue);
895std::optional<int64_t> DimOp::getConstantIndex() {
904 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
905 if (!rankedSourceType)
908 if (rankedSourceType.getRank() <= constantIndex)
916 setResultRange(getResult(),
922 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
927 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
934 if (indexVal < 0 || indexVal >= tensorType.getRank())
938 if (!tensorType.isDynamicDim(
index.getInt())) {
943 Operation *definingOp = getSource().getDefiningOp();
946 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
948 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
951 assert(ShapedType::isDynamic(resultType.getShape()[
index.getInt()]));
954 auto dynExtents = fromElements.getDynamicExtents().begin();
955 for (
auto dim : resultType.getShape().take_front(
index.getInt()))
956 if (ShapedType::isDynamic(dim))
959 return Value{*dynExtents};
963 unsigned unsignedIndex =
index.getValue().getZExtValue();
965 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
968 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
969 sliceOp.isDynamicSize(unsignedIndex)) {
970 return {sliceOp.getDynamicSize(unsignedIndex)};
984 using OpRewritePattern<DimOp>::OpRewritePattern;
986 LogicalResult matchAndRewrite(DimOp dimOp,
987 PatternRewriter &rewriter)
const override {
988 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
991 Value newSource = castOp.getOperand();
1000 using OpRewritePattern<DimOp>::OpRewritePattern;
1002 LogicalResult matchAndRewrite(DimOp dimOp,
1003 PatternRewriter &rewriter)
const override {
1004 auto source = dimOp.getSource();
1005 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1009 auto resultIndex = cast<OpResult>(source).getResultNumber();
1010 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1013 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1021 using OpRewritePattern<DimOp>::OpRewritePattern;
1023 LogicalResult matchAndRewrite(DimOp dim,
1024 PatternRewriter &rewriter)
const override {
1025 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1033 Location loc = dim.getLoc();
1035 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1036 if (extract.
getType() != dim.getType())
1038 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1047 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1057 assert(none_of(staticShape, ShapedType::isDynamic) &&
1058 "expected only static sizes");
1062void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1063 ArrayRef<int64_t> staticShape, Type elementType,
1064 ValueRange dynamicSizes, Attribute encoding) {
1065 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1066 build(builder,
result, tensorType, dynamicSizes);
1069void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1070 ArrayRef<OpFoldResult> sizes, Type elementType,
1071 Attribute encoding) {
1072 SmallVector<int64_t> staticShape;
1073 SmallVector<Value> dynamicSizes;
1075 build(builder,
result, staticShape, elementType, dynamicSizes, encoding);
1078LogicalResult EmptyOp::verify() {
1084EmptyOp::reifyResultShapes(OpBuilder &builder,
1086 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1088 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1089 if (
getType().isDynamicDim(i)) {
1098Value EmptyOp::getDynamicSize(
unsigned idx) {
1099 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1101 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1102 if (
getType().isDynamicDim(i))
1107SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1108 SmallVector<OpFoldResult>
result;
1111 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1112 if (
getType().isDynamicDim(i)) {
1133struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1134 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1136 LogicalResult matchAndRewrite(EmptyOp op,
1137 PatternRewriter &rewriter)
const override {
1138 SmallVector<Value> foldedDynamicSizes;
1140 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1143 if (foldedTensorType == op.getType())
1146 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1147 foldedDynamicSizes);
1153struct FoldEmptyTensorWithDimOp :
public OpRewritePattern<DimOp> {
1154 using OpRewritePattern<DimOp>::OpRewritePattern;
1156 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1157 PatternRewriter &rewriter)
const override {
1158 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1159 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1160 if (!emptyTensorOp || !maybeConstantIndex)
1162 auto emptyTensorType = emptyTensorOp.getType();
1163 if (*maybeConstantIndex < 0 ||
1164 *maybeConstantIndex >= emptyTensorType.getRank() ||
1165 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1168 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1188struct FoldEmptyTensorWithCastOp :
public OpRewritePattern<CastOp> {
1189 using OpRewritePattern<CastOp>::OpRewritePattern;
1191 LogicalResult matchAndRewrite(CastOp castOp,
1192 PatternRewriter &rewriter)
const override {
1195 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1200 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1201 ArrayRef<int64_t> resultShape = resultType.getShape();
1202 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1203 SmallVector<OpFoldResult> newMixedSizes;
1204 newMixedSizes.reserve(currMixedSizes.size());
1205 assert(resultShape.size() == currMixedSizes.size() &&
1206 "mismatch in result shape and sizes of empty op");
1207 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1208 int64_t newDim = std::get<0>(it);
1209 OpFoldResult currDim = std::get<1>(it);
1212 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1213 if (ShapedType::isDynamic(newDim) ||
1214 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1219 producer,
"mismatch in static value of shape of empty tensor "
1220 "result and cast result");
1222 newMixedSizes.push_back(attr);
1228 if (ShapedType::isStatic(newDim)) {
1229 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1235 newMixedSizes.push_back(currDim);
1240 resultType.getElementType());
1247void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1248 MLIRContext *context) {
1249 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1250 ReplaceEmptyTensorStaticShapeDims>(context);
1267struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1268 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1270 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1271 PatternRewriter &rewriter)
const final {
1272 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1275 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1278 extract, tensorCast.getSource(), extract.getIndices());
1293struct ExtractFromCollapseShape :
public OpRewritePattern<tensor::ExtractOp> {
1294 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1296 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1297 PatternRewriter &rewriter)
const final {
1299 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1302 if (!collapseOp.getSrcType().hasStaticShape())
1305 auto sourceSizes = collapseOp.getSrcType().getShape();
1307 SmallVector<Value>
indices(extractOp.getIndices().begin(),
1308 extractOp.getIndices().end());
1309 SmallVector<Value> sourceIndices;
1310 for (
auto [index, group] :
1311 llvm::zip(
indices, collapseOp.getReassociationIndices())) {
1312 assert(!group.empty() &&
"association indices groups cannot be empty");
1313 auto groupSize = group.size();
1315 if (groupSize == 1) {
1316 sourceIndices.push_back(index);
1320 SmallVector<int64_t> basis =
1321 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
1322 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1323 rewriter, extractOp.getLoc(), index, basis,
true);
1324 llvm::append_range(sourceIndices,
delinearize.getResults());
1326 if (collapseOp.getReassociationIndices().empty()) {
1329 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1331 rewriter, extractOp.getLoc(), zeroAffineMap,
1332 ArrayRef<OpFoldResult>{});
1333 for (int64_t i = 0; i < srcRank; i++) {
1334 sourceIndices.push_back(
1340 extractOp, collapseOp.getSrc(), sourceIndices);
1347void ExtractOp::getAsmResultNames(
1349 setNameFn(getResult(),
"extracted");
1352LogicalResult ExtractOp::verify() {
1354 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1355 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1356 return emitOpError(
"incorrect number of indices for extract_element");
1365 auto insertOp = extractOp.getTensor().
getDefiningOp<InsertOp>();
1370 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1371 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1372 return insertOp.getScalar();
1377OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1378 if (Attribute tensor = adaptor.getTensor()) {
1381 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1382 return splatTensor.getSplatValue<Attribute>();
1385 if (isa<DenseResourceElementsAttr>(tensor))
1390 SmallVector<uint64_t, 8>
indices;
1391 for (Attribute indice : adaptor.getIndices()) {
1392 if (!indice || !llvm::isa<IntegerAttr>(indice))
1394 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1398 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1399 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1400 auto rank = tensorType.getRank();
1401 assert(
static_cast<int64_t
>(
indices.size()) == tensorType.getRank() &&
1405 for (
int i = rank - 1; i >= 0; --i) {
1406 flatIndex +=
indices[i] * stride;
1407 stride *= tensorType.getDimSize(i);
1411 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1414 return fromElementsOp.getElements()[flatIndex];
1418 if (Attribute tensor = adaptor.getTensor()) {
1419 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1420 if (elementsAttr && elementsAttr.isValidIndex(
indices))
1421 return elementsAttr.getValues<Attribute>()[
indices];
1430void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1431 MLIRContext *context) {
1432 results.
add<ExtractFromTensorCast>(context);
1444void FromElementsOp::getAsmResultNames(
1446 setNameFn(getResult(),
"from_elements");
1451 assert(!elements.empty() &&
"expected at least one element");
1452 Type resultType = RankedTensorType::get(
1453 {
static_cast<int64_t>(elements.size())}, elements.front().
getType());
1454 build(builder,
result, resultType, elements);
1457OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1458 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1481struct ExtractElementFromIndexCast
1482 :
public OpRewritePattern<tensor::ExtractOp> {
1483 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1485 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1486 PatternRewriter &rewriter)
const final {
1487 Location loc = extract.getLoc();
1488 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1494 auto newExtract = tensor::ExtractOp::create(
1495 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1506void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1507 MLIRContext *context) {
1508 results.
add<ExtractElementFromIndexCast>(context);
1515void GatherOp::getAsmResultNames(
1517 setNameFn(getResult(),
"gather");
1532RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1533 RankedTensorType indicesType,
1534 ArrayRef<int64_t> gatherDims,
1536 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1537 resultShape.reserve(resultShape.size() + sourceType.getRank());
1538 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1539 if (llvm::binary_search(gatherDims, idx)) {
1541 resultShape.push_back(1);
1544 resultShape.push_back(sourceType.getDimSize(idx));
1546 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1552 StringRef gatherOrScatter, StringRef sourceOrDest) {
1554 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1556 int64_t numGatherDims = dims.size();
1557 if (numGatherDims > rank)
1559 <<
"_dims overflow " << sourceOrDest <<
" rank";
1562 <<
"_dims length must match the size of last dimension of indices";
1566 <<
"_dims value must be non-negative";
1569 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1571 for (
int64_t i = 1; i < numGatherDims; ++i) {
1572 if (dims[i - 1] >= dims[i])
1574 <<
"_dims values must be strictly increasing";
1579LogicalResult GatherOp::verify() {
1580 int64_t sourceRank = getSourceType().getRank();
1581 ArrayRef<int64_t> gatherDims = getGatherDims();
1583 getIndicesType().
getShape(), sourceRank,
1584 "gather",
"source")))
1587 RankedTensorType expectedResultType = GatherOp::inferResultType(
1588 getSourceType(), getIndicesType(), gatherDims,
false);
1589 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1590 getSourceType(), getIndicesType(), gatherDims,
true);
1591 if (getResultType() != expectedResultType &&
1592 getResultType() != expectedRankReducedResultType) {
1596 << expectedResultType <<
" or its rank-reduced variant "
1597 << expectedRankReducedResultType <<
" (got: " << getResultType()
1604OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1605 if (OpFoldResult reshapedSource = reshapeConstantSource(
1606 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1608 return reshapedSource;
1616void InsertOp::getAsmResultNames(
1618 setNameFn(getResult(),
"inserted");
1621LogicalResult InsertOp::verify() {
1623 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1624 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1625 return emitOpError(
"incorrect number of indices");
1629OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1630 Attribute scalar = adaptor.getScalar();
1631 Attribute dest = adaptor.getDest();
1633 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1634 if (scalar == splatDest.getSplatValue<Attribute>())
1643void GenerateOp::getAsmResultNames(
1645 setNameFn(getResult(),
"generated");
1648LogicalResult GenerateOp::reifyResultShapes(
1650 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1652 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1653 if (
getType().isDynamicDim(dim)) {
1654 reifiedReturnShapes[0][dim] = getOperand(idx++);
1656 reifiedReturnShapes[0][dim] =
1663LogicalResult GenerateOp::verify() {
1666 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1673LogicalResult GenerateOp::verifyRegions() {
1674 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1676 if (!llvm::all_of(getBody().getArgumentTypes(),
1677 [](Type ty) {
return ty.
isIndex(); }))
1678 return emitError(
"all body arguments must be index");
1679 if (getBody().getNumArguments() != resultTy.getRank())
1680 return emitError(
"must have one body argument per input dimension");
1683 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1685 if (yieldOp.getValue().getType() != resultTy.getElementType())
1687 "body must be terminated with a `yield` operation of the tensor "
1693void GenerateOp::build(
1694 OpBuilder &
b, OperationState &
result, Type resultTy,
1697 build(
b,
result, resultTy, dynamicExtents);
1700 OpBuilder::InsertionGuard guard(
b);
1701 Region *bodyRegion =
result.regions.front().get();
1702 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1703 SmallVector<Type, 2> argumentTypes(rank,
b.getIndexType());
1704 SmallVector<Location, 2> argumentLocs(rank,
result.location);
1706 b.createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1716struct StaticTensorGenerate :
public OpRewritePattern<GenerateOp> {
1717 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1719 LogicalResult matchAndRewrite(GenerateOp generateOp,
1720 PatternRewriter &rewriter)
const final {
1721 SmallVector<Value> foldedDynamicSizes;
1723 generateOp.getType(), generateOp.getDynamicExtents(),
1724 foldedDynamicSizes);
1727 if (foldedTensorType == generateOp.getType())
1730 auto loc = generateOp.getLoc();
1732 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1734 newOp.getBody().begin());
1736 generateOp.getType(), newOp);
1752struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1753 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1755 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1756 PatternRewriter &rewriter)
const final {
1757 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1762 Block *body = &tensorFromElements.getBody().front();
1765 rewriter.
clone(op, mapping);
1776void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1777 MLIRContext *context) {
1779 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1786void RankOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1787 setNameFn(getResult(),
"rank");
1790OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1792 auto type = getOperand().getType();
1793 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1794 if (shapedType && shapedType.hasRank())
1795 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
1796 return IntegerAttr();
1803void ReshapeOp::getAsmResultNames(
1805 setNameFn(getResult(),
"reshape");
1810 for (
auto dim : type.getShape())
1815LogicalResult ReshapeOp::verify() {
1816 TensorType operandType = llvm::cast<TensorType>(getSource().
getType());
1817 TensorType resultType = llvm::cast<TensorType>(getResult().
getType());
1820 return emitOpError(
"element types of source and destination tensor "
1821 "types should be the same");
1825 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1826 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1828 if (resultRankedType) {
1829 if (operandRankedType && resultRankedType.hasStaticShape() &&
1830 operandRankedType.hasStaticShape()) {
1832 return emitOpError(
"source and destination tensor should have the "
1833 "same number of elements");
1835 if (ShapedType::isDynamic(shapeSize))
1836 return emitOpError(
"cannot use shape operand with dynamic length to "
1837 "reshape to statically-ranked tensor type");
1838 if (shapeSize != resultRankedType.getRank())
1840 "length of shape operand differs from the result's tensor rank");
1845OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1846 if (OpFoldResult reshapedSource = reshapeConstantSource(
1847 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1849 return reshapedSource;
1854 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1855 getSourceMutable().assign(reshapeOpProducer.getSource());
1859 auto source = getSource();
1860 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1861 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1862 if (!sourceTy || !resultTy || sourceTy != resultTy)
1867 if (sourceTy.getRank() <= 1)
1870 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1871 auto elements = fromElements.getElements();
1873 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1874 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++
id) {
1875 auto element = elements[id];
1878 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1882 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1883 dynamicNoop &= dimOp.getSource() == source;
1887 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1891 dynamicNoop =
false;
1906void CollapseShapeOp::getAsmResultNames(
1908 setNameFn(getResult(),
"collapsed");
1911void ExpandShapeOp::getAsmResultNames(
1913 setNameFn(getResult(),
"expanded");
1916int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1917 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1918 "invalid resultDim");
1919 for (
const auto &it : llvm::enumerate(getReassociationIndices()))
1920 if (llvm::is_contained(it.value(), resultDim))
1922 llvm_unreachable(
"could not find reassociation group");
1925FailureOr<SmallVector<OpFoldResult>>
1926ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
1927 RankedTensorType expandedType,
1928 ArrayRef<ReassociationIndices> reassociation,
1929 ArrayRef<OpFoldResult> inputShape) {
1930 std::optional<SmallVector<OpFoldResult>> outputShape =
1935 return *outputShape;
1938SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1942void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1943 Type resultType, Value src,
1944 ArrayRef<ReassociationIndices> reassociation,
1945 ArrayRef<OpFoldResult> outputShape) {
1946 auto [staticOutputShape, dynamicOutputShape] =
1948 build(builder,
result, cast<RankedTensorType>(resultType), src,
1950 dynamicOutputShape, staticOutputShape);
1953void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1954 Type resultType, Value src,
1955 ArrayRef<ReassociationIndices> reassociation) {
1956 SmallVector<OpFoldResult> inputShape =
1958 auto tensorResultTy = cast<RankedTensorType>(resultType);
1959 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1960 builder,
result.location, tensorResultTy, reassociation, inputShape);
1961 SmallVector<OpFoldResult> outputShapeOrEmpty;
1962 if (succeeded(outputShape)) {
1963 outputShapeOrEmpty = *outputShape;
1965 build(builder,
result, tensorResultTy, src, reassociation,
1966 outputShapeOrEmpty);
1969SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1972SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1974 getReassociationIndices());
1977SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1980SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1982 getReassociationIndices());
1985RankedTensorType CollapseShapeOp::inferCollapsedType(
1986 RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
1987 return inferCollapsedType(
1989 type.getContext(), reassociation)));
1995CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1996 ArrayRef<AffineMap> reassociation) {
1997 auto shape = type.getShape();
1998 SmallVector<int64_t, 4> newShape;
1999 newShape.reserve(reassociation.size());
2004 unsigned currentDim = 0;
2005 for (AffineMap m : reassociation) {
2006 unsigned dim = m.getNumResults();
2007 auto band = shape.slice(currentDim, dim);
2009 if (llvm::is_contained(band, ShapedType::kDynamic))
2010 size = ShapedType::kDynamic;
2012 for (
unsigned d = 0; d < dim; ++d)
2013 size *= shape[currentDim + d];
2014 newShape.push_back(size);
2018 return RankedTensorType::get(newShape, type.getElementType());
2021void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2022 ArrayRef<ReassociationIndices> reassociation,
2023 ArrayRef<NamedAttribute> attrs) {
2024 auto srcType = llvm::cast<RankedTensorType>(src.
getType());
2025 RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2027 RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2028 srcType.getEncoding());
2029 result.addAttribute(getReassociationAttrStrName(),
2031 build(
b,
result, resultType, src, attrs);
2034template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
2035 TensorReshapeOp, ExpandShapeOp>::value>
2037 RankedTensorType expandedType,
2038 RankedTensorType collapsedType) {
2040 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2043 auto maps = op.getReassociationMaps();
2044 RankedTensorType expectedType =
2045 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2047 return op.emitOpError(
"expected collapsed type to be ")
2048 << expectedType <<
", but got " << collapsedType;
2052LogicalResult ExpandShapeOp::verify() {
2053 auto srcType = getSrcType();
2054 auto resultType = getResultType();
2056 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2057 return emitOpError(
"expected number of static shape dims to be equal to "
2058 "the output rank (")
2059 << resultType.getRank() <<
") but found "
2060 << getStaticOutputShape().size() <<
" inputs instead";
2062 if ((int64_t)getOutputShape().size() !=
2063 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2064 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2065 "static_output_shape: static_output_shape has ")
2066 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2067 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2073LogicalResult CollapseShapeOp::verify() {
2074 CollapseShapeOp op = *
this;
2075 if (llvm::any_of(op.getReassociationIndices(),
2077 return op.emitOpError(
"reassociation indices must not be empty");
2085template <
typename TensorReshapeOp>
2086struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2087 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2088 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2089 PatternRewriter &rewriter)
const override {
2090 DenseElementsAttr attr;
2096 reshapeOp.getResultType(), attr.
getRawData());
2103template <
typename TensorReshapeOp>
2104class FoldReshapeWithSplat :
public OpRewritePattern<TensorReshapeOp> {
2106 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2108 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2109 PatternRewriter &rewriter)
const override {
2110 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2111 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2115 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2122template <
typename TensorReshapeOp>
2123struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2124 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2125 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2126 PatternRewriter &rewriter)
const override {
2128 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2132 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2134 if (!shapedTy.hasStaticShape())
2138 fromElements.getElements());
2144struct FoldCollapseOfCastOp :
public OpRewritePattern<CollapseShapeOp> {
2145 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2147 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2148 PatternRewriter &rewriter)
const override {
2149 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2153 RankedTensorType srcType =
2154 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2155 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2156 srcType, collapseShapeOp.getReassociationMaps());
2158 if (newResultType == collapseShapeOp.getResultType()) {
2160 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2163 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2164 newResultType, castOp.getSource(),
2165 collapseShapeOp.getReassociation());
2167 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2177struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2178 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2180 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2181 PatternRewriter &rewriter)
const override {
2182 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2186 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2187 SmallVector<ReassociationIndices, 4> reassoc =
2188 expandOp.getReassociationIndices();
2190 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2191 SmallVector<Value> dynamicOutputShape;
2192 auto outputIt = expandOp.getOutputShape().begin();
2194 for (
const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2195 for (uint64_t outDim : innerReassoc) {
2196 if (ShapedType::isStatic(newOutputShape[outDim]))
2203 Value val = *outputIt;
2205 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2206 dynamicOutputShape.push_back(val);
2212 newOutputShape[outDim] = cst.getSExtValue();
2214 dynamicOutputShape.push_back(val);
2220 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2224 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2225 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2226 for (
auto outDim : reassoc[inDim]) {
2227 auto ofr = newOutputShape[outDim];
2228 if (ShapedType::isDynamic(ofr)) {
2229 newInputShape[inDim] = ShapedType::kDynamic;
2232 newInputShape[inDim] *= ofr;
2236 SmallVector<OpFoldResult> outputOfr =
2238 auto inputType = RankedTensorType::get(
2239 newInputShape, expandOp.getSrcType().getElementType());
2240 auto outputType = RankedTensorType::get(
2241 newOutputShape, expandOp.getSrcType().getElementType());
2242 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2244 auto newExpand = ExpandShapeOp::create(
2245 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2246 expandOp.getReassociationIndices(), outputOfr);
2248 newExpand.getResult());
2254void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2255 MLIRContext *context) {
2257 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2258 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2259 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2260 FoldReshapeWithSplat<ExpandShapeOp>,
2261 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2264void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2265 MLIRContext *context) {
2267 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2268 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2269 tensor::DimOp, RankedTensorType>,
2270 FoldReshapeWithConstant<CollapseShapeOp>,
2271 FoldReshapeWithSplat<CollapseShapeOp>,
2272 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2276OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2278 adaptor.getOperands());
2281OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2283 adaptor.getOperands());
2290void ExtractSliceOp::getAsmResultNames(
2292 setNameFn(getResult(),
"extracted_slice");
2299ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2300 ArrayRef<int64_t> staticSizes) {
2304 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2305 sourceTensorType.getRank() &&
2306 "unexpected staticSizes not equal to rank of source");
2307 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2308 sourceTensorType.getEncoding());
2313ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2314 ArrayRef<OpFoldResult> sizes) {
2315 SmallVector<int64_t> staticSizes;
2318 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2319 sourceTensorType.getRank() &&
2320 "unexpected staticSizes not equal to rank of source");
2321 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2322 sourceTensorType.getEncoding());
2333RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2334 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2335 ArrayRef<int64_t> sizes) {
2337 auto inferredType = llvm::cast<RankedTensorType>(
2338 inferResultType(sourceRankedTensorType, sizes));
2339 int rankDiff = inferredType.getRank() - desiredResultRank;
2341 auto shape = inferredType.getShape();
2342 llvm::SmallBitVector dimsToProject =
2344 SmallVector<int64_t> projectedShape;
2346 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2347 if (!dimsToProject.test(pos))
2348 projectedShape.push_back(shape[pos]);
2350 RankedTensorType::get(projectedShape, inferredType.getElementType());
2352 return inferredType;
2355RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2356 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2357 ArrayRef<OpFoldResult> sizes) {
2358 SmallVector<int64_t> staticSizes;
2359 SmallVector<Value> dynamicSizes;
2361 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2362 desiredResultRank, sourceRankedTensorType, staticSizes);
2367void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2368 RankedTensorType resultType, Value source,
2369 ArrayRef<OpFoldResult> offsets,
2370 ArrayRef<OpFoldResult> sizes,
2371 ArrayRef<OpFoldResult> strides,
2372 ArrayRef<NamedAttribute> attrs) {
2373 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2374 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2378 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2381 resultType = llvm::cast<RankedTensorType>(
2382 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2384 result.addAttributes(attrs);
2385 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
2386 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2387 b.getDenseI64ArrayAttr(staticSizes),
2388 b.getDenseI64ArrayAttr(staticStrides));
2393void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2394 ArrayRef<OpFoldResult> offsets,
2395 ArrayRef<OpFoldResult> sizes,
2396 ArrayRef<OpFoldResult> strides,
2397 ArrayRef<NamedAttribute> attrs) {
2398 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2403void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2404 ArrayRef<Range> ranges,
2405 ArrayRef<NamedAttribute> attrs) {
2407 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2412void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2413 RankedTensorType resultType, Value source,
2415 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2416 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2417 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
2418 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2419 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
2420 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2421 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
2422 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
2426void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2428 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2429 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2434 RankedTensorType expectedType) {
2439 return op->
emitError(
"expected rank to be smaller or equal to ")
2440 <<
"the other rank. ";
2442 return op->
emitError(
"expected type to be ")
2443 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2445 return op->
emitError(
"expected element type to be ")
2446 << expectedType.getElementType();
2448 llvm_unreachable(
"unexpected extract_slice op verification result");
2454void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2455 RankedTensorType resultType, Value source,
2456 ArrayRef<OpFoldResult> sizes,
2457 ArrayRef<NamedAttribute> attrs) {
2458 Attribute zeroIdxAttr =
b.getIndexAttr(0);
2459 Attribute oneIdxAttr =
b.getIndexAttr(1);
2460 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2461 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2462 build(
b,
result, resultType, source, readOffsets, sizes, readStrides, attrs);
2466LogicalResult ExtractSliceOp::verify() {
2467 RankedTensorType sourceType = getSourceType();
2470 RankedTensorType expectedType =
2471 ExtractSliceOp::inferResultType(sourceType,
getMixedSizes());
2479 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2480 getStaticStrides(),
true);
2482 return getOperation()->emitError(boundsResult.
errorMessage);
2487llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2492ExtractSliceOp::rankReduceIfNeeded(OpBuilder &
b, Location loc, Value value,
2493 ArrayRef<int64_t> desiredShape) {
2494 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2495 assert(sourceTensorType &&
"not a ranked tensor type");
2496 auto sourceShape = sourceTensorType.getShape();
2497 if (sourceShape.equals(desiredShape))
2499 auto maybeRankReductionMask =
2501 if (!maybeRankReductionMask)
2505 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2508LogicalResult ExtractSliceOp::reifyResultShapes(
2510 reifiedReturnShapes.resize(1);
2511 reifiedReturnShapes[0].reserve(
getType().getRank());
2514 for (
const auto &size :
enumerate(mixedSizes)) {
2515 if (droppedDims.test(size.index()))
2517 reifiedReturnShapes[0].push_back(size.value());
2538class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2540 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2542 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2543 PatternRewriter &rewriter)
const override {
2545 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2546 return matchPattern(operand, matchConstantIndex());
2550 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2559 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2560 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2561 sliceOp.getStaticStrides());
2566 Location loc = sliceOp.getLoc();
2567 Value newResult = ExtractSliceOp::create(
2568 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2569 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2570 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2571 sliceOp.getStaticStrides());
2580template <
typename IterTy,
typename ElemTy>
2581static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2582 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2583 ArrayRef<int64_t> strides,
2584 llvm::SmallVectorImpl<ElemTy> *outValues) {
2585 assert(offsets.size() == sizes.size());
2586 assert(offsets.size() == strides.size());
2587 if (offsets.empty())
2590 int64_t offset = offsets.front();
2591 int64_t size = sizes.front();
2592 int64_t stride = strides.front();
2593 if (offsets.size() == 1) {
2594 for (int64_t i = 0; i < size; ++i, offset += stride)
2595 outValues->push_back(*(values + offset));
2600 for (int64_t i = 0; i < size; ++i, offset += stride) {
2601 auto begin = values + offset * counts.front();
2602 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2603 offsets.drop_front(), sizes.drop_front(),
2604 strides.drop_front(), outValues);
2611class ConstantOpExtractSliceFolder final
2612 :
public OpRewritePattern<ExtractSliceOp> {
2614 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2616 ConstantOpExtractSliceFolder(MLIRContext *context,
2618 : OpRewritePattern<ExtractSliceOp>(context),
2619 controlFn(std::move(controlFn)) {}
2621 LogicalResult matchAndRewrite(ExtractSliceOp op,
2622 PatternRewriter &rewriter)
const override {
2623 DenseElementsAttr attr;
2632 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2633 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2634 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2641 int64_t count = sourceType.getNumElements();
2646 auto offsets = op.getStaticOffsets();
2647 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2649 auto sizes = op.getStaticSizes();
2650 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2652 auto strides = op.getStaticStrides();
2653 if (llvm::is_contained(strides, ShapedType::kDynamic))
2657 SmallVector<int64_t> counts;
2658 ArrayRef<int64_t> shape = sourceType.getShape();
2659 counts.reserve(shape.size());
2660 for (int64_t v : shape) {
2662 counts.push_back(count);
2666 DenseElementsAttr newAttr;
2668 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2669 SmallVector<APInt> outValues;
2670 outValues.reserve(sourceType.getNumElements());
2671 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2672 elems.begin(), counts, offsets, sizes, strides, &outValues);
2674 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2675 SmallVector<APFloat> outValues;
2676 outValues.reserve(sourceType.getNumElements());
2677 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2678 elems.begin(), counts, offsets, sizes, strides, &outValues);
2701 patterns.add<ConstantOpExtractSliceFolder>(
patterns.getContext(), controlFn);
2710 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2711 op.getType().getRank(), op.getSourceType(), mixedSizes);
2718 ExtractSliceOp newOp) {
2721 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2727void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2728 MLIRContext *context) {
2730 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2731 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2732 ExtractSliceOpCastFolder>(context);
2738 ShapedType shapedType) {
2745 auto shape = shapedType.getShape();
2746 for (
auto it : llvm::zip(op.getMixedSizes(),
shape))
2760 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2763 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2764 insertOp.isSameAs(extractOp, isSame))
2765 return insertOp.getSource();
2770OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2771 if (OpFoldResult reshapedSource = reshapeConstantSource(
2772 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2774 return reshapedSource;
2775 if (getSourceType() ==
getType() &&
2777 return this->getSource();
2781 return OpFoldResult();
2786 auto rankedTensorType = llvm::cast<RankedTensorType>(
tensor.getType());
2787 unsigned rank = rankedTensorType.getRank();
2791 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType,
tensor,
2792 offsets, sizes, strides);
2799void InsertSliceOp::getAsmResultNames(
2801 setNameFn(getResult(),
"inserted_slice");
2815 result.addAttributes(attrs);
2816 build(
b,
result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2817 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2818 b.getDenseI64ArrayAttr(staticSizes),
2819 b.getDenseI64ArrayAttr(staticStrides));
2824void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2825 Value dest, ArrayRef<Range> ranges,
2826 ArrayRef<NamedAttribute> attrs) {
2828 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
2832void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2834 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2835 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2836 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
2837 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2838 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
2839 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2840 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
2841 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
2847 RankedTensorType srcType, RankedTensorType dstType,
2852 RankedTensorType expected =
2853 ExtractSliceOp::inferResultType(dstType, staticSizes);
2855 *expectedType = expected;
2860LogicalResult InsertSliceOp::verify() {
2862 RankedTensorType expectedType;
2865 getStaticSizes(), getStaticStrides(), &expectedType);
2872 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2873 getStaticStrides(),
true);
2875 return getOperation()->emitError(boundsResult.
errorMessage);
2898 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2901 if (!prevInsertOp ||
2902 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2903 !prevInsertOp.isSameAs(insertOp, isSame))
2906 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2918 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2921 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2922 !extractOp.isSameAs(insertOp, isSame))
2925 return extractOp.getSource();
2928OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2929 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2930 getSourceType() ==
getType() &&
2932 return this->getSource();
2939 return OpFoldResult();
2942LogicalResult InsertSliceOp::reifyResultShapes(
2944 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
2953template <
typename InsertOpTy>
2954class InsertSliceOpConstantArgumentFolder final
2955 :
public OpRewritePattern<InsertOpTy> {
2957 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2959 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2960 PatternRewriter &rewriter)
const override {
2961 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2962 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2963 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2972 SliceBoundsVerificationResult sliceResult =
2974 mixedOffsets, mixedSizes, mixedStrides);
2979 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2980 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2982 Value toInsert = insertSliceOp.getSource();
2983 if (sourceType != insertSliceOp.getSourceType()) {
2984 OpBuilder::InsertionGuard g(rewriter);
2988 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2990 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2991 sourceType, toInsert);
2994 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2995 mixedSizes, mixedStrides);
3020template <
typename InsertOpTy>
3021struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
3022 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3024 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3025 PatternRewriter &rewriter)
const override {
3026 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3027 return matchPattern(operand, matchConstantIndex());
3031 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3034 return std::nullopt;
3035 return castOp.getSource();
3037 std::optional<Value> sourceCastSource =
3038 getSourceOfCastOp(insertSliceOp.getSource());
3039 std::optional<Value> destCastSource =
3040 getSourceOfCastOp(insertSliceOp.getDest());
3041 if (!sourceCastSource && !destCastSource)
3045 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3046 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3047 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
3048 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3049 if (!srcType || !dstType)
3055 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3057 staticSizes, srcType.getShape(),
true);
3058 if (!rankReductionMask.has_value())
3065 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3066 int64_t rankReducedIdx = 0;
3067 for (
auto [idx, size] :
enumerate(staticSizes)) {
3068 if (!rankReductionMask.value().contains(idx) &&
3069 !srcType.isDynamicDim(rankReducedIdx)) {
3071 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
3072 size = srcType.getDimSize(rankReducedIdx++);
3078 staticSizes, insertSliceOp.getStaticStrides()) !=
3079 SliceVerificationResult::Success)
3081 SliceBoundsVerificationResult sliceResult =
3083 mixedSizes, insertSliceOp.getMixedStrides());
3088 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3089 insertSliceOp.getMixedOffsets(), mixedSizes,
3090 insertSliceOp.getMixedStrides());
3093 bool isParallelInsert =
3094 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3095 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3096 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3097 insertSliceOp.getDestType(),
3126template <
typename InsertOpTy>
3127struct InsertSliceOpSourceCastInserter final
3128 :
public OpRewritePattern<InsertOpTy> {
3129 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3131 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3132 PatternRewriter &rewriter)
const override {
3133 RankedTensorType srcType = insertSliceOp.getSourceType();
3134 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3136 SmallVector<int64_t> newSrcShape(srcType.getShape());
3137 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3138 if (std::optional<int64_t> constInt =
3143 newSrcShape[i] = *constInt;
3149 RankedTensorType newSrcType = RankedTensorType::get(
3150 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3151 if (srcType == newSrcType ||
3153 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3161 OpBuilder::InsertionGuard g(rewriter);
3165 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3167 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3168 newSrcType, insertSliceOp.getSource());
3170 insertSliceOp, cast, insertSliceOp.getDest(),
3171 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3172 insertSliceOp.getMixedStrides());
3178llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3182void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3183 MLIRContext *context) {
3184 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3185 InsertSliceOpCastFolder<InsertSliceOp>,
3186 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3193 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3194 unsigned rank = rankedTensorType.getRank();
3198 return b.createOrFold<tensor::InsertSliceOp>(loc,
tensor, dest, offsets,
3207 setNameFn(getResult(),
"padded");
3210LogicalResult PadOp::verify() {
3211 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3212 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3214 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3215 if (!expectedType) {
3216 return emitError(
"failed to infer expectedType from sourceType ")
3217 << sourceType <<
", specified resultType is " << resultType;
3219 if (resultType.getRank() != expectedType.getRank()) {
3221 << resultType <<
" does not match the inferred type "
3224 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3225 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3227 if (expectedType.isDynamicDim(i))
3230 << resultType <<
" does not match the inferred type "
3237LogicalResult PadOp::verifyRegions() {
3238 auto ®ion = getRegion();
3239 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3240 Block &block = region.front();
3242 return emitError(
"expected the block to have ") << rank <<
" arguments";
3246 if (!en.value().isIndex())
3248 << (en.index() + 1) <<
" to be an index";
3253 if (yieldOp.getValue().getType() !=
3255 return emitOpError(
"expected yield type to match shape element type");
3260RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3261 ArrayRef<int64_t> staticLow,
3262 ArrayRef<int64_t> staticHigh,
3263 ArrayRef<int64_t> resultShape) {
3264 unsigned rank = sourceType.getRank();
3265 if (staticLow.size() != rank)
3266 return RankedTensorType();
3267 if (staticHigh.size() != rank)
3268 return RankedTensorType();
3269 if (!resultShape.empty() && resultShape.size() != rank)
3270 return RankedTensorType();
3272 SmallVector<int64_t, 4> inferredShape;
3273 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3274 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3275 staticHigh[i] == ShapedType::kDynamic) {
3276 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3279 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3280 assert((resultShape.empty() || size == resultShape[i] ||
3281 resultShape[i] == ShapedType::kDynamic) &&
3282 "mismatch between inferred shape and result shape");
3283 inferredShape.push_back(size);
3287 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3290void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3291 Value source, ArrayRef<int64_t> staticLow,
3293 bool nofold, ArrayRef<NamedAttribute> attrs) {
3294 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3296 resultType = inferResultType(sourceType, staticLow, staticHigh);
3297 result.addAttributes(attrs);
3298 build(
b,
result, resultType, source, low, high,
3299 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3300 nofold ?
b.getUnitAttr() : UnitAttr());
3303void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3305 ArrayRef<NamedAttribute> attrs) {
3306 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3307 unsigned rank = sourceType.getRank();
3308 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3309 build(
b,
result, resultType, source, staticVector, staticVector, low, high,
3313void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3314 Value source, ArrayRef<OpFoldResult> low,
3315 ArrayRef<OpFoldResult> high,
bool nofold,
3316 ArrayRef<NamedAttribute> attrs) {
3317 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3318 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3319 SmallVector<int64_t, 4> staticLow, staticHigh;
3327 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3329 assert(llvm::isa<RankedTensorType>(resultType));
3330 result.addAttributes(attrs);
3331 build(
b,
result, resultType, source, dynamicLow, dynamicHigh,
3332 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3333 nofold ?
b.getUnitAttr() : UnitAttr());
3336void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3337 Value source, ArrayRef<OpFoldResult> low,
3338 ArrayRef<OpFoldResult> high, Value constantPadValue,
3339 bool nofold, ArrayRef<NamedAttribute> attrs) {
3340 build(
b,
result, resultType, source, low, high, nofold, attrs);
3343 Region *region =
result.regions[0].get();
3344 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3345 SmallVector<Type> blockArgTypes(sourceRank,
b.getIndexType());
3346 SmallVector<Location> blockArgLocs(sourceRank,
result.location);
3350 OpBuilder::InsertionGuard guard(
b);
3351 b.createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3352 tensor::YieldOp::create(
b,
result.location, constantPadValue);
3355llvm::SmallBitVector PadOp::getPaddedDims() {
3356 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3357 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3358 for (
const auto &en :
enumerate(paddingWidths))
3360 paddedDims.set(en.index());
3362 extractPaddedDims(getMixedLowPad());
3363 extractPaddedDims(getMixedHighPad());
3370struct FoldStaticZeroPadding :
public OpRewritePattern<PadOp> {
3371 using OpRewritePattern<PadOp>::OpRewritePattern;
3373 LogicalResult matchAndRewrite(PadOp padTensorOp,
3374 PatternRewriter &rewriter)
const override {
3375 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3377 if (padTensorOp.getNofold())
3380 padTensorOp, padTensorOp.getResult().
getType(),
3381 padTensorOp.getSource());
3387struct FoldSourceTensorCast :
public OpRewritePattern<PadOp> {
3388 using OpRewritePattern<PadOp>::OpRewritePattern;
3390 LogicalResult matchAndRewrite(PadOp padTensorOp,
3391 PatternRewriter &rewriter)
const override {
3392 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3396 auto newResultType = PadOp::inferResultType(
3397 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3398 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3399 padTensorOp.getResultType().getShape());
3401 if (newResultType == padTensorOp.getResultType()) {
3403 padTensorOp.getSourceMutable().assign(castOp.getSource());
3406 auto newOp = PadOp::create(
3407 rewriter, padTensorOp->getLoc(), newResultType,
3408 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3409 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3410 padTensorOp.getHigh(), padTensorOp.getNofold(),
3413 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3416 padTensorOp, padTensorOp.getResultType(), newOp);
3424struct FoldTargetTensorCast :
public OpRewritePattern<PadOp> {
3425 using OpRewritePattern<PadOp>::OpRewritePattern;
3427 LogicalResult matchAndRewrite(PadOp padTensorOp,
3428 PatternRewriter &rewriter)
const override {
3429 if (!padTensorOp.getResult().hasOneUse())
3432 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3436 tensorCastOp.getDest().getType()))
3439 auto replacementOp = PadOp::create(
3440 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3441 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3442 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3443 padTensorOp.getHigh(), padTensorOp.getNofold(),
3445 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3447 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3448 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3488struct FoldOrthogonalPaddings :
public OpRewritePattern<PadOp> {
3489 using OpRewritePattern<PadOp>::OpRewritePattern;
3491 LogicalResult matchAndRewrite(PadOp padOp,
3492 PatternRewriter &rewriter)
const override {
3493 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3496 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3497 if (!outerPadOp || outerPadOp.getNofold())
3499 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3504 int64_t rank = padOp.getSourceType().getRank();
3505 if (outerSliceOp.getSourceType().getRank() != rank) {
3507 "cannot fold rank-reducing chain");
3511 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3513 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3517 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3519 "cannot fold PadOps with low padding");
3523 Attribute innerAttr, outerAttr;
3524 Value innerValue = padOp.getConstantPaddingValue();
3525 Value outerValue = outerPadOp.getConstantPaddingValue();
3526 if (!innerValue || !outerValue ||
3529 innerAttr != outerAttr) {
3531 padOp,
"cannot fold PadOps with different padding values");
3535 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3536 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3537 if (innerDims.anyCommon(outerDims)) {
3539 padOp,
"cannot fold PadOps with common padding dimensions");
3547 SmallVector<OpFoldResult> newOffsets(rank, rewriter.
getIndexAttr(0));
3549 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3550 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3551 if (!innerDims.test(en.index()) &&
3553 en.value() = outerOffset;
3556 if (!outerDims.test(en.index()) &&
3558 en.value() = innerOffset;
3562 padOp,
"cannot find zero-offset and zero-padding pair");
3570 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3572 if (!outerDims.test(en.index()))
3574 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3575 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3576 assert(ShapedType::isStatic(sourceSize) &&
3577 "expected padded dimension to have a static size");
3580 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3581 "match the size of the outer padding");
3583 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3587 SmallVector<OpFoldResult> newHighPad(rank, rewriter.
getIndexAttr(0));
3589 if (innerDims.test(en.index()))
3590 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3591 if (outerDims.test(en.index()))
3592 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3597 auto newSliceOp = ExtractSliceOp::create(
3598 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3599 newSizes, innerSliceOp.getMixedStrides());
3600 auto newPadOp = PadOp::create(
3601 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3602 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3605 newPadOp.getRegion().begin());
3606 rewriter.
replaceOp(padOp, newPadOp.getResult());
3611struct FoldStaticPadding :
public OpRewritePattern<PadOp> {
3612 using OpRewritePattern<PadOp>::OpRewritePattern;
3614 LogicalResult matchAndRewrite(PadOp padTensorOp,
3615 PatternRewriter &rewriter)
const override {
3616 Value input = padTensorOp.getSource();
3617 if (!llvm::isa<RankedTensorType>(input.
getType()))
3619 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3620 auto inputRank = inputDims.size();
3622 auto oldResultType =
3623 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3627 auto outputDims = oldResultType.getShape();
3630 SmallVector<int64_t> constOperandsLow;
3631 SmallVector<Value> newLows;
3632 for (
auto operand : padTensorOp.getLow()) {
3635 constOperandsLow.push_back(ShapedType::kDynamic);
3636 newLows.push_back(operand);
3639 constOperandsLow.push_back(intOp.getExtValue());
3641 SmallVector<int64_t> constOperandsHigh;
3642 SmallVector<Value> newHighs;
3643 for (
auto operand : padTensorOp.getHigh()) {
3646 constOperandsHigh.push_back(ShapedType::kDynamic);
3647 newHighs.push_back(operand);
3650 constOperandsHigh.push_back(intOp.getExtValue());
3653 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3654 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3657 if (inputDims.size() != outputDims.size() ||
3658 inputDims.size() != constLow.size() ||
3659 inputDims.size() != constHigh.size())
3664 for (
size_t i = 0; i < inputRank; i++) {
3665 if (constLow[i] == ShapedType::kDynamic)
3666 constLow[i] = constOperandsLow[lowCount++];
3667 if (constHigh[i] == ShapedType::kDynamic)
3668 constHigh[i] = constOperandsHigh[highCount++];
3671 auto staticLow = ArrayRef<int64_t>(constLow);
3672 auto staticHigh = ArrayRef<int64_t>(constHigh);
3675 SmallVector<int64_t> newOutDims;
3676 for (
size_t i = 0; i < inputRank; i++) {
3677 if (outputDims[i] == ShapedType::kDynamic) {
3678 newOutDims.push_back(
3679 (staticLow[i] == ShapedType::kDynamic ||
3680 staticHigh[i] == ShapedType::kDynamic ||
3681 inputDims[i] == ShapedType::kDynamic
3682 ? ShapedType::kDynamic
3683 : inputDims[i] + staticLow[i] + staticHigh[i]));
3685 newOutDims.push_back(outputDims[i]);
3689 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3690 llvm::all_of(newOutDims,
3691 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3695 auto newResultType = RankedTensorType::get(
3696 newOutDims, padTensorOp.getType().getElementType());
3697 auto newOp = PadOp::create(
3698 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3699 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3703 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3731struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3732 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3734 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3735 PatternRewriter &rewriter)
const override {
3736 if (padOp.getNofold()) {
3740 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3741 if (!producerPad || producerPad.getNofold()) {
3743 padOp,
"producer is not a foldable tensor.pad op");
3747 Value consumerPadValue = padOp.getConstantPaddingValue();
3748 Value producerPadValue = producerPad.getConstantPaddingValue();
3749 if (!consumerPadValue || !producerPadValue ||
3750 consumerPadValue != producerPadValue) {
3753 "cannot fold PadOps with different or non-constant padding values");
3756 Location loc = padOp.getLoc();
3761 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3762 ArrayRef<OpFoldResult> producerPaddings) {
3763 SmallVector<OpFoldResult> sumPaddings;
3764 for (
auto [consumerIndex, producerIndex] :
3765 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3767 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3772 SmallVector<OpFoldResult> newHighPad =
3773 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3774 SmallVector<OpFoldResult> newLowPad =
3775 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3777 auto newPadOp = tensor::PadOp::create(
3778 rewriter, padOp.getLoc(), padOp.getResultType(),
3779 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3782 newPadOp.getRegion().begin());
3783 rewriter.
replaceOp(padOp, newPadOp.getResult());
3791PadOp::reifyResultShapes(OpBuilder &
b,
3793 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
3794 SmallVector<OpFoldResult> lp = getMixedLowPad();
3795 SmallVector<OpFoldResult> hp = getMixedHighPad();
3796 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3797 if (!
getType().isDynamicDim(i)) {
3798 reifiedReturnShapes[0][i] =
b.getIndexAttr(
getType().getDimSize(i));
3801 Location loc = getLoc();
3802 Value dim =
b.createOrFold<tensor::DimOp>(
3805 AffineExpr d0, d1, d2;
3808 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3813void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3814 MLIRContext *context) {
3815 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3816 FoldOrthogonalPaddings, FoldStaticPadding,
3817 FoldConsecutiveConstantPadding>(context);
3829Value PadOp::getConstantPaddingValue() {
3830 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3833 Value padValue = yieldOp.getValue();
3844OpFoldResult PadOp::fold(FoldAdaptor) {
3845 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3855OpResult ParallelInsertSliceOp::getTiedOpResult() {
3856 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3857 for (
const auto &it :
3858 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3859 Operation &nextOp = it.value();
3860 if (&nextOp == getOperation())
3861 return parallelCombiningParent.getParentResult(it.index());
3863 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3867void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3868 Value source, Value dest,
3869 ArrayRef<OpFoldResult> offsets,
3870 ArrayRef<OpFoldResult> sizes,
3871 ArrayRef<OpFoldResult> strides,
3872 ArrayRef<NamedAttribute> attrs) {
3873 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3874 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3878 result.addAttributes(attrs);
3879 build(
b,
result, {}, source, dest, dynamicOffsets, dynamicSizes,
3880 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3881 b.getDenseI64ArrayAttr(staticSizes),
3882 b.getDenseI64ArrayAttr(staticStrides));
3887void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3888 Value source, Value dest,
3889 ArrayRef<Range> ranges,
3890 ArrayRef<NamedAttribute> attrs) {
3892 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
3896void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3897 Value source, Value dest,
ValueRange offsets,
3899 ArrayRef<NamedAttribute> attrs) {
3900 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3901 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
3902 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3903 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
3904 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3905 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
3906 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
3911void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
3912 Value dest, ArrayRef<OpFoldResult> sizes,
3913 ArrayRef<NamedAttribute> attrs) {
3914 Attribute zeroIdxAttr =
b.getIndexAttr(0);
3915 Attribute oneIdxAttr =
b.getIndexAttr(1);
3916 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3917 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3918 build(
b,
result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3921LogicalResult ParallelInsertSliceOp::verify() {
3922 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3923 return this->
emitError(
"expected InParallelOpInterface parent, got:")
3924 << *(getOperation()->getParentOp());
3927 RankedTensorType expectedType;
3930 getStaticSizes(), getStaticStrides(), &expectedType);
3937 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3938 getStaticStrides(),
true);
3940 return getOperation()->emitError(boundsResult.
errorMessage);
3945void ParallelInsertSliceOp::getCanonicalizationPatterns(
3946 RewritePatternSet &results, MLIRContext *context) {
3947 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3948 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3949 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3952llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3957MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3958 return getDestMutable();
3961Operation *ParallelInsertSliceOp::getIteratingParent() {
3963 if (
auto combiningOp =
3964 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3965 return combiningOp->getParentOp();
3973void ScatterOp::getAsmResultNames(
3975 setNameFn(getResult(),
"scatter");
3978LogicalResult ScatterOp::verify() {
3979 int64_t destRank = getDestType().getRank();
3980 ArrayRef<int64_t> scatterDims = getScatterDims();
3982 getIndicesType().
getShape(), destRank,
3983 "scatter",
"dest")))
3987 return emitOpError(
"requires 'unique' attribute to be set");
3994 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3995 getDestType(), getIndicesType(), scatterDims,
false);
3996 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3997 getDestType(), getIndicesType(), scatterDims,
true);
3998 if (getSourceType() != expectedSourceType &&
3999 getSourceType() != expectedRankReducedSourceType) {
4003 << expectedSourceType <<
" or its rank-reduced variant "
4004 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
4015void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4016 Type aggregateType,
ValueRange dynamicSizes) {
4017 build(builder,
result, aggregateType, element, dynamicSizes);
4020void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4021 ArrayRef<int64_t> staticShape,
ValueRange dynamicSizes) {
4022 auto aggregateType = RankedTensorType::get(staticShape, element.
getType());
4023 build(builder,
result, aggregateType, element, dynamicSizes);
4026void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4027 ArrayRef<OpFoldResult> sizes) {
4028 SmallVector<int64_t> staticShape;
4029 SmallVector<Value> dynamicSizes;
4031 build(builder,
result, element, staticShape, dynamicSizes);
4034void SplatOp::getAsmResultNames(
4036 setNameFn(getResult(),
"splat");
4039LogicalResult SplatOp::verify() {
4045SplatOp::reifyResultShapes(OpBuilder &builder,
4047 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
4049 for (int64_t i = 0; i <
getType().getRank(); ++i) {
4050 if (
getType().isDynamicDim(i)) {
4059OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4060 auto constOperand = adaptor.getInput();
4061 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4065 if (!
getType().hasStaticShape())
4080 if (isa<InsertSliceOp>(op.getOperation()) ||
4081 isa<LoopLikeOpInterface>(op.getOperation()))
4114 isa<linalg::RelayoutOpInterface>(*op))
4122 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4125 replacements.reserve(newOp->getNumResults());
4126 for (
auto [oldResult, newResult] :
4127 llvm::zip(op->getResults(), newOp->getResults())) {
4128 if (newResult.getType() != oldResult.getType()) {
4129 replacements.push_back(tensor::CastOp::create(
4130 rewriter, op->getLoc(), oldResult.
getType(), newResult));
4132 replacements.push_back(newResult);
4145void TensorDialect::getCanonicalizationPatterns(
4146 RewritePatternSet &results)
const {
4154#define GET_OP_CLASSES
4155#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
Attributes are known-constant values of operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.