33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/SmallBitVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/MathExtras.h"
49 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
51 if (complex::ConstantOp::isBuildableWith(value, type))
52 return complex::ConstantOp::create(builder, loc, type,
53 llvm::cast<ArrayAttr>(value));
59 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
60 if (tensorType.isDynamicDim(dim))
61 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
68 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
70 for (
int64_t i = 0; i < tensorType.getRank(); ++i)
77 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
78 assert(tensorType &&
"expected tensor type");
82 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
84 return destOp.getTiedOpOperand(opResult)->get();
92 if (!tensorType.hasStaticShape()) {
100 for (
int64_t sz : tensorType.getShape())
101 mixedSizes.push_back(
b.getIndexAttr(sz));
106 tensor::EmptyOp::create(
b, loc, mixedSizes, tensorType.getElementType());
114 if (llvm::isa<TensorType>(opResult.getType())) {
116 if (failed(destination))
118 result.push_back(*destination);
125 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127 return rtp1.getShape() == rtp2.getShape() &&
128 rtp1.getElementType() == rtp2.getElementType();
138 llvm::SmallBitVector droppedDims(mixedSizes.size());
139 int64_t shapePos = reducedShape.size() - 1;
141 for (
const auto &size : enumerate(llvm::reverse(mixedSizes))) {
142 size_t idx = mixedSizes.size() - size.index() - 1;
144 bool isStaticUnitSize =
145 isa<Attribute>(size.value()) &&
146 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
151 assert(isStaticUnitSize &&
"expected unit dim");
152 droppedDims.set(idx);
157 if (!isStaticUnitSize) {
163 if (reducedShape[shapePos] == 1) {
169 droppedDims.set(idx);
172 assert(shapePos < 0 &&
"dimension mismatch");
179static RankedTensorType
183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184 "incorrect number of dynamic sizes");
188 for (
int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
192 if (cst.has_value()) {
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
198 staticShape[i] = *cst;
200 foldedDynamicSizes.push_back(dynamicSize);
205 return RankedTensorType::get(staticShape, type.getElementType(),
214 if (inputs.size() != 1 || outputs.size() != 1)
216 Type a = inputs.front(),
b = outputs.front();
217 auto aT = dyn_cast<TensorType>(a);
218 auto bT = dyn_cast<TensorType>(
b);
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
233 using OpRewritePattern<BitcastOp>::OpRewritePattern;
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236 PatternRewriter &rewriter)
const final {
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239 if (!tensorBitcastOperand)
242 auto resultType = cast<TensorType>(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
253 results.
add<ChainedTensorBitcast>(context);
261 setNameFn(getResult(),
"cast");
267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268 auto targetType = llvm::dyn_cast<RankedTensorType>(
target);
271 if (!sourceType || !targetType)
275 if (sourceType.getElementType() != targetType.getElementType())
279 if (sourceType.getRank() != targetType.getRank())
283 if (sourceType.getEncoding() != targetType.getEncoding())
287 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (ShapedType::isStatic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
325 castOp.getSource().getType());
358 if (llvm::isa<BlockArgument>(opOperand.get()))
360 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
361 return castOp && canFoldIntoConsumerOp(castOp);
368 newOperands.reserve(op->getNumOperands());
374 for (
OpOperand &opOperand : op->getOpOperands()) {
375 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
377 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
378 if (op.isDpsInit(&opOperand) &&
379 !llvm::isa<MemRefType>(newOperands.back().getType()))
380 newResTy[dpsInitIdx++] = newOperands.back().getType();
390 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
392 operand.set(castOp.getOperand());
400 if (inputs.size() != 1 || outputs.size() != 1)
402 Type a = inputs.front(),
b = outputs.front();
403 auto aT = llvm::dyn_cast<TensorType>(a);
404 auto bT = llvm::dyn_cast<TensorType>(
b);
408 if (aT.getElementType() != bT.getElementType())
425 if (rank != two.getRank())
430 for (
int64_t i = 0; i < rank; ++i) {
431 if (one.isDynamicDim(i)) {
432 join.push_back(two.getDimSize(i));
435 if (two.isDynamicDim(i)) {
436 join.push_back(one.getDimSize(i));
439 if (one.getDimSize(i) != two.getDimSize(i))
441 join.push_back(one.getDimSize(i));
451 using OpRewritePattern<CastOp>::OpRewritePattern;
453 LogicalResult matchAndRewrite(CastOp tensorCast,
454 PatternRewriter &rewriter)
const final {
455 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
457 if (!tensorCastOperand)
461 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
462 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
463 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
477 auto newJoin =
joinShapes(sourceType, resultType);
478 if (firstJoin != newJoin)
481 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
482 tensorCastOperand.getOperand());
500 using OpRewritePattern<CastOp>::OpRewritePattern;
502 LogicalResult matchAndRewrite(CastOp tensorCast,
503 PatternRewriter &rewriter)
const final {
504 auto extractOperand =
505 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
508 auto rankedResultType =
509 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
510 if (!rankedResultType)
514 rankedResultType.getShape() ==
515 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
519 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
521 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
523 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
524 if (dimMask && dimMask->count(i))
526 int64_t dim = rankedResultType.getShape()[dimIndex++];
527 if (ShapedType::isDynamic(dim))
529 sizes[i] = rewriter.getIndexAttr(dim);
532 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
533 tensorCast, rankedResultType, extractOperand.getSource(),
534 extractOperand.getMixedOffsets(), sizes,
535 extractOperand.getMixedStrides());
544 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
551RankedTensorType ConcatOp::inferResultType(
int64_t dim,
TypeRange inputTypes) {
552 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
554 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
555 int64_t concatRank = tensorTypes[0].getRank();
558 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
561 for (
int64_t i = 0, e = concatRank; i < e; ++i) {
565 for (
auto tensorType : tensorTypes)
570 for (
auto tensorType : tensorTypes)
573 sizes[dim] = concatSize.asInteger();
574 return RankedTensorType::get(sizes, tensorTypes[0].
getElementType());
579 FailureOr<RankedTensorType> resultType =
580 inferResultType(dim, inputs.
getTypes());
581 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
582 build(builder,
result, *resultType, dim, inputs);
585LogicalResult ConcatOp::verify() {
586 if (getInputs().size() < 1)
590 for (
auto input : getInputs())
591 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
593 RankedTensorType resultType = getResultType();
594 int64_t resultRank = getRank();
595 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
596 return type.getRank() != resultRank;
598 return emitOpError(
"rank of concatenated inputs must match result rank");
600 Type resultElementType = resultType.getElementType();
601 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
602 return type.getElementType() != resultElementType;
604 return emitOpError(
"inputs and result element type must match");
607 if (dim >= resultRank)
608 return emitOpError(
"concatenation dim must be less than the tensor rank");
611 for (
int64_t i = 0, e = resultRank; i < e; ++i) {
615 for (
auto tensorType : inputTypes) {
616 FailureOr<SaturatedInteger> maybeSize =
619 return emitOpError(
"static concatenation size mismatch along ")
620 <<
"non-concatenated dimension " << i;
626 for (
auto tensorType : inputTypes)
629 sizes[dim] = concatSize.asInteger();
630 auto inferredResultType =
633 for (
auto [inferredSize, actualSize] :
634 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
635 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
636 ShapedType::isDynamic(actualSize);
637 if (!hasDynamic && inferredSize != actualSize)
639 << resultType <<
"does not match inferred shape "
640 << inferredResultType <<
" static sizes";
646FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
647 size_t numInputs = getInputs().size();
648 uint64_t concatDim = getDim();
651 inputShapes.reserve(numInputs);
653 concatOffsets.reserve(numInputs);
660 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
664 outputShape = inputShape;
665 concatOffsets.push_back(zero);
667 concatOffsets.push_back(outputShape[concatDim]);
669 builder, loc, addExpr,
670 {outputShape[concatDim], inputShape[concatDim]});
672 inputShapes.emplace_back(std::move(inputShape));
682 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
683 offsets[concatDim] = concatOffsets[
index];
684 auto insertSlice = tensor::InsertSliceOp::create(
695ConcatOp::reifyResultShapes(
OpBuilder &builder,
699 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
701 Value init = inputs[0];
709 for (
int64_t i = 0; i < rank; ++i) {
712 if (!
getType().isDynamicDim(i)) {
714 }
else if (!inferredResultType.isDynamicDim(i)) {
717 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
719 reifiedReturnShapes[0][i] =
720 tensor::DimOp::create(builder, init.
getLoc(), init, i).getResult();
724 if (
getType().isDynamicDim(dim)) {
729 for (
auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
732 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
740 reifiedReturnShapes[0][dim] =
746void ConcatOp::getAsmResultNames(
748 setNameFn(getResult(),
"concat");
753 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
761 using OpRewritePattern<ConcatOp>::OpRewritePattern;
763 LogicalResult matchAndRewrite(ConcatOp concatOp,
764 PatternRewriter &rewriter)
const override {
765 if (concatOp.getInputs().size() != 1)
768 concatOp.getInputs()[0]);
793 using OpRewritePattern<ConcatOp>::OpRewritePattern;
795 LogicalResult matchAndRewrite(ConcatOp concatOp,
796 PatternRewriter &rewriter)
const override {
797 int64_t dim = concatOp.getDim();
798 RankedTensorType inferredResultType =
799 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
802 LogicalResult matched = failure();
805 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
806 for (
auto [operandIdx, operandType] :
807 llvm::enumerate(concatOp->getOperandTypes())) {
809 inferredOperandShape[dim] =
810 cast<RankedTensorType>(operandType).getDimSize(dim);
811 auto inferredOperandType = RankedTensorType::get(
812 inferredOperandShape, inferredResultType.getElementType());
820 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
821 concatOp.getOperand(operandIdx));
823 concatOp->setOperand(operandIdx, castOp->getResult(0));
847 using OpRewritePattern<ConcatOp>::OpRewritePattern;
849 LogicalResult matchAndRewrite(ConcatOp concatOp,
850 PatternRewriter &rewriter)
const override {
851 int64_t dim = concatOp.getDim();
852 RankedTensorType inferredResultType =
853 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
857 concatOp.getResultType())) {
862 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
863 concatOp->getOperands());
875 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
884 setNameFn(getResult(),
"dim");
889 auto loc =
result.location;
891 build(builder,
result, source, indexValue);
894std::optional<int64_t> DimOp::getConstantIndex() {
903 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
904 if (!rankedSourceType)
907 if (rankedSourceType.getRank() <= constantIndex)
915 setResultRange(getResult(),
921 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
926 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
933 if (indexVal < 0 || indexVal >= tensorType.getRank())
937 if (!tensorType.isDynamicDim(
index.getInt())) {
942 Operation *definingOp = getSource().getDefiningOp();
945 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
947 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
950 assert(ShapedType::isDynamic(resultType.getShape()[
index.getInt()]));
953 auto dynExtents = fromElements.getDynamicExtents().begin();
954 for (
auto dim : resultType.getShape().take_front(
index.getInt()))
955 if (ShapedType::isDynamic(dim))
958 return Value{*dynExtents};
962 unsigned unsignedIndex =
index.getValue().getZExtValue();
964 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
967 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
968 sliceOp.isDynamicSize(unsignedIndex)) {
969 return {sliceOp.getDynamicSize(unsignedIndex)};
983 using OpRewritePattern<DimOp>::OpRewritePattern;
985 LogicalResult matchAndRewrite(DimOp dimOp,
986 PatternRewriter &rewriter)
const override {
987 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
990 Value newSource = castOp.getOperand();
999 using OpRewritePattern<DimOp>::OpRewritePattern;
1001 LogicalResult matchAndRewrite(DimOp dimOp,
1002 PatternRewriter &rewriter)
const override {
1003 auto source = dimOp.getSource();
1004 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1008 auto resultIndex = cast<OpResult>(source).getResultNumber();
1009 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1012 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1020 using OpRewritePattern<DimOp>::OpRewritePattern;
1022 LogicalResult matchAndRewrite(DimOp dim,
1023 PatternRewriter &rewriter)
const override {
1024 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1032 Location loc = dim.getLoc();
1034 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1035 if (extract.
getType() != dim.getType())
1037 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1046 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1056 assert(none_of(staticShape, ShapedType::isDynamic) &&
1057 "expected only static sizes");
1061void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1062 ArrayRef<int64_t> staticShape, Type elementType,
1063 ValueRange dynamicSizes, Attribute encoding) {
1064 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1065 build(builder,
result, tensorType, dynamicSizes);
1068void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1069 ArrayRef<OpFoldResult> sizes, Type elementType,
1070 Attribute encoding) {
1071 SmallVector<int64_t> staticShape;
1072 SmallVector<Value> dynamicSizes;
1074 build(builder,
result, staticShape, elementType, dynamicSizes, encoding);
1077LogicalResult EmptyOp::verify() {
1079 return emitOpError(
"incorrect number of dynamic sizes, has ")
1081 <<
getType().getNumDynamicDims();
1086EmptyOp::reifyResultShapes(OpBuilder &builder,
1088 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1090 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1091 if (
getType().isDynamicDim(i)) {
1100Value EmptyOp::getDynamicSize(
unsigned idx) {
1101 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1103 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1104 if (
getType().isDynamicDim(i))
1109SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1110 SmallVector<OpFoldResult>
result;
1113 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1114 if (
getType().isDynamicDim(i)) {
1135struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1136 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1138 LogicalResult matchAndRewrite(EmptyOp op,
1139 PatternRewriter &rewriter)
const override {
1140 SmallVector<Value> foldedDynamicSizes;
1142 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1145 if (foldedTensorType == op.getType())
1148 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1149 foldedDynamicSizes);
1155struct FoldEmptyTensorWithDimOp :
public OpRewritePattern<DimOp> {
1156 using OpRewritePattern<DimOp>::OpRewritePattern;
1158 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1159 PatternRewriter &rewriter)
const override {
1160 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1161 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1162 if (!emptyTensorOp || !maybeConstantIndex)
1164 auto emptyTensorType = emptyTensorOp.getType();
1165 if (*maybeConstantIndex < 0 ||
1166 *maybeConstantIndex >= emptyTensorType.getRank() ||
1167 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1170 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1190struct FoldEmptyTensorWithCastOp :
public OpRewritePattern<CastOp> {
1191 using OpRewritePattern<CastOp>::OpRewritePattern;
1193 LogicalResult matchAndRewrite(CastOp castOp,
1194 PatternRewriter &rewriter)
const override {
1197 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1202 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1203 ArrayRef<int64_t> resultShape = resultType.getShape();
1204 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1205 SmallVector<OpFoldResult> newMixedSizes;
1206 newMixedSizes.reserve(currMixedSizes.size());
1207 assert(resultShape.size() == currMixedSizes.size() &&
1208 "mismatch in result shape and sizes of empty op");
1209 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1210 int64_t newDim = std::get<0>(it);
1211 OpFoldResult currDim = std::get<1>(it);
1214 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1215 if (ShapedType::isDynamic(newDim) ||
1216 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1221 producer,
"mismatch in static value of shape of empty tensor "
1222 "result and cast result");
1224 newMixedSizes.push_back(attr);
1230 if (ShapedType::isStatic(newDim)) {
1231 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1237 newMixedSizes.push_back(currDim);
1242 resultType.getElementType());
1249void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1250 MLIRContext *context) {
1251 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1252 ReplaceEmptyTensorStaticShapeDims>(context);
1269struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1270 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1272 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1273 PatternRewriter &rewriter)
const final {
1274 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1277 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1280 extract, tensorCast.getSource(), extract.getIndices());
1295struct ExtractFromCollapseShape :
public OpRewritePattern<tensor::ExtractOp> {
1296 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1298 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1299 PatternRewriter &rewriter)
const final {
1301 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1304 if (!collapseOp.getSrcType().hasStaticShape())
1307 auto sourceSizes = collapseOp.getSrcType().getShape();
1309 SmallVector<Value>
indices(extractOp.getIndices().begin(),
1310 extractOp.getIndices().end());
1311 SmallVector<Value> sourceIndices;
1312 for (
auto [index, group] :
1313 llvm::zip(
indices, collapseOp.getReassociationIndices())) {
1314 assert(!group.empty() &&
"association indices groups cannot be empty");
1315 auto groupSize = group.size();
1317 if (groupSize == 1) {
1318 sourceIndices.push_back(index);
1322 SmallVector<int64_t> basis =
1323 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
1324 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1325 rewriter, extractOp.getLoc(), index, basis,
true);
1326 llvm::append_range(sourceIndices,
delinearize.getResults());
1328 if (collapseOp.getReassociationIndices().empty()) {
1331 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1333 rewriter, extractOp.getLoc(), zeroAffineMap,
1334 ArrayRef<OpFoldResult>{});
1335 for (int64_t i = 0; i < srcRank; i++) {
1336 sourceIndices.push_back(
1342 extractOp, collapseOp.getSrc(), sourceIndices);
1349void ExtractOp::getAsmResultNames(
1351 setNameFn(getResult(),
"extracted");
1354LogicalResult ExtractOp::verify() {
1356 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1357 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1358 return emitOpError(
"incorrect number of indices for extract_element");
1367 auto insertOp = extractOp.getTensor().
getDefiningOp<InsertOp>();
1372 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1373 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1374 return insertOp.getScalar();
1379OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1380 if (Attribute tensor = adaptor.getTensor()) {
1383 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1384 return splatTensor.getSplatValue<Attribute>();
1387 if (isa<DenseResourceElementsAttr>(tensor))
1392 SmallVector<uint64_t, 8>
indices;
1393 for (Attribute indice : adaptor.getIndices()) {
1394 if (!indice || !llvm::isa<IntegerAttr>(indice))
1396 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1400 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1401 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1402 auto rank = tensorType.getRank();
1403 assert(
static_cast<int64_t
>(
indices.size()) == tensorType.getRank() &&
1407 for (
int i = rank - 1; i >= 0; --i) {
1408 flatIndex +=
indices[i] * stride;
1409 stride *= tensorType.getDimSize(i);
1413 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1416 return fromElementsOp.getElements()[flatIndex];
1420 if (Attribute tensor = adaptor.getTensor()) {
1421 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1422 if (elementsAttr && elementsAttr.isValidIndex(
indices))
1423 return elementsAttr.getValues<Attribute>()[
indices];
1432void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1433 MLIRContext *context) {
1434 results.
add<ExtractFromTensorCast>(context);
1446void FromElementsOp::getAsmResultNames(
1448 setNameFn(getResult(),
"from_elements");
1453 assert(!elements.empty() &&
"expected at least one element");
1454 Type resultType = RankedTensorType::get(
1455 {
static_cast<int64_t>(elements.size())}, elements.front().
getType());
1456 build(builder,
result, resultType, elements);
1459OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1460 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1483struct ExtractElementFromIndexCast
1484 :
public OpRewritePattern<tensor::ExtractOp> {
1485 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1487 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1488 PatternRewriter &rewriter)
const final {
1489 Location loc = extract.getLoc();
1490 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1496 auto newExtract = tensor::ExtractOp::create(
1497 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1508void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1509 MLIRContext *context) {
1510 results.
add<ExtractElementFromIndexCast>(context);
1517void GatherOp::getAsmResultNames(
1519 setNameFn(getResult(),
"gather");
1534RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1535 RankedTensorType indicesType,
1536 ArrayRef<int64_t> gatherDims,
1538 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1539 resultShape.reserve(resultShape.size() + sourceType.getRank());
1540 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1541 if (llvm::binary_search(gatherDims, idx)) {
1543 resultShape.push_back(1);
1546 resultShape.push_back(sourceType.getDimSize(idx));
1548 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1554 StringRef gatherOrScatter, StringRef sourceOrDest) {
1556 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1558 int64_t numGatherDims = dims.size();
1559 if (numGatherDims > rank)
1561 <<
"_dims overflow " << sourceOrDest <<
" rank";
1564 <<
"_dims length must match the size of last dimension of indices";
1568 <<
"_dims value must be non-negative";
1571 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1573 for (
int64_t i = 1; i < numGatherDims; ++i) {
1574 if (dims[i - 1] >= dims[i])
1576 <<
"_dims values must be strictly increasing";
1581LogicalResult GatherOp::verify() {
1582 int64_t sourceRank = getSourceType().getRank();
1583 ArrayRef<int64_t> gatherDims = getGatherDims();
1585 getIndicesType().
getShape(), sourceRank,
1586 "gather",
"source")))
1589 RankedTensorType expectedResultType = GatherOp::inferResultType(
1590 getSourceType(), getIndicesType(), gatherDims,
false);
1591 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1592 getSourceType(), getIndicesType(), gatherDims,
true);
1593 if (getResultType() != expectedResultType &&
1594 getResultType() != expectedRankReducedResultType) {
1598 << expectedResultType <<
" or its rank-reduced variant "
1599 << expectedRankReducedResultType <<
" (got: " << getResultType()
1606OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1607 if (OpFoldResult reshapedSource = reshapeConstantSource(
1608 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1610 return reshapedSource;
1618void InsertOp::getAsmResultNames(
1620 setNameFn(getResult(),
"inserted");
1623LogicalResult InsertOp::verify() {
1625 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1626 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1627 return emitOpError(
"incorrect number of indices");
1631OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1632 Attribute scalar = adaptor.getScalar();
1633 Attribute dest = adaptor.getDest();
1635 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1636 if (scalar == splatDest.getSplatValue<Attribute>())
1645void GenerateOp::getAsmResultNames(
1647 setNameFn(getResult(),
"generated");
1650LogicalResult GenerateOp::reifyResultShapes(
1652 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1654 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1655 if (
getType().isDynamicDim(dim)) {
1656 reifiedReturnShapes[0][dim] = getOperand(idx++);
1658 reifiedReturnShapes[0][dim] =
1665LogicalResult GenerateOp::verify() {
1668 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1669 if (getNumOperands() != resultType.getNumDynamicDims())
1670 return emitError(
"must have as many index operands as dynamic extents "
1671 "in the result type");
1675LogicalResult GenerateOp::verifyRegions() {
1676 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1678 if (!llvm::all_of(getBody().getArgumentTypes(),
1679 [](Type ty) {
return ty.
isIndex(); }))
1680 return emitError(
"all body arguments must be index");
1681 if (getBody().getNumArguments() != resultTy.getRank())
1682 return emitError(
"must have one body argument per input dimension");
1685 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1687 if (yieldOp.getValue().getType() != resultTy.getElementType())
1689 "body must be terminated with a `yield` operation of the tensor "
1695void GenerateOp::build(
1696 OpBuilder &
b, OperationState &
result, Type resultTy,
1699 build(
b,
result, resultTy, dynamicExtents);
1702 OpBuilder::InsertionGuard guard(
b);
1703 Region *bodyRegion =
result.regions.front().get();
1704 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1705 SmallVector<Type, 2> argumentTypes(rank,
b.getIndexType());
1706 SmallVector<Location, 2> argumentLocs(rank,
result.location);
1708 b.createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1718struct StaticTensorGenerate :
public OpRewritePattern<GenerateOp> {
1719 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1721 LogicalResult matchAndRewrite(GenerateOp generateOp,
1722 PatternRewriter &rewriter)
const final {
1723 SmallVector<Value> foldedDynamicSizes;
1725 generateOp.getType(), generateOp.getDynamicExtents(),
1726 foldedDynamicSizes);
1729 if (foldedTensorType == generateOp.getType())
1732 auto loc = generateOp.getLoc();
1734 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1736 newOp.getBody().begin());
1738 generateOp.getType(), newOp);
1754struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1755 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1757 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1758 PatternRewriter &rewriter)
const final {
1759 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1764 Block *body = &tensorFromElements.getBody().front();
1767 rewriter.
clone(op, mapping);
1778void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1779 MLIRContext *context) {
1781 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1788void RankOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1789 setNameFn(getResult(),
"rank");
1792OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1794 auto type = getOperand().getType();
1795 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1796 if (shapedType && shapedType.hasRank())
1797 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
1798 return IntegerAttr();
1805void ReshapeOp::getAsmResultNames(
1807 setNameFn(getResult(),
"reshape");
1812 for (
auto dim : type.getShape())
1817LogicalResult ReshapeOp::verify() {
1818 TensorType operandType = llvm::cast<TensorType>(getSource().
getType());
1819 TensorType resultType = llvm::cast<TensorType>(getResult().
getType());
1822 return emitOpError(
"element types of source and destination tensor "
1823 "types should be the same");
1827 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1828 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1830 if (resultRankedType) {
1831 if (operandRankedType && resultRankedType.hasStaticShape() &&
1832 operandRankedType.hasStaticShape()) {
1834 return emitOpError(
"source and destination tensor should have the "
1835 "same number of elements");
1837 if (ShapedType::isDynamic(shapeSize))
1838 return emitOpError(
"cannot use shape operand with dynamic length to "
1839 "reshape to statically-ranked tensor type");
1840 if (shapeSize != resultRankedType.getRank())
1842 "length of shape operand differs from the result's tensor rank");
1847OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1848 if (OpFoldResult reshapedSource = reshapeConstantSource(
1849 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1851 return reshapedSource;
1856 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1857 getSourceMutable().assign(reshapeOpProducer.getSource());
1861 auto source = getSource();
1862 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1863 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1864 if (!sourceTy || !resultTy || sourceTy != resultTy)
1869 if (sourceTy.getRank() <= 1)
1872 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1873 auto elements = fromElements.getElements();
1875 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1876 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++
id) {
1877 auto element = elements[id];
1880 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1884 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1885 dynamicNoop &= dimOp.getSource() == source;
1889 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1893 dynamicNoop =
false;
1908void CollapseShapeOp::getAsmResultNames(
1910 setNameFn(getResult(),
"collapsed");
1913void ExpandShapeOp::getAsmResultNames(
1915 setNameFn(getResult(),
"expanded");
1918int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1919 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1920 "invalid resultDim");
1921 for (
const auto &it : llvm::enumerate(getReassociationIndices()))
1922 if (llvm::is_contained(it.value(), resultDim))
1924 llvm_unreachable(
"could not find reassociation group");
1927FailureOr<SmallVector<OpFoldResult>>
1928ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
1929 RankedTensorType expandedType,
1930 ArrayRef<ReassociationIndices> reassociation,
1931 ArrayRef<OpFoldResult> inputShape) {
1932 std::optional<SmallVector<OpFoldResult>> outputShape =
1937 return *outputShape;
1940SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1944void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1945 Type resultType, Value src,
1946 ArrayRef<ReassociationIndices> reassociation,
1947 ArrayRef<OpFoldResult> outputShape) {
1948 auto [staticOutputShape, dynamicOutputShape] =
1950 build(builder,
result, cast<RankedTensorType>(resultType), src,
1952 dynamicOutputShape, staticOutputShape);
1955void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1956 Type resultType, Value src,
1957 ArrayRef<ReassociationIndices> reassociation) {
1958 SmallVector<OpFoldResult> inputShape =
1960 auto tensorResultTy = cast<RankedTensorType>(resultType);
1961 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1962 builder,
result.location, tensorResultTy, reassociation, inputShape);
1963 SmallVector<OpFoldResult> outputShapeOrEmpty;
1964 if (succeeded(outputShape)) {
1965 outputShapeOrEmpty = *outputShape;
1967 build(builder,
result, tensorResultTy, src, reassociation,
1968 outputShapeOrEmpty);
1971SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1974SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1976 getReassociationIndices());
1979SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1982SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1984 getReassociationIndices());
1987RankedTensorType CollapseShapeOp::inferCollapsedType(
1988 RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1989 return inferCollapsedType(
1991 type.getContext(), reassociation)));
1997CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1998 ArrayRef<AffineMap> reassociation) {
1999 auto shape = type.getShape();
2000 SmallVector<int64_t, 4> newShape;
2001 newShape.reserve(reassociation.size());
2006 unsigned currentDim = 0;
2007 for (AffineMap m : reassociation) {
2008 unsigned dim = m.getNumResults();
2009 auto band = shape.slice(currentDim, dim);
2011 if (llvm::is_contained(band, ShapedType::kDynamic))
2012 size = ShapedType::kDynamic;
2014 for (
unsigned d = 0; d < dim; ++d)
2015 size *= shape[currentDim + d];
2016 newShape.push_back(size);
2020 return RankedTensorType::get(newShape, type.getElementType());
2023void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2024 ArrayRef<ReassociationIndices> reassociation,
2025 ArrayRef<NamedAttribute> attrs) {
2026 auto resultType = inferCollapsedType(
2027 llvm::cast<RankedTensorType>(src.
getType()),
2030 result.addAttribute(getReassociationAttrStrName(),
2032 build(
b,
result, resultType, src, attrs);
2035template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
2036 TensorReshapeOp, ExpandShapeOp>::value>
2038 RankedTensorType expandedType,
2039 RankedTensorType collapsedType) {
2041 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2044 auto maps = op.getReassociationMaps();
2045 RankedTensorType expectedType =
2046 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2048 return op.emitOpError(
"expected collapsed type to be ")
2049 << expectedType <<
", but got " << collapsedType;
2053LogicalResult ExpandShapeOp::verify() {
2054 auto srcType = getSrcType();
2055 auto resultType = getResultType();
2057 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2058 return emitOpError(
"expected number of static shape dims to be equal to "
2059 "the output rank (")
2060 << resultType.getRank() <<
") but found "
2061 << getStaticOutputShape().size() <<
" inputs instead";
2063 if ((int64_t)getOutputShape().size() !=
2064 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2065 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2066 "static_output_shape: static_output_shape has ")
2067 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2068 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2074LogicalResult CollapseShapeOp::verify() {
2081template <
typename TensorReshapeOp>
2082struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2083 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2084 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2085 PatternRewriter &rewriter)
const override {
2086 DenseElementsAttr attr;
2092 reshapeOp.getResultType(), attr.
getRawData());
2099template <
typename TensorReshapeOp>
2100class FoldReshapeWithSplat :
public OpRewritePattern<TensorReshapeOp> {
2102 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2104 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2105 PatternRewriter &rewriter)
const override {
2106 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2107 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2111 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2118template <
typename TensorReshapeOp>
2119struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2120 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2121 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2122 PatternRewriter &rewriter)
const override {
2124 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2128 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2130 if (!shapedTy.hasStaticShape())
2134 fromElements.getElements());
2140struct FoldCollapseOfCastOp :
public OpRewritePattern<CollapseShapeOp> {
2141 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2143 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2144 PatternRewriter &rewriter)
const override {
2145 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2149 RankedTensorType srcType =
2150 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2151 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2152 srcType, collapseShapeOp.getReassociationMaps());
2154 if (newResultType == collapseShapeOp.getResultType()) {
2156 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2159 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2160 newResultType, castOp.getSource(),
2161 collapseShapeOp.getReassociation());
2163 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2173struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2174 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2176 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2177 PatternRewriter &rewriter)
const override {
2178 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2182 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2183 SmallVector<ReassociationIndices, 4> reassoc =
2184 expandOp.getReassociationIndices();
2186 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2187 SmallVector<Value> dynamicOutputShape;
2188 auto outputIt = expandOp.getOutputShape().begin();
2190 for (
const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2191 for (uint64_t outDim : innerReassoc) {
2192 if (ShapedType::isStatic(newOutputShape[outDim]))
2199 Value val = *outputIt;
2201 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2202 dynamicOutputShape.push_back(val);
2208 newOutputShape[outDim] = cst.getSExtValue();
2210 dynamicOutputShape.push_back(val);
2216 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2220 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2221 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2222 for (
auto outDim : reassoc[inDim]) {
2223 auto ofr = newOutputShape[outDim];
2224 if (ShapedType::isDynamic(ofr)) {
2225 newInputShape[inDim] = ShapedType::kDynamic;
2228 newInputShape[inDim] *= ofr;
2232 SmallVector<OpFoldResult> outputOfr =
2234 auto inputType = RankedTensorType::get(
2235 newInputShape, expandOp.getSrcType().getElementType());
2236 auto outputType = RankedTensorType::get(
2237 newOutputShape, expandOp.getSrcType().getElementType());
2238 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2240 auto newExpand = ExpandShapeOp::create(
2241 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2242 expandOp.getReassociationIndices(), outputOfr);
2244 newExpand.getResult());
2250void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2251 MLIRContext *context) {
2253 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2254 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
2255 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2256 FoldReshapeWithSplat<ExpandShapeOp>,
2257 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2260void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2261 MLIRContext *context) {
2263 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2264 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2265 tensor::DimOp, RankedTensorType>,
2266 FoldReshapeWithConstant<CollapseShapeOp>,
2267 FoldReshapeWithSplat<CollapseShapeOp>,
2268 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2272OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2274 adaptor.getOperands());
2277OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2279 adaptor.getOperands());
2286void ExtractSliceOp::getAsmResultNames(
2288 setNameFn(getResult(),
"extracted_slice");
2294RankedTensorType ExtractSliceOp::inferResultType(
2295 RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2296 ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2300 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2301 sourceTensorType.getRank() &&
2302 "unexpected staticSizes not equal to rank of source");
2303 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2304 sourceTensorType.getEncoding());
2308RankedTensorType ExtractSliceOp::inferResultType(
2309 RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2310 ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
2311 SmallVector<int64_t> staticSizes;
2313 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2314 sourceTensorType.getRank() &&
2315 "unexpected staticSizes not equal to rank of source");
2316 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2317 sourceTensorType.getEncoding());
2328RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2329 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2330 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2331 ArrayRef<int64_t> strides) {
2333 auto inferredType = llvm::cast<RankedTensorType>(
2334 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2335 int rankDiff = inferredType.getRank() - desiredResultRank;
2337 auto shape = inferredType.getShape();
2338 llvm::SmallBitVector dimsToProject =
2340 SmallVector<int64_t> projectedShape;
2342 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2343 if (!dimsToProject.test(pos))
2344 projectedShape.push_back(shape[pos]);
2346 RankedTensorType::get(projectedShape, inferredType.getElementType());
2348 return inferredType;
2351RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2352 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2353 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2354 ArrayRef<OpFoldResult> strides) {
2355 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2356 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2360 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2361 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2367void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2368 RankedTensorType resultType, Value source,
2369 ArrayRef<OpFoldResult> offsets,
2370 ArrayRef<OpFoldResult> sizes,
2371 ArrayRef<OpFoldResult> strides,
2372 ArrayRef<NamedAttribute> attrs) {
2373 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2374 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2378 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2381 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2382 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2384 result.addAttributes(attrs);
2385 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
2386 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2387 b.getDenseI64ArrayAttr(staticSizes),
2388 b.getDenseI64ArrayAttr(staticStrides));
2393void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2394 ArrayRef<OpFoldResult> offsets,
2395 ArrayRef<OpFoldResult> sizes,
2396 ArrayRef<OpFoldResult> strides,
2397 ArrayRef<NamedAttribute> attrs) {
2398 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2403void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2404 ArrayRef<Range> ranges,
2405 ArrayRef<NamedAttribute> attrs) {
2407 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2412void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2413 RankedTensorType resultType, Value source,
2415 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2416 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2417 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
2418 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2419 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
2420 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2421 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
2422 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
2426void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2428 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2429 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2434 RankedTensorType expectedType) {
2439 return op->
emitError(
"expected rank to be smaller or equal to ")
2440 <<
"the other rank. ";
2442 return op->
emitError(
"expected type to be ")
2443 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2445 return op->
emitError(
"expected element type to be ")
2446 << expectedType.getElementType();
2448 llvm_unreachable(
"unexpected extract_slice op verification result");
2453LogicalResult ExtractSliceOp::verify() {
2454 RankedTensorType sourceType = getSourceType();
2457 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2458 sourceType, getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2466 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2467 getStaticStrides(),
true);
2469 return getOperation()->emitError(boundsResult.
errorMessage);
2474llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2479ExtractSliceOp::rankReduceIfNeeded(OpBuilder &
b, Location loc, Value value,
2480 ArrayRef<int64_t> desiredShape) {
2481 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2482 assert(sourceTensorType &&
"not a ranked tensor type");
2483 auto sourceShape = sourceTensorType.getShape();
2484 if (sourceShape.equals(desiredShape))
2486 auto maybeRankReductionMask =
2488 if (!maybeRankReductionMask)
2492 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2495LogicalResult ExtractSliceOp::reifyResultShapes(
2497 reifiedReturnShapes.resize(1);
2498 reifiedReturnShapes[0].reserve(
getType().getRank());
2501 for (
const auto &size :
enumerate(mixedSizes)) {
2502 if (droppedDims.test(size.index()))
2504 reifiedReturnShapes[0].push_back(size.value());
2525class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2527 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2529 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2530 PatternRewriter &rewriter)
const override {
2532 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2533 return matchPattern(operand, matchConstantIndex());
2537 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2546 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2547 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2548 sliceOp.getStaticStrides());
2553 Location loc = sliceOp.getLoc();
2554 Value newResult = ExtractSliceOp::create(
2555 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2556 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2557 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2558 sliceOp.getStaticStrides());
2567template <
typename IterTy,
typename ElemTy>
2568static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2569 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2570 ArrayRef<int64_t> strides,
2571 llvm::SmallVectorImpl<ElemTy> *outValues) {
2572 assert(offsets.size() == sizes.size());
2573 assert(offsets.size() == strides.size());
2574 if (offsets.empty())
2577 int64_t offset = offsets.front();
2578 int64_t size = sizes.front();
2579 int64_t stride = strides.front();
2580 if (offsets.size() == 1) {
2581 for (int64_t i = 0; i < size; ++i, offset += stride)
2582 outValues->push_back(*(values + offset));
2587 for (int64_t i = 0; i < size; ++i, offset += stride) {
2588 auto begin = values + offset * counts.front();
2589 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2590 offsets.drop_front(), sizes.drop_front(),
2591 strides.drop_front(), outValues);
2598class ConstantOpExtractSliceFolder final
2599 :
public OpRewritePattern<ExtractSliceOp> {
2601 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2603 ConstantOpExtractSliceFolder(MLIRContext *context,
2605 : OpRewritePattern<ExtractSliceOp>(context),
2606 controlFn(std::move(controlFn)) {}
2608 LogicalResult matchAndRewrite(ExtractSliceOp op,
2609 PatternRewriter &rewriter)
const override {
2610 DenseElementsAttr attr;
2619 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2620 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2621 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2628 int64_t count = sourceType.getNumElements();
2633 auto offsets = op.getStaticOffsets();
2634 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2636 auto sizes = op.getStaticSizes();
2637 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2639 auto strides = op.getStaticStrides();
2640 if (llvm::is_contained(strides, ShapedType::kDynamic))
2644 SmallVector<int64_t> counts;
2645 ArrayRef<int64_t> shape = sourceType.getShape();
2646 counts.reserve(shape.size());
2647 for (int64_t v : shape) {
2649 counts.push_back(count);
2653 DenseElementsAttr newAttr;
2655 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2656 SmallVector<APInt> outValues;
2657 outValues.reserve(sourceType.getNumElements());
2658 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2659 elems.begin(), counts, offsets, sizes, strides, &outValues);
2661 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2662 SmallVector<APFloat> outValues;
2663 outValues.reserve(sourceType.getNumElements());
2664 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2665 elems.begin(), counts, offsets, sizes, strides, &outValues);
2688 patterns.add<ConstantOpExtractSliceFolder>(
patterns.getContext(), controlFn);
2697 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2698 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2706 ExtractSliceOp newOp) {
2709 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2715void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2716 MLIRContext *context) {
2718 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2719 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2720 ExtractSliceOpCastFolder>(context);
2726 ShapedType shapedType) {
2733 auto shape = shapedType.getShape();
2734 for (
auto it : llvm::zip(op.getMixedSizes(),
shape))
2748 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2751 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2752 insertOp.isSameAs(extractOp, isSame))
2753 return insertOp.getSource();
2758OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2759 if (OpFoldResult reshapedSource = reshapeConstantSource(
2760 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2762 return reshapedSource;
2763 if (getSourceType() ==
getType() &&
2765 return this->getSource();
2769 return OpFoldResult();
2774 auto rankedTensorType = llvm::cast<RankedTensorType>(
tensor.getType());
2775 unsigned rank = rankedTensorType.getRank();
2779 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType,
tensor,
2780 offsets, sizes, strides);
2787void InsertSliceOp::getAsmResultNames(
2789 setNameFn(getResult(),
"inserted_slice");
2803 result.addAttributes(attrs);
2804 build(
b,
result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2805 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2806 b.getDenseI64ArrayAttr(staticSizes),
2807 b.getDenseI64ArrayAttr(staticStrides));
2812void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2813 Value dest, ArrayRef<Range> ranges,
2814 ArrayRef<NamedAttribute> attrs) {
2816 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
2820void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2822 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2823 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2824 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
2825 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2826 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
2827 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2828 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
2829 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
2835 RankedTensorType srcType, RankedTensorType dstType,
2840 RankedTensorType expected = ExtractSliceOp::inferResultType(
2841 dstType, staticOffsets, staticSizes, staticStrides);
2843 *expectedType = expected;
2848LogicalResult InsertSliceOp::verify() {
2850 RankedTensorType expectedType;
2853 getStaticSizes(), getStaticStrides(), &expectedType);
2860 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2861 getStaticStrides(),
true);
2863 return getOperation()->emitError(boundsResult.
errorMessage);
2886 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2889 if (!prevInsertOp ||
2890 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2891 !prevInsertOp.isSameAs(insertOp, isSame))
2894 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2906 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2909 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2910 !extractOp.isSameAs(insertOp, isSame))
2913 return extractOp.getSource();
2916OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2917 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2918 getSourceType() ==
getType() &&
2920 return this->getSource();
2927 return OpFoldResult();
2930LogicalResult InsertSliceOp::reifyResultShapes(
2932 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
2941template <
typename InsertOpTy>
2942class InsertSliceOpConstantArgumentFolder final
2943 :
public OpRewritePattern<InsertOpTy> {
2945 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2947 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2948 PatternRewriter &rewriter)
const override {
2949 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2950 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2951 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2960 SliceBoundsVerificationResult sliceResult =
2962 mixedOffsets, mixedSizes, mixedStrides);
2967 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2968 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2969 mixedOffsets, mixedSizes, mixedStrides);
2970 Value toInsert = insertSliceOp.getSource();
2971 if (sourceType != insertSliceOp.getSourceType()) {
2972 OpBuilder::InsertionGuard g(rewriter);
2976 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2978 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2979 sourceType, toInsert);
2982 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2983 mixedSizes, mixedStrides);
3008template <
typename InsertOpTy>
3009struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
3010 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3012 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3013 PatternRewriter &rewriter)
const override {
3014 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3015 return matchPattern(operand, matchConstantIndex());
3019 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3022 return std::nullopt;
3023 return castOp.getSource();
3025 std::optional<Value> sourceCastSource =
3026 getSourceOfCastOp(insertSliceOp.getSource());
3027 std::optional<Value> destCastSource =
3028 getSourceOfCastOp(insertSliceOp.getDest());
3029 if (!sourceCastSource && !destCastSource)
3033 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3034 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3035 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
3036 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3037 if (!srcType || !dstType)
3043 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3045 staticSizes, srcType.getShape(),
true);
3046 if (!rankReductionMask.has_value())
3053 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3054 int64_t rankReducedIdx = 0;
3055 for (
auto [idx, size] :
enumerate(staticSizes)) {
3056 if (!rankReductionMask.value().contains(idx) &&
3057 !srcType.isDynamicDim(rankReducedIdx)) {
3059 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
3060 size = srcType.getDimSize(rankReducedIdx++);
3066 staticSizes, insertSliceOp.getStaticStrides()) !=
3067 SliceVerificationResult::Success)
3069 SliceBoundsVerificationResult sliceResult =
3071 mixedSizes, insertSliceOp.getMixedStrides());
3076 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3077 insertSliceOp.getMixedOffsets(), mixedSizes,
3078 insertSliceOp.getMixedStrides());
3081 bool isParallelInsert =
3082 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3083 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3084 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3085 insertSliceOp.getDestType(),
3114template <
typename InsertOpTy>
3115struct InsertSliceOpSourceCastInserter final
3116 :
public OpRewritePattern<InsertOpTy> {
3117 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3119 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3120 PatternRewriter &rewriter)
const override {
3121 RankedTensorType srcType = insertSliceOp.getSourceType();
3122 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3124 SmallVector<int64_t> newSrcShape(srcType.getShape());
3125 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3126 if (std::optional<int64_t> constInt =
3131 newSrcShape[i] = *constInt;
3137 RankedTensorType newSrcType = RankedTensorType::get(
3138 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3139 if (srcType == newSrcType ||
3141 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3149 OpBuilder::InsertionGuard g(rewriter);
3153 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3155 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3156 newSrcType, insertSliceOp.getSource());
3158 insertSliceOp, cast, insertSliceOp.getDest(),
3159 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3160 insertSliceOp.getMixedStrides());
3166llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3170void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3171 MLIRContext *context) {
3172 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3173 InsertSliceOpCastFolder<InsertSliceOp>,
3174 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3181 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3182 unsigned rank = rankedTensorType.getRank();
3186 return b.createOrFold<tensor::InsertSliceOp>(loc,
tensor, dest, offsets,
3195 setNameFn(getResult(),
"padded");
3198LogicalResult PadOp::verify() {
3199 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3200 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3202 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3203 if (!expectedType) {
3204 return emitError(
"failed to infer expectedType from sourceType ")
3205 << sourceType <<
", specified resultType is " << resultType;
3207 if (resultType.getRank() != expectedType.getRank()) {
3209 << resultType <<
" does not match the inferred type "
3212 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3213 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3215 if (expectedType.isDynamicDim(i))
3218 << resultType <<
" does not match the inferred type "
3225LogicalResult PadOp::verifyRegions() {
3226 auto ®ion = getRegion();
3227 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3228 Block &block = region.front();
3230 return emitError(
"expected the block to have ") << rank <<
" arguments";
3234 if (!en.value().isIndex())
3236 << (en.index() + 1) <<
" to be an index";
3241 if (yieldOp.getValue().getType() !=
3243 return emitOpError(
"expected yield type to match shape element type");
3248RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3249 ArrayRef<int64_t> staticLow,
3250 ArrayRef<int64_t> staticHigh,
3251 ArrayRef<int64_t> resultShape) {
3252 unsigned rank = sourceType.getRank();
3253 if (staticLow.size() != rank)
3254 return RankedTensorType();
3255 if (staticHigh.size() != rank)
3256 return RankedTensorType();
3257 if (!resultShape.empty() && resultShape.size() != rank)
3258 return RankedTensorType();
3260 SmallVector<int64_t, 4> inferredShape;
3261 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3262 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3263 staticHigh[i] == ShapedType::kDynamic) {
3264 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3267 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3268 assert((resultShape.empty() || size == resultShape[i] ||
3269 resultShape[i] == ShapedType::kDynamic) &&
3270 "mismatch between inferred shape and result shape");
3271 inferredShape.push_back(size);
3275 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3278void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3279 Value source, ArrayRef<int64_t> staticLow,
3281 bool nofold, ArrayRef<NamedAttribute> attrs) {
3282 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3284 resultType = inferResultType(sourceType, staticLow, staticHigh);
3285 result.addAttributes(attrs);
3286 build(
b,
result, resultType, source, low, high,
3287 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3288 nofold ?
b.getUnitAttr() : UnitAttr());
3291void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3293 ArrayRef<NamedAttribute> attrs) {
3294 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3295 unsigned rank = sourceType.getRank();
3296 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3297 build(
b,
result, resultType, source, staticVector, staticVector, low, high,
3301void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3302 Value source, ArrayRef<OpFoldResult> low,
3303 ArrayRef<OpFoldResult> high,
bool nofold,
3304 ArrayRef<NamedAttribute> attrs) {
3305 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3306 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3307 SmallVector<int64_t, 4> staticLow, staticHigh;
3315 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3317 assert(llvm::isa<RankedTensorType>(resultType));
3318 result.addAttributes(attrs);
3319 build(
b,
result, resultType, source, dynamicLow, dynamicHigh,
3320 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3321 nofold ?
b.getUnitAttr() : UnitAttr());
3324void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3325 Value source, ArrayRef<OpFoldResult> low,
3326 ArrayRef<OpFoldResult> high, Value constantPadValue,
3327 bool nofold, ArrayRef<NamedAttribute> attrs) {
3328 build(
b,
result, resultType, source, low, high, nofold, attrs);
3331 Region *region =
result.regions[0].get();
3332 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3333 SmallVector<Type> blockArgTypes(sourceRank,
b.getIndexType());
3334 SmallVector<Location> blockArgLocs(sourceRank,
result.location);
3338 OpBuilder::InsertionGuard guard(
b);
3339 b.createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3340 tensor::YieldOp::create(
b,
result.location, constantPadValue);
3343llvm::SmallBitVector PadOp::getPaddedDims() {
3344 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3345 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3346 for (
const auto &en :
enumerate(paddingWidths))
3348 paddedDims.set(en.index());
3350 extractPaddedDims(getMixedLowPad());
3351 extractPaddedDims(getMixedHighPad());
3358struct FoldStaticZeroPadding :
public OpRewritePattern<PadOp> {
3359 using OpRewritePattern<PadOp>::OpRewritePattern;
3361 LogicalResult matchAndRewrite(PadOp padTensorOp,
3362 PatternRewriter &rewriter)
const override {
3363 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3365 if (padTensorOp.getNofold())
3368 padTensorOp, padTensorOp.getResult().
getType(),
3369 padTensorOp.getSource());
3375struct FoldSourceTensorCast :
public OpRewritePattern<PadOp> {
3376 using OpRewritePattern<PadOp>::OpRewritePattern;
3378 LogicalResult matchAndRewrite(PadOp padTensorOp,
3379 PatternRewriter &rewriter)
const override {
3380 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3384 auto newResultType = PadOp::inferResultType(
3385 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3386 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3387 padTensorOp.getResultType().getShape());
3389 if (newResultType == padTensorOp.getResultType()) {
3391 padTensorOp.getSourceMutable().assign(castOp.getSource());
3394 auto newOp = PadOp::create(
3395 rewriter, padTensorOp->getLoc(), newResultType,
3396 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3397 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3398 padTensorOp.getHigh(), padTensorOp.getNofold(),
3401 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3404 padTensorOp, padTensorOp.getResultType(), newOp);
3412struct FoldTargetTensorCast :
public OpRewritePattern<PadOp> {
3413 using OpRewritePattern<PadOp>::OpRewritePattern;
3415 LogicalResult matchAndRewrite(PadOp padTensorOp,
3416 PatternRewriter &rewriter)
const override {
3417 if (!padTensorOp.getResult().hasOneUse())
3420 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3424 tensorCastOp.getDest().getType()))
3427 auto replacementOp = PadOp::create(
3428 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3429 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3430 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3431 padTensorOp.getHigh(), padTensorOp.getNofold(),
3433 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3435 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3436 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3476struct FoldOrthogonalPaddings :
public OpRewritePattern<PadOp> {
3477 using OpRewritePattern<PadOp>::OpRewritePattern;
3479 LogicalResult matchAndRewrite(PadOp padOp,
3480 PatternRewriter &rewriter)
const override {
3481 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3484 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3485 if (!outerPadOp || outerPadOp.getNofold())
3487 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3492 int64_t rank = padOp.getSourceType().getRank();
3493 if (outerSliceOp.getSourceType().getRank() != rank) {
3495 "cannot fold rank-reducing chain");
3499 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3501 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3505 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3507 "cannot fold PadOps with low padding");
3511 Attribute innerAttr, outerAttr;
3512 Value innerValue = padOp.getConstantPaddingValue();
3513 Value outerValue = outerPadOp.getConstantPaddingValue();
3514 if (!innerValue || !outerValue ||
3517 innerAttr != outerAttr) {
3519 padOp,
"cannot fold PadOps with different padding values");
3523 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3524 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3525 if (innerDims.anyCommon(outerDims)) {
3527 padOp,
"cannot fold PadOps with common padding dimensions");
3535 SmallVector<OpFoldResult> newOffsets(rank, rewriter.
getIndexAttr(0));
3537 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3538 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3539 if (!innerDims.test(en.index()) &&
3541 en.value() = outerOffset;
3544 if (!outerDims.test(en.index()) &&
3546 en.value() = innerOffset;
3550 padOp,
"cannot find zero-offset and zero-padding pair");
3558 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3560 if (!outerDims.test(en.index()))
3562 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3563 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3564 assert(ShapedType::isStatic(sourceSize) &&
3565 "expected padded dimension to have a static size");
3568 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3569 "match the size of the outer padding");
3571 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3575 SmallVector<OpFoldResult> newHighPad(rank, rewriter.
getIndexAttr(0));
3577 if (innerDims.test(en.index()))
3578 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3579 if (outerDims.test(en.index()))
3580 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3585 auto newSliceOp = ExtractSliceOp::create(
3586 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3587 newSizes, innerSliceOp.getMixedStrides());
3588 auto newPadOp = PadOp::create(
3589 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3590 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3593 newPadOp.getRegion().begin());
3594 rewriter.
replaceOp(padOp, newPadOp.getResult());
3599struct FoldStaticPadding :
public OpRewritePattern<PadOp> {
3600 using OpRewritePattern<PadOp>::OpRewritePattern;
3602 LogicalResult matchAndRewrite(PadOp padTensorOp,
3603 PatternRewriter &rewriter)
const override {
3604 Value input = padTensorOp.getSource();
3605 if (!llvm::isa<RankedTensorType>(input.
getType()))
3607 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3608 auto inputRank = inputDims.size();
3610 auto oldResultType =
3611 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3615 auto outputDims = oldResultType.getShape();
3618 SmallVector<int64_t> constOperandsLow;
3619 SmallVector<Value> newLows;
3620 for (
auto operand : padTensorOp.getLow()) {
3623 constOperandsLow.push_back(ShapedType::kDynamic);
3624 newLows.push_back(operand);
3627 constOperandsLow.push_back(intOp.getExtValue());
3629 SmallVector<int64_t> constOperandsHigh;
3630 SmallVector<Value> newHighs;
3631 for (
auto operand : padTensorOp.getHigh()) {
3634 constOperandsHigh.push_back(ShapedType::kDynamic);
3635 newHighs.push_back(operand);
3638 constOperandsHigh.push_back(intOp.getExtValue());
3641 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3642 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3645 if (inputDims.size() != outputDims.size() ||
3646 inputDims.size() != constLow.size() ||
3647 inputDims.size() != constHigh.size())
3652 for (
size_t i = 0; i < inputRank; i++) {
3653 if (constLow[i] == ShapedType::kDynamic)
3654 constLow[i] = constOperandsLow[lowCount++];
3655 if (constHigh[i] == ShapedType::kDynamic)
3656 constHigh[i] = constOperandsHigh[highCount++];
3659 auto staticLow = ArrayRef<int64_t>(constLow);
3660 auto staticHigh = ArrayRef<int64_t>(constHigh);
3663 SmallVector<int64_t> newOutDims;
3664 for (
size_t i = 0; i < inputRank; i++) {
3665 if (outputDims[i] == ShapedType::kDynamic) {
3666 newOutDims.push_back(
3667 (staticLow[i] == ShapedType::kDynamic ||
3668 staticHigh[i] == ShapedType::kDynamic ||
3669 inputDims[i] == ShapedType::kDynamic
3670 ? ShapedType::kDynamic
3671 : inputDims[i] + staticLow[i] + staticHigh[i]));
3673 newOutDims.push_back(outputDims[i]);
3677 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3678 llvm::all_of(newOutDims,
3679 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3683 auto newResultType = RankedTensorType::get(
3684 newOutDims, padTensorOp.getType().getElementType());
3685 auto newOp = PadOp::create(
3686 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3687 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3691 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3719struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3720 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3722 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3723 PatternRewriter &rewriter)
const override {
3724 if (padOp.getNofold()) {
3728 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3729 if (!producerPad || producerPad.getNofold()) {
3731 padOp,
"producer is not a foldable tensor.pad op");
3735 Value consumerPadValue = padOp.getConstantPaddingValue();
3736 Value producerPadValue = producerPad.getConstantPaddingValue();
3737 if (!consumerPadValue || !producerPadValue ||
3738 consumerPadValue != producerPadValue) {
3741 "cannot fold PadOps with different or non-constant padding values");
3744 Location loc = padOp.getLoc();
3749 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3750 ArrayRef<OpFoldResult> producerPaddings) {
3751 SmallVector<OpFoldResult> sumPaddings;
3752 for (
auto [consumerIndex, producerIndex] :
3753 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3755 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3760 SmallVector<OpFoldResult> newHighPad =
3761 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3762 SmallVector<OpFoldResult> newLowPad =
3763 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3765 auto newPadOp = tensor::PadOp::create(
3766 rewriter, padOp.getLoc(), padOp.getResultType(),
3767 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3770 newPadOp.getRegion().begin());
3771 rewriter.
replaceOp(padOp, newPadOp.getResult());
3779PadOp::reifyResultShapes(OpBuilder &
b,
3781 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
3782 SmallVector<OpFoldResult> lp = getMixedLowPad();
3783 SmallVector<OpFoldResult> hp = getMixedHighPad();
3784 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3785 if (!
getType().isDynamicDim(i)) {
3786 reifiedReturnShapes[0][i] =
b.getIndexAttr(
getType().getDimSize(i));
3789 Location loc = getLoc();
3790 Value dim =
b.createOrFold<tensor::DimOp>(
3793 AffineExpr d0, d1, d2;
3796 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3801void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3802 MLIRContext *context) {
3803 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3804 FoldOrthogonalPaddings, FoldStaticPadding,
3805 FoldConsecutiveConstantPadding>(context);
3817Value PadOp::getConstantPaddingValue() {
3818 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3821 Value padValue = yieldOp.getValue();
3832OpFoldResult PadOp::fold(FoldAdaptor) {
3833 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3843OpResult ParallelInsertSliceOp::getTiedOpResult() {
3844 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3845 for (
const auto &it :
3846 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3847 Operation &nextOp = it.value();
3848 if (&nextOp == getOperation())
3849 return parallelCombiningParent.getParentResult(it.index());
3851 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3855void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3856 Value source, Value dest,
3857 ArrayRef<OpFoldResult> offsets,
3858 ArrayRef<OpFoldResult> sizes,
3859 ArrayRef<OpFoldResult> strides,
3860 ArrayRef<NamedAttribute> attrs) {
3861 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3862 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3866 result.addAttributes(attrs);
3867 build(
b,
result, {}, source, dest, dynamicOffsets, dynamicSizes,
3868 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3869 b.getDenseI64ArrayAttr(staticSizes),
3870 b.getDenseI64ArrayAttr(staticStrides));
3875void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3876 Value source, Value dest,
3877 ArrayRef<Range> ranges,
3878 ArrayRef<NamedAttribute> attrs) {
3880 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
3884void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3885 Value source, Value dest,
ValueRange offsets,
3887 ArrayRef<NamedAttribute> attrs) {
3888 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3889 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
3890 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3891 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
3892 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3893 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
3894 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
3897LogicalResult ParallelInsertSliceOp::verify() {
3898 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3899 return this->
emitError(
"expected InParallelOpInterface parent, got:")
3900 << *(getOperation()->getParentOp());
3903 RankedTensorType expectedType;
3906 getStaticSizes(), getStaticStrides(), &expectedType);
3913 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3914 getStaticStrides(),
true);
3916 return getOperation()->emitError(boundsResult.
errorMessage);
3921void ParallelInsertSliceOp::getCanonicalizationPatterns(
3922 RewritePatternSet &results, MLIRContext *context) {
3923 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3924 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3925 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3928llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3933MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3934 return getDestMutable();
3937Operation *ParallelInsertSliceOp::getIteratingParent() {
3939 if (
auto combiningOp =
3940 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3941 return combiningOp->getParentOp();
3949void ScatterOp::getAsmResultNames(
3951 setNameFn(getResult(),
"scatter");
3954LogicalResult ScatterOp::verify() {
3955 int64_t destRank = getDestType().getRank();
3956 ArrayRef<int64_t> scatterDims = getScatterDims();
3958 getIndicesType().
getShape(), destRank,
3959 "scatter",
"dest")))
3963 return emitOpError(
"requires 'unique' attribute to be set");
3970 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3971 getDestType(), getIndicesType(), scatterDims,
false);
3972 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3973 getDestType(), getIndicesType(), scatterDims,
true);
3974 if (getSourceType() != expectedSourceType &&
3975 getSourceType() != expectedRankReducedSourceType) {
3979 << expectedSourceType <<
" or its rank-reduced variant "
3980 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3991void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
3992 Type aggregateType,
ValueRange dynamicSizes) {
3993 build(builder,
result, aggregateType, element, dynamicSizes);
3996void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
3997 ArrayRef<int64_t> staticShape,
ValueRange dynamicSizes) {
3998 auto aggregateType = RankedTensorType::get(staticShape, element.
getType());
3999 build(builder,
result, aggregateType, element, dynamicSizes);
4002void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4003 ArrayRef<OpFoldResult> sizes) {
4004 SmallVector<int64_t> staticShape;
4005 SmallVector<Value> dynamicSizes;
4007 build(builder,
result, element, staticShape, dynamicSizes);
4010void SplatOp::getAsmResultNames(
4012 setNameFn(getResult(),
"splat");
4015LogicalResult SplatOp::verify() {
4017 return emitOpError(
"incorrect number of dynamic sizes, has ")
4019 <<
getType().getNumDynamicDims();
4024SplatOp::reifyResultShapes(OpBuilder &builder,
4026 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
4028 for (int64_t i = 0; i <
getType().getRank(); ++i) {
4029 if (
getType().isDynamicDim(i)) {
4038OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4039 auto constOperand = adaptor.getInput();
4040 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4044 if (!
getType().hasStaticShape())
4059 if (isa<InsertSliceOp>(op.getOperation()) ||
4060 isa<LoopLikeOpInterface>(op.getOperation()))
4093 isa<linalg::RelayoutOpInterface>(*op))
4101 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4104 replacements.reserve(newOp->getNumResults());
4105 for (
auto [oldResult, newResult] :
4106 llvm::zip(op->getResults(), newOp->getResults())) {
4107 if (newResult.getType() != oldResult.getType()) {
4108 replacements.push_back(tensor::CastOp::create(
4109 rewriter, op->getLoc(), oldResult.
getType(), newResult));
4111 replacements.push_back(newResult);
4124void TensorDialect::getCanonicalizationPatterns(
4125 RewritePatternSet &results)
const {
4133#define GET_OP_CLASSES
4134#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
Attributes are known-constant values of operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.