34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/Repeated.h"
36#include "llvm/ADT/STLExtras.h"
37#include "llvm/ADT/SmallBitVector.h"
38#include "llvm/ADT/SmallVectorExtras.h"
39#include "llvm/ADT/StringRef.h"
40#include "llvm/Support/Casting.h"
41#include "llvm/Support/MathExtras.h"
52 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
54 if (complex::ConstantOp::isBuildableWith(value, type))
55 return complex::ConstantOp::create(builder, loc, type,
56 llvm::cast<ArrayAttr>(value));
62 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
63 if (tensorType.isDynamicDim(dim))
64 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
71 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
73 for (
int64_t i = 0; i < tensorType.getRank(); ++i)
80 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
81 assert(tensorType &&
"expected tensor type");
85 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
87 return destOp.getTiedOpOperand(opResult)->get();
95 if (!tensorType.hasStaticShape()) {
103 for (
int64_t sz : tensorType.getShape())
104 mixedSizes.push_back(
b.getIndexAttr(sz));
109 if (
auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType))
110 encoding = rankedTensorType.getEncoding();
111 Value emptyTensor = tensor::EmptyOp::create(
112 b, loc, mixedSizes, tensorType.getElementType(), encoding);
120 if (llvm::isa<TensorType>(opResult.getType())) {
122 if (failed(destination))
124 result.push_back(*destination);
131 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
132 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
133 return rtp1.getShape() == rtp2.getShape() &&
134 rtp1.getElementType() == rtp2.getElementType();
144 llvm::SmallBitVector droppedDims(mixedSizes.size());
145 int64_t shapePos = reducedShape.size() - 1;
147 for (
const auto &size : enumerate(llvm::reverse(mixedSizes))) {
148 size_t idx = mixedSizes.size() - size.index() - 1;
150 bool isStaticUnitSize =
151 isa<Attribute>(size.value()) &&
152 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
157 assert(isStaticUnitSize &&
"expected unit dim");
158 droppedDims.set(idx);
163 if (!isStaticUnitSize) {
169 if (reducedShape[shapePos] == 1) {
175 droppedDims.set(idx);
178 assert(shapePos < 0 &&
"dimension mismatch");
185static RankedTensorType
189 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
190 "incorrect number of dynamic sizes");
194 for (
int64_t i = 0, e = type.getRank(); i < e; ++i) {
195 if (type.isDynamicDim(i)) {
196 Value dynamicSize = dynamicSizes[ctr++];
198 if (cst.has_value()) {
200 if (cst.value() < 0) {
201 foldedDynamicSizes.push_back(dynamicSize);
204 staticShape[i] = *cst;
206 foldedDynamicSizes.push_back(dynamicSize);
211 return RankedTensorType::get(staticShape, type.getElementType(),
220 if (inputs.size() != 1 || outputs.size() != 1)
222 Type a = inputs.front(),
b = outputs.front();
223 auto aT = dyn_cast<TensorType>(a);
224 auto bT = dyn_cast<TensorType>(
b);
228 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
239 using OpRewritePattern<BitcastOp>::OpRewritePattern;
241 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
242 PatternRewriter &rewriter)
const final {
243 auto tensorBitcastOperand =
244 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
245 if (!tensorBitcastOperand)
248 auto resultType = cast<TensorType>(tensorBitcast.getType());
249 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
250 tensorBitcastOperand.getOperand());
259 results.
add<ChainedTensorBitcast>(context);
267 setNameFn(getResult(),
"cast");
273 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
274 auto targetType = llvm::dyn_cast<RankedTensorType>(
target);
277 if (!sourceType || !targetType)
281 if (sourceType.getElementType() != targetType.getElementType())
285 if (sourceType.getRank() != targetType.getRank())
289 if (sourceType.getEncoding() != targetType.getEncoding())
293 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
294 if (ShapedType::isStatic(std::get<0>(t)) &&
295 ShapedType::isDynamic(std::get<1>(t)))
331 castOp.getSource().getType());
364 if (llvm::isa<BlockArgument>(opOperand.get()))
366 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
367 return castOp && canFoldIntoConsumerOp(castOp);
374 newOperands.reserve(op->getNumOperands());
380 for (
OpOperand &opOperand : op->getOpOperands()) {
381 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
383 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
384 if (op.isDpsInit(&opOperand) &&
385 !llvm::isa<MemRefType>(newOperands.back().getType()))
386 newResTy[dpsInitIdx++] = newOperands.back().getType();
396 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
398 operand.set(castOp.getOperand());
406 if (inputs.size() != 1 || outputs.size() != 1)
408 Type a = inputs.front(),
b = outputs.front();
409 auto aT = llvm::dyn_cast<TensorType>(a);
410 auto bT = llvm::dyn_cast<TensorType>(
b);
414 if (aT.getElementType() != bT.getElementType())
431 if (rank != two.getRank())
436 for (
int64_t i = 0; i < rank; ++i) {
437 if (one.isDynamicDim(i)) {
438 join.push_back(two.getDimSize(i));
441 if (two.isDynamicDim(i)) {
442 join.push_back(one.getDimSize(i));
445 if (one.getDimSize(i) != two.getDimSize(i))
447 join.push_back(one.getDimSize(i));
457 using OpRewritePattern<CastOp>::OpRewritePattern;
459 LogicalResult matchAndRewrite(CastOp tensorCast,
460 PatternRewriter &rewriter)
const final {
461 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
463 if (!tensorCastOperand)
467 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
468 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
469 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
483 auto newJoin =
joinShapes(sourceType, resultType);
484 if (firstJoin != newJoin)
487 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
488 tensorCastOperand.getOperand());
506 using OpRewritePattern<CastOp>::OpRewritePattern;
508 LogicalResult matchAndRewrite(CastOp tensorCast,
509 PatternRewriter &rewriter)
const final {
510 auto extractOperand =
511 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
514 auto rankedResultType =
515 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
516 if (!rankedResultType)
520 rankedResultType.getShape() ==
521 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
525 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
527 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
529 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
530 if (dimMask && dimMask->count(i))
532 int64_t dim = rankedResultType.getShape()[dimIndex++];
533 if (ShapedType::isDynamic(dim))
535 sizes[i] = rewriter.getIndexAttr(dim);
538 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
539 tensorCast, rankedResultType, extractOperand.getSource(),
540 extractOperand.getMixedOffsets(), sizes,
541 extractOperand.getMixedStrides());
550 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
557RankedTensorType ConcatOp::inferResultType(
int64_t dim,
TypeRange inputTypes) {
558 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
560 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
561 int64_t concatRank = tensorTypes[0].getRank();
564 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
567 for (
int64_t i = 0, e = concatRank; i < e; ++i) {
571 for (
auto tensorType : tensorTypes)
576 for (
auto tensorType : tensorTypes)
579 sizes[dim] = concatSize.asInteger();
580 return RankedTensorType::get(sizes, tensorTypes[0].
getElementType());
585 FailureOr<RankedTensorType> resultType =
586 inferResultType(dim, inputs.
getTypes());
587 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
588 build(builder,
result, *resultType, dim, inputs);
591LogicalResult ConcatOp::verify() {
592 if (getInputs().size() < 1)
596 for (
auto input : getInputs())
597 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
599 RankedTensorType resultType = getResultType();
600 int64_t resultRank = getRank();
601 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
602 return type.getRank() != resultRank;
604 return emitOpError(
"rank of concatenated inputs must match result rank");
606 Type resultElementType = resultType.getElementType();
607 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
608 return type.getElementType() != resultElementType;
610 return emitOpError(
"inputs and result element type must match");
613 if (dim >= resultRank)
614 return emitOpError(
"concatenation dim must be less than the tensor rank");
617 for (
int64_t i = 0, e = resultRank; i < e; ++i) {
621 for (
auto tensorType : inputTypes) {
622 FailureOr<SaturatedInteger> maybeSize =
625 return emitOpError(
"static concatenation size mismatch along ")
626 <<
"non-concatenated dimension " << i;
632 for (
auto tensorType : inputTypes)
635 sizes[dim] = concatSize.asInteger();
636 auto inferredResultType =
639 for (
auto [inferredSize, actualSize] :
640 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
641 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
642 ShapedType::isDynamic(actualSize);
643 if (!hasDynamic && inferredSize != actualSize)
645 << resultType <<
"does not match inferred shape "
646 << inferredResultType <<
" static sizes";
652FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
653 size_t numInputs = getInputs().size();
654 uint64_t concatDim = getDim();
657 inputShapes.reserve(numInputs);
659 concatOffsets.reserve(numInputs);
666 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
670 outputShape = inputShape;
671 concatOffsets.push_back(zero);
673 concatOffsets.push_back(outputShape[concatDim]);
675 builder, loc, addExpr,
676 {outputShape[concatDim], inputShape[concatDim]});
678 inputShapes.emplace_back(std::move(inputShape));
688 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
689 offsets[concatDim] = concatOffsets[
index];
690 auto insertSlice = tensor::InsertSliceOp::create(
701ConcatOp::reifyResultShapes(
OpBuilder &builder,
705 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
707 Value init = inputs[0];
715 for (
int64_t i = 0; i < rank; ++i) {
718 if (!
getType().isDynamicDim(i)) {
720 }
else if (!inferredResultType.isDynamicDim(i)) {
723 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
725 reifiedReturnShapes[0][i] =
726 tensor::DimOp::create(builder, init.
getLoc(), init, i).getResult();
730 if (
getType().isDynamicDim(dim)) {
735 for (
auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
738 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
746 reifiedReturnShapes[0][dim] =
752void ConcatOp::getAsmResultNames(
754 setNameFn(getResult(),
"concat");
759 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
767 using OpRewritePattern<ConcatOp>::OpRewritePattern;
769 LogicalResult matchAndRewrite(ConcatOp concatOp,
770 PatternRewriter &rewriter)
const override {
771 if (concatOp.getInputs().size() != 1)
774 concatOp.getInputs()[0]);
799 using OpRewritePattern<ConcatOp>::OpRewritePattern;
801 LogicalResult matchAndRewrite(ConcatOp concatOp,
802 PatternRewriter &rewriter)
const override {
803 int64_t dim = concatOp.getDim();
804 RankedTensorType inferredResultType =
805 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
808 LogicalResult matched = failure();
811 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
812 for (
auto [operandIdx, operandType] :
813 llvm::enumerate(concatOp->getOperandTypes())) {
815 inferredOperandShape[dim] =
816 cast<RankedTensorType>(operandType).getDimSize(dim);
817 auto inferredOperandType = RankedTensorType::get(
818 inferredOperandShape, inferredResultType.getElementType());
826 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
827 concatOp.getOperand(operandIdx));
829 concatOp->setOperand(operandIdx, castOp->getResult(0));
853 using OpRewritePattern<ConcatOp>::OpRewritePattern;
855 LogicalResult matchAndRewrite(ConcatOp concatOp,
856 PatternRewriter &rewriter)
const override {
857 int64_t dim = concatOp.getDim();
858 RankedTensorType inferredResultType =
859 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
863 concatOp.getResultType())) {
868 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
869 concatOp->getOperands());
881 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
890 setNameFn(getResult(),
"dim");
895 auto loc =
result.location;
897 build(builder,
result, source, indexValue);
900std::optional<int64_t> DimOp::getConstantIndex() {
909 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
910 if (!rankedSourceType)
913 if (rankedSourceType.getRank() <= constantIndex)
921 setResultRange(getResult(),
927 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
932 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
939 if (indexVal < 0 || indexVal >= tensorType.getRank())
943 if (!tensorType.isDynamicDim(
index.getInt())) {
948 Operation *definingOp = getSource().getDefiningOp();
951 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
953 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
956 assert(ShapedType::isDynamic(resultType.getShape()[
index.getInt()]));
959 auto dynExtents = fromElements.getDynamicExtents().begin();
960 for (
auto dim : resultType.getShape().take_front(
index.getInt()))
961 if (ShapedType::isDynamic(dim))
964 return Value{*dynExtents};
968 unsigned unsignedIndex =
index.getValue().getZExtValue();
970 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
973 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
974 sliceOp.isDynamicSize(unsignedIndex)) {
975 return {sliceOp.getDynamicSize(unsignedIndex)};
989 using OpRewritePattern<DimOp>::OpRewritePattern;
991 LogicalResult matchAndRewrite(DimOp dimOp,
992 PatternRewriter &rewriter)
const override {
993 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
996 Value newSource = castOp.getOperand();
1005 using OpRewritePattern<DimOp>::OpRewritePattern;
1007 LogicalResult matchAndRewrite(DimOp dimOp,
1008 PatternRewriter &rewriter)
const override {
1009 auto source = dimOp.getSource();
1010 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1014 auto resultIndex = cast<OpResult>(source).getResultNumber();
1015 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1018 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1026 using OpRewritePattern<DimOp>::OpRewritePattern;
1028 LogicalResult matchAndRewrite(DimOp dim,
1029 PatternRewriter &rewriter)
const override {
1030 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1038 Location loc = dim.getLoc();
1040 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1041 if (extract.
getType() != dim.getType())
1043 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1052 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1062 assert(none_of(staticShape, ShapedType::isDynamic) &&
1063 "expected only static sizes");
1067void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1068 ArrayRef<int64_t> staticShape, Type elementType,
1069 ValueRange dynamicSizes, Attribute encoding) {
1070 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1071 build(builder,
result, tensorType, dynamicSizes);
1074void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1075 ArrayRef<OpFoldResult> sizes, Type elementType,
1076 Attribute encoding) {
1077 SmallVector<int64_t> staticShape;
1078 SmallVector<Value> dynamicSizes;
1080 build(builder,
result, staticShape, elementType, dynamicSizes, encoding);
1083LogicalResult EmptyOp::verify() {
1089EmptyOp::reifyResultShapes(OpBuilder &builder,
1091 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1093 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1094 if (
getType().isDynamicDim(i)) {
1103Value EmptyOp::getDynamicSize(
unsigned idx) {
1104 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1106 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1107 if (
getType().isDynamicDim(i))
1112SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1113 SmallVector<OpFoldResult>
result;
1117 if (ShapedType::isDynamic(dim)) {
1120 result.push_back(
b.getIndexAttr(dim));
1138struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1139 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1141 LogicalResult matchAndRewrite(EmptyOp op,
1142 PatternRewriter &rewriter)
const override {
1143 SmallVector<Value> foldedDynamicSizes;
1145 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1148 if (foldedTensorType == op.getType())
1151 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1152 foldedDynamicSizes);
1158struct FoldEmptyTensorWithDimOp :
public OpRewritePattern<DimOp> {
1159 using OpRewritePattern<DimOp>::OpRewritePattern;
1161 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1162 PatternRewriter &rewriter)
const override {
1163 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1164 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1165 if (!emptyTensorOp || !maybeConstantIndex)
1167 auto emptyTensorType = emptyTensorOp.getType();
1168 if (*maybeConstantIndex < 0 ||
1169 *maybeConstantIndex >= emptyTensorType.getRank() ||
1170 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1173 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1193struct FoldEmptyTensorWithCastOp :
public OpRewritePattern<CastOp> {
1194 using OpRewritePattern<CastOp>::OpRewritePattern;
1196 LogicalResult matchAndRewrite(CastOp castOp,
1197 PatternRewriter &rewriter)
const override {
1200 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1205 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1206 ArrayRef<int64_t> resultShape = resultType.getShape();
1207 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1208 SmallVector<OpFoldResult> newMixedSizes;
1209 newMixedSizes.reserve(currMixedSizes.size());
1210 assert(resultShape.size() == currMixedSizes.size() &&
1211 "mismatch in result shape and sizes of empty op");
1212 for (
auto [newDim, currDim] : llvm::zip(resultShape, currMixedSizes)) {
1215 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1216 if (ShapedType::isDynamic(newDim) ||
1217 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1222 producer,
"mismatch in static value of shape of empty tensor "
1223 "result and cast result");
1225 newMixedSizes.push_back(attr);
1231 if (ShapedType::isStatic(newDim)) {
1232 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1238 newMixedSizes.push_back(currDim);
1242 resultType.getElementType(),
1243 resultType.getEncoding());
1250void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1251 MLIRContext *context) {
1252 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1253 ReplaceEmptyTensorStaticShapeDims>(context);
1270struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1271 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1273 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1274 PatternRewriter &rewriter)
const final {
1275 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1278 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1281 extract, tensorCast.getSource(), extract.getIndices());
1296struct ExtractFromCollapseShape :
public OpRewritePattern<tensor::ExtractOp> {
1297 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1299 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1300 PatternRewriter &rewriter)
const final {
1302 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1305 if (!collapseOp.getSrcType().hasStaticShape())
1308 auto sourceSizes = collapseOp.getSrcType().getShape();
1310 SmallVector<Value>
indices(extractOp.getIndices().begin(),
1311 extractOp.getIndices().end());
1312 SmallVector<Value> sourceIndices;
1313 for (
auto [index, group] :
1314 llvm::zip(
indices, collapseOp.getReassociationIndices())) {
1315 assert(!group.empty() &&
"association indices groups cannot be empty");
1316 auto groupSize = group.size();
1318 if (groupSize == 1) {
1319 sourceIndices.push_back(index);
1323 SmallVector<int64_t> basis =
1324 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
1325 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1326 rewriter, extractOp.getLoc(), index, basis,
true);
1327 llvm::append_range(sourceIndices,
delinearize.getResults());
1329 if (collapseOp.getReassociationIndices().empty()) {
1332 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1333 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1334 rewriter, extractOp.getLoc(), zeroAffineMap,
1335 ArrayRef<OpFoldResult>{});
1336 for (int64_t i = 0; i < srcRank; i++) {
1337 sourceIndices.push_back(
1343 extractOp, collapseOp.getSrc(), sourceIndices);
1350void ExtractOp::getAsmResultNames(
1352 setNameFn(getResult(),
"extracted");
1355LogicalResult ExtractOp::verify() {
1357 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1358 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1359 return emitOpError(
"incorrect number of indices for extract_element");
1368 auto insertOp = extractOp.getTensor().
getDefiningOp<InsertOp>();
1373 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1374 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1375 return insertOp.getScalar();
1380OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1381 if (Attribute tensor = adaptor.getTensor()) {
1384 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1385 return splatTensor.getSplatValue<Attribute>();
1388 if (isa<DenseResourceElementsAttr>(tensor))
1393 SmallVector<uint64_t, 8>
indices;
1394 for (Attribute indice : adaptor.getIndices()) {
1395 if (!indice || !llvm::isa<IntegerAttr>(indice))
1397 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1401 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1402 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1403 auto rank = tensorType.getRank();
1404 assert(
static_cast<int64_t
>(
indices.size()) == tensorType.getRank() &&
1408 for (
int i = rank - 1; i >= 0; --i) {
1409 flatIndex +=
indices[i] * stride;
1410 stride *= tensorType.getDimSize(i);
1414 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1417 return fromElementsOp.getElements()[flatIndex];
1421 if (Attribute tensor = adaptor.getTensor()) {
1422 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1423 if (elementsAttr && elementsAttr.isValidIndex(
indices))
1424 return elementsAttr.getValues<Attribute>()[
indices];
1433void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1434 MLIRContext *context) {
1435 results.
add<ExtractFromTensorCast>(context);
1440 patterns.
add<ExtractFromCollapseShape>(patterns.
getContext());
1447void FromElementsOp::getAsmResultNames(
1449 setNameFn(getResult(),
"from_elements");
1454 assert(!elements.empty() &&
"expected at least one element");
1455 Type resultType = RankedTensorType::get(
1456 {
static_cast<int64_t>(elements.size())}, elements.front().
getType());
1457 build(builder,
result, resultType, elements);
1460OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1465 Type eltType =
getType().getElementType();
1468 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1491struct ExtractElementFromIndexCast
1492 :
public OpRewritePattern<tensor::ExtractOp> {
1493 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1495 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1496 PatternRewriter &rewriter)
const final {
1497 Location loc = extract.getLoc();
1498 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1504 auto newExtract = tensor::ExtractOp::create(
1505 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1516void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1517 MLIRContext *context) {
1518 results.
add<ExtractElementFromIndexCast>(context);
1525void GatherOp::getAsmResultNames(
1527 setNameFn(getResult(),
"gather");
1542RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1543 RankedTensorType indicesType,
1544 ArrayRef<int64_t> gatherDims,
1546 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1547 resultShape.reserve(resultShape.size() + sourceType.getRank());
1548 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1549 if (llvm::binary_search(gatherDims, idx)) {
1551 resultShape.push_back(1);
1554 resultShape.push_back(sourceType.getDimSize(idx));
1556 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1562 StringRef gatherOrScatter, StringRef sourceOrDest) {
1564 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1566 int64_t numGatherDims = dims.size();
1567 if (numGatherDims > rank)
1569 <<
"_dims overflow " << sourceOrDest <<
" rank";
1572 <<
"_dims length must match the size of last dimension of indices";
1576 <<
"_dims value must be non-negative";
1579 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1581 for (
int64_t i = 1; i < numGatherDims; ++i) {
1582 if (dims[i - 1] >= dims[i])
1584 <<
"_dims values must be strictly increasing";
1589LogicalResult GatherOp::verify() {
1590 int64_t sourceRank = getSourceType().getRank();
1591 ArrayRef<int64_t> gatherDims = getGatherDims();
1593 getIndicesType().
getShape(), sourceRank,
1594 "gather",
"source")))
1597 RankedTensorType expectedResultType = GatherOp::inferResultType(
1598 getSourceType(), getIndicesType(), gatherDims,
false);
1599 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1600 getSourceType(), getIndicesType(), gatherDims,
true);
1601 if (getResultType() != expectedResultType &&
1602 getResultType() != expectedRankReducedResultType) {
1606 << expectedResultType <<
" or its rank-reduced variant "
1607 << expectedRankReducedResultType <<
" (got: " << getResultType()
1614OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1615 if (OpFoldResult reshapedSource = reshapeConstantSource(
1616 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1618 return reshapedSource;
1626void InsertOp::getAsmResultNames(
1628 setNameFn(getResult(),
"inserted");
1631LogicalResult InsertOp::verify() {
1633 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1634 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1635 return emitOpError(
"incorrect number of indices");
1639OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1640 Attribute scalar = adaptor.getScalar();
1641 Attribute dest = adaptor.getDest();
1643 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1644 if (scalar == splatDest.getSplatValue<Attribute>())
1653void GenerateOp::getAsmResultNames(
1655 setNameFn(getResult(),
"generated");
1658LogicalResult GenerateOp::reifyResultShapes(
1660 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1662 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1663 if (
getType().isDynamicDim(dim)) {
1664 reifiedReturnShapes[0][dim] = getOperand(idx++);
1666 reifiedReturnShapes[0][dim] =
1673LogicalResult GenerateOp::verify() {
1676 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1683LogicalResult GenerateOp::verifyRegions() {
1684 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1686 if (!llvm::all_of(getBody().getArgumentTypes(),
1687 [](Type ty) {
return ty.
isIndex(); }))
1688 return emitError(
"all body arguments must be index");
1689 if (getBody().getNumArguments() != resultTy.getRank())
1690 return emitError(
"must have one body argument per input dimension");
1693 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1695 if (yieldOp.getValue().getType() != resultTy.getElementType())
1697 "body must be terminated with a `yield` operation of the tensor "
1703void GenerateOp::build(
1704 OpBuilder &
b, OperationState &
result, Type resultTy,
1707 build(
b,
result, resultTy, dynamicExtents);
1710 OpBuilder::InsertionGuard guard(
b);
1711 Region *bodyRegion =
result.regions.front().get();
1712 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1713 SmallVector<Type, 2> argumentTypes(rank,
b.getIndexType());
1714 SmallVector<Location, 2> argumentLocs(rank,
result.location);
1716 b.createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1726struct StaticTensorGenerate :
public OpRewritePattern<GenerateOp> {
1727 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1729 LogicalResult matchAndRewrite(GenerateOp generateOp,
1730 PatternRewriter &rewriter)
const final {
1731 SmallVector<Value> foldedDynamicSizes;
1733 generateOp.getType(), generateOp.getDynamicExtents(),
1734 foldedDynamicSizes);
1737 if (foldedTensorType == generateOp.getType())
1740 auto loc = generateOp.getLoc();
1742 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1744 newOp.getBody().begin());
1746 generateOp.getType(), newOp);
1762struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1763 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1765 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1766 PatternRewriter &rewriter)
const final {
1767 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1772 Block *body = &tensorFromElements.getBody().front();
1775 rewriter.
clone(op, mapping);
1786void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1787 MLIRContext *context) {
1789 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1796void RankOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1797 setNameFn(getResult(),
"rank");
1800OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1802 auto type = getOperand().getType();
1803 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1804 if (shapedType && shapedType.hasRank())
1805 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
1806 return IntegerAttr();
1813void ReshapeOp::getAsmResultNames(
1815 setNameFn(getResult(),
"reshape");
1820 for (
auto dim : type.getShape())
1825LogicalResult ReshapeOp::verify() {
1826 TensorType operandType = llvm::cast<TensorType>(getSource().
getType());
1827 TensorType resultType = llvm::cast<TensorType>(getResult().
getType());
1830 return emitOpError(
"element types of source and destination tensor "
1831 "types should be the same");
1835 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1836 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1838 if (resultRankedType) {
1839 if (operandRankedType && resultRankedType.hasStaticShape() &&
1840 operandRankedType.hasStaticShape()) {
1842 return emitOpError(
"source and destination tensor should have the "
1843 "same number of elements");
1845 if (ShapedType::isDynamic(shapeSize))
1846 return emitOpError(
"cannot use shape operand with dynamic length to "
1847 "reshape to statically-ranked tensor type");
1848 if (shapeSize != resultRankedType.getRank())
1850 "length of shape operand differs from the result's tensor rank");
1855OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1856 if (OpFoldResult reshapedSource = reshapeConstantSource(
1857 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1859 return reshapedSource;
1864 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1865 getSourceMutable().assign(reshapeOpProducer.getSource());
1869 auto source = getSource();
1870 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1871 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1872 if (!sourceTy || !resultTy || sourceTy != resultTy)
1877 if (sourceTy.getRank() <= 1)
1880 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1881 auto elements = fromElements.getElements();
1883 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1884 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++
id) {
1885 auto element = elements[id];
1888 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1892 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1893 dynamicNoop &= dimOp.getSource() == source;
1897 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1901 dynamicNoop =
false;
1916void CollapseShapeOp::getAsmResultNames(
1918 setNameFn(getResult(),
"collapsed");
1921void ExpandShapeOp::getAsmResultNames(
1923 setNameFn(getResult(),
"expanded");
1926int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1927 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1928 "invalid resultDim");
1929 for (
const auto &it : llvm::enumerate(getReassociationIndices()))
1930 if (llvm::is_contained(it.value(), resultDim))
1932 llvm_unreachable(
"could not find reassociation group");
1935FailureOr<SmallVector<OpFoldResult>>
1936ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
1937 RankedTensorType expandedType,
1938 ArrayRef<ReassociationIndices> reassociation,
1939 ArrayRef<OpFoldResult> inputShape) {
1940 std::optional<SmallVector<OpFoldResult>> outputShape =
1945 return *outputShape;
1948SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1952void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1953 Type resultType, Value src,
1954 ArrayRef<ReassociationIndices> reassociation,
1955 ArrayRef<OpFoldResult> outputShape) {
1956 auto [staticOutputShape, dynamicOutputShape] =
1958 build(builder,
result, cast<RankedTensorType>(resultType), src,
1960 dynamicOutputShape, staticOutputShape);
1963void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1964 Type resultType, Value src,
1965 ArrayRef<ReassociationIndices> reassociation) {
1966 SmallVector<OpFoldResult> inputShape =
1968 auto tensorResultTy = cast<RankedTensorType>(resultType);
1969 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1970 builder,
result.location, tensorResultTy, reassociation, inputShape);
1971 SmallVector<OpFoldResult> outputShapeOrEmpty;
1972 if (succeeded(outputShape)) {
1973 outputShapeOrEmpty = *outputShape;
1975 build(builder,
result, tensorResultTy, src, reassociation,
1976 outputShapeOrEmpty);
1979SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1982SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1984 getReassociationIndices());
1987SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1990SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1992 getReassociationIndices());
1995RankedTensorType CollapseShapeOp::inferCollapsedType(
1996 RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
1997 return inferCollapsedType(
1999 type.getContext(), reassociation)));
2005CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2006 ArrayRef<AffineMap> reassociation) {
2007 auto shape = type.getShape();
2008 SmallVector<int64_t, 4> newShape;
2009 newShape.reserve(reassociation.size());
2014 unsigned currentDim = 0;
2015 for (AffineMap m : reassociation) {
2016 unsigned dim = m.getNumResults();
2017 auto band = shape.slice(currentDim, dim);
2019 if (llvm::is_contained(band, ShapedType::kDynamic))
2020 size = ShapedType::kDynamic;
2022 for (
unsigned d = 0; d < dim; ++d)
2023 size *= shape[currentDim + d];
2024 newShape.push_back(size);
2028 return RankedTensorType::get(newShape, type.getElementType());
2031void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2032 ArrayRef<ReassociationIndices> reassociation,
2033 ArrayRef<NamedAttribute> attrs) {
2034 auto srcType = llvm::cast<RankedTensorType>(src.
getType());
2035 RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2037 RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2038 srcType.getEncoding());
2039 result.addAttribute(getReassociationAttrStrName(),
2041 build(
b,
result, resultType, src, attrs);
2044template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
2045 TensorReshapeOp, ExpandShapeOp>::value>
2047 RankedTensorType expandedType,
2048 RankedTensorType collapsedType) {
2050 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2054 if (expandedType.hasStaticShape() && collapsedType.hasStaticShape()) {
2055 int64_t expandedNumElements = expandedType.getNumElements();
2056 int64_t collapsedNumElements = collapsedType.getNumElements();
2057 if (expandedNumElements != collapsedNumElements) {
2058 return op.emitOpError(
"number of elements must be preserved: ")
2059 << expandedNumElements <<
" != " << collapsedNumElements;
2063 auto maps = op.getReassociationMaps();
2064 RankedTensorType expectedType =
2065 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2067 return op.emitOpError(
"expected collapsed type to be ")
2068 << expectedType <<
", but got " << collapsedType;
2072LogicalResult ExpandShapeOp::verify() {
2073 RankedTensorType srcType = getSrc().getType();
2074 RankedTensorType resultType = getResult().getType();
2076 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2077 return emitOpError(
"expected number of static shape dims to be equal to "
2078 "the output rank (")
2079 << resultType.getRank() <<
") but found "
2080 << getStaticOutputShape().size() <<
" inputs instead";
2082 if ((int64_t)getOutputShape().size() !=
2083 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2084 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2085 "static_output_shape: static_output_shape has ")
2086 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2087 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2098 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2099 for (
auto [pos, shape] : llvm::enumerate(resShape))
2100 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos])
2101 return emitOpError(
"invalid output shape provided at pos ") << pos;
2106LogicalResult CollapseShapeOp::verify() {
2107 CollapseShapeOp op = *
this;
2108 if (llvm::any_of(op.getReassociationIndices(),
2110 return op.emitOpError(
"reassociation indices must not be empty");
2112 RankedTensorType srcType = op.getSrc().getType();
2113 RankedTensorType resultType = op.getResult().getType();
2121template <
typename TensorReshapeOp>
2122struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2123 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2124 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2125 PatternRewriter &rewriter)
const override {
2126 DenseElementsAttr attr;
2133 if (!reshapeOp.getResultType().hasStaticShape())
2136 reshapeOp.getResultType(), attr.
getRawData());
2143template <
typename TensorReshapeOp>
2144class FoldReshapeWithSplat :
public OpRewritePattern<TensorReshapeOp> {
2146 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2148 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2149 PatternRewriter &rewriter)
const override {
2150 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2151 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2155 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2162template <
typename TensorReshapeOp>
2163struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2164 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2165 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2166 PatternRewriter &rewriter)
const override {
2168 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2172 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2174 if (!shapedTy.hasStaticShape())
2178 fromElements.getElements());
2184struct FoldCollapseOfCastOp :
public OpRewritePattern<CollapseShapeOp> {
2185 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2187 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2188 PatternRewriter &rewriter)
const override {
2189 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2193 RankedTensorType srcType =
2194 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2195 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2196 srcType, collapseShapeOp.getReassociationMaps());
2198 if (newResultType == collapseShapeOp.getResultType()) {
2200 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2203 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2204 newResultType, castOp.getSource(),
2205 collapseShapeOp.getReassociation());
2207 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2217struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2218 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2220 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2221 PatternRewriter &rewriter)
const override {
2222 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2226 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2227 SmallVector<ReassociationIndices, 4> reassoc =
2228 expandOp.getReassociationIndices();
2230 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2231 SmallVector<Value> dynamicOutputShape;
2232 auto outputIt = expandOp.getOutputShape().begin();
2234 for (
const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2235 for (uint64_t outDim : innerReassoc) {
2236 if (ShapedType::isStatic(newOutputShape[outDim]))
2243 Value val = *outputIt;
2245 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2246 dynamicOutputShape.push_back(val);
2252 newOutputShape[outDim] = cst.getSExtValue();
2254 dynamicOutputShape.push_back(val);
2260 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2264 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2265 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2266 for (
auto outDim : reassoc[inDim]) {
2267 auto ofr = newOutputShape[outDim];
2268 if (ShapedType::isDynamic(ofr)) {
2269 newInputShape[inDim] = ShapedType::kDynamic;
2272 newInputShape[inDim] *= ofr;
2276 SmallVector<OpFoldResult> outputOfr =
2278 auto inputType = RankedTensorType::get(
2279 newInputShape, expandOp.getSrcType().getElementType());
2280 auto outputType = RankedTensorType::get(
2281 newOutputShape, expandOp.getSrcType().getElementType());
2282 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2284 auto newExpand = ExpandShapeOp::create(
2285 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2286 expandOp.getReassociationIndices(), outputOfr);
2288 newExpand.getResult());
2294void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2295 MLIRContext *context) {
2297 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2298 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2299 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2300 FoldReshapeWithSplat<ExpandShapeOp>,
2301 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2304void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2305 MLIRContext *context) {
2307 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2308 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2309 tensor::DimOp, RankedTensorType>,
2310 FoldReshapeWithConstant<CollapseShapeOp>,
2311 FoldReshapeWithSplat<CollapseShapeOp>,
2312 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2316OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2318 adaptor.getOperands());
2321OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2323 adaptor.getOperands());
2330void ExtractSliceOp::getAsmResultNames(
2332 setNameFn(getResult(),
"extracted_slice");
2339ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2340 ArrayRef<int64_t> staticSizes) {
2344 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2345 sourceTensorType.getRank() &&
2346 "unexpected staticSizes not equal to rank of source");
2347 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2348 sourceTensorType.getEncoding());
2353ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2354 ArrayRef<OpFoldResult> sizes) {
2355 SmallVector<int64_t> staticSizes;
2358 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2359 sourceTensorType.getRank() &&
2360 "unexpected staticSizes not equal to rank of source");
2361 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2362 sourceTensorType.getEncoding());
2373RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2374 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2375 ArrayRef<int64_t> sizes) {
2377 auto inferredType = llvm::cast<RankedTensorType>(
2378 inferResultType(sourceRankedTensorType, sizes));
2379 int rankDiff = inferredType.getRank() - desiredResultRank;
2381 auto shape = inferredType.getShape();
2382 llvm::SmallBitVector dimsToProject =
2384 SmallVector<int64_t> projectedShape;
2386 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2387 if (!dimsToProject.test(pos))
2388 projectedShape.push_back(shape[pos]);
2390 RankedTensorType::get(projectedShape, inferredType.getElementType());
2392 return inferredType;
2395RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2396 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2397 ArrayRef<OpFoldResult> sizes) {
2398 SmallVector<int64_t> staticSizes;
2399 SmallVector<Value> dynamicSizes;
2401 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2402 desiredResultRank, sourceRankedTensorType, staticSizes);
2407void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2408 RankedTensorType resultType, Value source,
2409 ArrayRef<OpFoldResult> offsets,
2410 ArrayRef<OpFoldResult> sizes,
2411 ArrayRef<OpFoldResult> strides,
2412 ArrayRef<NamedAttribute> attrs) {
2413 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2414 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2418 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2421 resultType = llvm::cast<RankedTensorType>(
2422 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2424 result.addAttributes(attrs);
2425 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
2426 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2427 b.getDenseI64ArrayAttr(staticSizes),
2428 b.getDenseI64ArrayAttr(staticStrides));
2433void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2434 ArrayRef<OpFoldResult> offsets,
2435 ArrayRef<OpFoldResult> sizes,
2436 ArrayRef<OpFoldResult> strides,
2437 ArrayRef<NamedAttribute> attrs) {
2438 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2443void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2444 ArrayRef<Range> ranges,
2445 ArrayRef<NamedAttribute> attrs) {
2447 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2452void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2453 RankedTensorType resultType, Value source,
2455 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2456 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2457 offsets, [](Value v) -> OpFoldResult {
return v; });
2458 SmallVector<OpFoldResult> sizeValues =
2459 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
2460 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2461 strides, [](Value v) -> OpFoldResult {
return v; });
2462 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
2466void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2468 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2469 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2474 RankedTensorType expectedType) {
2479 return op->
emitError(
"expected rank to be smaller or equal to ")
2480 <<
"the other rank. ";
2482 return op->
emitError(
"expected type to be ")
2483 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2485 return op->
emitError(
"expected element type to be ")
2486 << expectedType.getElementType();
2488 llvm_unreachable(
"unexpected extract_slice op verification result");
2494void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2495 RankedTensorType resultType, Value source,
2496 ArrayRef<OpFoldResult> sizes,
2497 ArrayRef<NamedAttribute> attrs) {
2498 Attribute zeroIdxAttr =
b.getIndexAttr(0);
2499 Attribute oneIdxAttr =
b.getIndexAttr(1);
2500 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2501 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2502 build(
b,
result, resultType, source, readOffsets, sizes, readStrides, attrs);
2506LogicalResult ExtractSliceOp::verify() {
2507 RankedTensorType sourceType = getSourceType();
2510 RankedTensorType expectedType =
2511 ExtractSliceOp::inferResultType(sourceType,
getMixedSizes());
2519 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2520 getStaticStrides(),
true);
2522 return getOperation()->emitError(boundsResult.
errorMessage);
2527llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2532ExtractSliceOp::rankReduceIfNeeded(OpBuilder &
b, Location loc, Value value,
2533 ArrayRef<int64_t> desiredShape) {
2534 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2535 assert(sourceTensorType &&
"not a ranked tensor type");
2536 auto sourceShape = sourceTensorType.getShape();
2537 if (sourceShape.equals(desiredShape))
2539 auto maybeRankReductionMask =
2541 if (!maybeRankReductionMask)
2545 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2548LogicalResult ExtractSliceOp::reifyResultShapes(
2550 reifiedReturnShapes.resize(1);
2551 reifiedReturnShapes[0].reserve(
getType().getRank());
2554 for (
const auto &size :
enumerate(mixedSizes)) {
2555 if (droppedDims.test(size.index()))
2557 reifiedReturnShapes[0].push_back(size.value());
2578class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2580 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2582 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2583 PatternRewriter &rewriter)
const override {
2585 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2586 return matchPattern(operand, matchConstantIndex());
2590 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2599 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2600 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2601 sliceOp.getStaticStrides());
2606 Location loc = sliceOp.getLoc();
2607 Value newResult = ExtractSliceOp::create(
2608 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2609 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2610 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2611 sliceOp.getStaticStrides());
2620template <
typename IterTy,
typename ElemTy>
2621static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2622 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2623 ArrayRef<int64_t> strides,
2624 llvm::SmallVectorImpl<ElemTy> *outValues) {
2625 assert(offsets.size() == sizes.size());
2626 assert(offsets.size() == strides.size());
2627 if (offsets.empty())
2630 int64_t offset = offsets.front();
2631 int64_t size = sizes.front();
2632 int64_t stride = strides.front();
2633 if (offsets.size() == 1) {
2634 for (int64_t i = 0; i < size; ++i, offset += stride)
2635 outValues->push_back(*(values + offset));
2640 for (int64_t i = 0; i < size; ++i, offset += stride) {
2641 auto begin = values + offset * counts.front();
2642 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2643 offsets.drop_front(), sizes.drop_front(),
2644 strides.drop_front(), outValues);
2651class ConstantOpExtractSliceFolder final
2652 :
public OpRewritePattern<ExtractSliceOp> {
2654 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2656 ConstantOpExtractSliceFolder(MLIRContext *context,
2658 : OpRewritePattern<ExtractSliceOp>(context),
2659 controlFn(std::move(controlFn)) {}
2661 LogicalResult matchAndRewrite(ExtractSliceOp op,
2662 PatternRewriter &rewriter)
const override {
2663 DenseElementsAttr attr;
2672 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2673 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2674 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2681 int64_t count = sourceType.getNumElements();
2686 auto offsets = op.getStaticOffsets();
2687 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2689 auto sizes = op.getStaticSizes();
2690 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2692 auto strides = op.getStaticStrides();
2693 if (llvm::is_contained(strides, ShapedType::kDynamic))
2697 SmallVector<int64_t> counts;
2698 ArrayRef<int64_t> shape = sourceType.getShape();
2699 counts.reserve(shape.size());
2700 for (int64_t v : shape) {
2702 counts.push_back(count);
2706 SmallVector<Attribute> outValues;
2707 outValues.reserve(resultType.getNumElements());
2708 sliceElements(attr.
value_begin<Attribute>(), counts, offsets, sizes,
2709 strides, &outValues);
2726 patterns.
add<ConstantOpExtractSliceFolder>(patterns.
getContext(), controlFn);
2736 RankedTensorType nonReducedType =
2737 ExtractSliceOp::inferResultType(op.getSourceType(), mixedSizes);
2741 llvm::SmallBitVector droppedDims = op.getDroppedDims();
2742 if (droppedDims.none())
2743 return nonReducedType;
2747 for (
auto i : llvm::seq<int64_t>(mixedSizes.size()))
2748 if (!droppedDims.test(i))
2749 targetShape.push_back(nonReducedType.getDimSize(i));
2751 return RankedTensorType::get(targetShape, nonReducedType.getElementType(),
2752 nonReducedType.getEncoding());
2759 ExtractSliceOp newOp) {
2762 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2768void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2769 MLIRContext *context) {
2771 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2772 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2773 ExtractSliceOpCastFolder>(context);
2779 ShapedType shapedType) {
2786 auto shape = shapedType.getShape();
2787 for (
auto it : llvm::zip(op.getMixedSizes(),
shape))
2801 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2804 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2805 insertOp.isSameAs(extractOp, isSame))
2806 return insertOp.getSource();
2811OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2812 if (OpFoldResult reshapedSource = reshapeConstantSource(
2813 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2815 return reshapedSource;
2816 if (getSourceType() ==
getType() &&
2818 return this->getSource();
2822 return OpFoldResult();
2827 auto rankedTensorType = llvm::cast<RankedTensorType>(
tensor.getType());
2828 unsigned rank = rankedTensorType.getRank();
2832 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType,
tensor,
2833 offsets, sizes, strides);
2840void InsertSliceOp::getAsmResultNames(
2842 setNameFn(getResult(),
"inserted_slice");
2856 result.addAttributes(attrs);
2857 build(
b,
result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2858 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2859 b.getDenseI64ArrayAttr(staticSizes),
2860 b.getDenseI64ArrayAttr(staticStrides));
2865void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2866 Value dest, ArrayRef<Range> ranges,
2867 ArrayRef<NamedAttribute> attrs) {
2869 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
2873void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2875 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2876 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2877 offsets, [](Value v) -> OpFoldResult {
return v; });
2878 SmallVector<OpFoldResult> sizeValues =
2879 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
2880 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2881 strides, [](Value v) -> OpFoldResult {
return v; });
2882 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
2888 RankedTensorType srcType, RankedTensorType dstType,
2893 RankedTensorType expected =
2894 ExtractSliceOp::inferResultType(dstType, staticSizes);
2896 *expectedType = expected;
2901LogicalResult InsertSliceOp::verify() {
2903 RankedTensorType expectedType;
2906 getStaticSizes(), getStaticStrides(), &expectedType);
2913 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2914 getStaticStrides(),
true);
2916 return getOperation()->emitError(boundsResult.
errorMessage);
2939 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2942 if (!prevInsertOp ||
2943 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2944 !prevInsertOp.isSameAs(insertOp, isSame))
2947 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2959 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2962 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2963 !extractOp.isSameAs(insertOp, isSame))
2966 return extractOp.getSource();
2969OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2970 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2971 getSourceType() ==
getType() &&
2973 return this->getSource();
2980 return OpFoldResult();
2983LogicalResult InsertSliceOp::reifyResultShapes(
2985 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
2994template <
typename InsertOpTy>
2995class InsertSliceOpConstantArgumentFolder final
2996 :
public OpRewritePattern<InsertOpTy> {
2998 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3000 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3001 PatternRewriter &rewriter)
const override {
3002 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
3003 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3004 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
3013 SliceBoundsVerificationResult sliceResult =
3015 mixedOffsets, mixedSizes, mixedStrides);
3020 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
3021 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
3023 Value toInsert = insertSliceOp.getSource();
3024 if (sourceType != insertSliceOp.getSourceType()) {
3025 OpBuilder::InsertionGuard g(rewriter);
3029 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
3031 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3032 sourceType, toInsert);
3035 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
3036 mixedSizes, mixedStrides);
3061template <
typename InsertOpTy>
3062struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
3063 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3065 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3066 PatternRewriter &rewriter)
const override {
3067 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3068 return matchPattern(operand, matchConstantIndex());
3072 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3075 return std::nullopt;
3076 return castOp.getSource();
3078 std::optional<Value> sourceCastSource =
3079 getSourceOfCastOp(insertSliceOp.getSource());
3080 std::optional<Value> destCastSource =
3081 getSourceOfCastOp(insertSliceOp.getDest());
3082 if (!sourceCastSource && !destCastSource)
3086 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3087 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3088 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
3089 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3090 if (!srcType || !dstType)
3096 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3098 staticSizes, srcType.getShape(),
true);
3099 if (!rankReductionMask.has_value())
3106 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3107 int64_t rankReducedIdx = 0;
3108 for (
auto [idx, size] :
enumerate(staticSizes)) {
3109 if (!rankReductionMask.value().contains(idx) &&
3110 !srcType.isDynamicDim(rankReducedIdx)) {
3112 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
3113 size = srcType.getDimSize(rankReducedIdx++);
3119 staticSizes, insertSliceOp.getStaticStrides()) !=
3120 SliceVerificationResult::Success)
3122 SliceBoundsVerificationResult sliceResult =
3124 mixedSizes, insertSliceOp.getMixedStrides());
3129 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3130 insertSliceOp.getMixedOffsets(), mixedSizes,
3131 insertSliceOp.getMixedStrides());
3134 bool isParallelInsert =
3135 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3136 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3137 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3138 insertSliceOp.getDestType(),
3167template <
typename InsertOpTy>
3168struct InsertSliceOpSourceCastInserter final
3169 :
public OpRewritePattern<InsertOpTy> {
3170 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3172 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3173 PatternRewriter &rewriter)
const override {
3174 RankedTensorType srcType = insertSliceOp.getSourceType();
3175 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3177 SmallVector<int64_t> newSrcShape(srcType.getShape());
3178 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3179 if (std::optional<int64_t> constInt =
3184 newSrcShape[i] = *constInt;
3190 RankedTensorType newSrcType = RankedTensorType::get(
3191 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3192 if (srcType == newSrcType ||
3194 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3202 OpBuilder::InsertionGuard g(rewriter);
3206 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3208 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3209 newSrcType, insertSliceOp.getSource());
3211 insertSliceOp, cast, insertSliceOp.getDest(),
3212 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3213 insertSliceOp.getMixedStrides());
3219llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3223void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3224 MLIRContext *context) {
3225 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3226 InsertSliceOpCastFolder<InsertSliceOp>,
3227 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3234 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3235 unsigned rank = rankedTensorType.getRank();
3239 return b.createOrFold<tensor::InsertSliceOp>(loc,
tensor, dest, offsets,
3248 setNameFn(getResult(),
"padded");
3251LogicalResult PadOp::verify() {
3252 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3253 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3255 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3256 if (!expectedType) {
3257 return emitError(
"failed to infer expectedType from sourceType ")
3258 << sourceType <<
", specified resultType is " << resultType;
3260 if (resultType.getRank() != expectedType.getRank()) {
3262 << resultType <<
" does not match the inferred type "
3265 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3266 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3268 if (expectedType.isDynamicDim(i))
3271 << resultType <<
" does not match the inferred type "
3278LogicalResult PadOp::verifyRegions() {
3279 auto ®ion = getRegion();
3280 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3281 Block &block = region.front();
3283 return emitError(
"expected the block to have ") << rank <<
" arguments";
3287 if (!en.value().isIndex())
3289 << (en.index() + 1) <<
" to be an index";
3294 if (yieldOp.getValue().getType() !=
3296 return emitOpError(
"expected yield type to match shape element type");
3301RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3302 ArrayRef<int64_t> staticLow,
3303 ArrayRef<int64_t> staticHigh,
3304 ArrayRef<int64_t> resultShape) {
3305 unsigned rank = sourceType.getRank();
3306 if (staticLow.size() != rank)
3307 return RankedTensorType();
3308 if (staticHigh.size() != rank)
3309 return RankedTensorType();
3310 if (!resultShape.empty() && resultShape.size() != rank)
3311 return RankedTensorType();
3313 SmallVector<int64_t, 4> inferredShape;
3314 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3315 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3316 staticHigh[i] == ShapedType::kDynamic) {
3317 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3320 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3321 assert((resultShape.empty() || size == resultShape[i] ||
3322 resultShape[i] == ShapedType::kDynamic) &&
3323 "mismatch between inferred shape and result shape");
3324 inferredShape.push_back(size);
3328 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3331void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3332 Value source, ArrayRef<int64_t> staticLow,
3334 bool nofold, ArrayRef<NamedAttribute> attrs) {
3335 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3337 resultType = inferResultType(sourceType, staticLow, staticHigh);
3338 result.addAttributes(attrs);
3339 build(
b,
result, resultType, source, low, high,
3340 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3341 nofold ?
b.getUnitAttr() : UnitAttr());
3344void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3346 ArrayRef<NamedAttribute> attrs) {
3347 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3348 unsigned rank = sourceType.getRank();
3349 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3350 build(
b,
result, resultType, source, staticVector, staticVector, low, high,
3354void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3355 Value source, ArrayRef<OpFoldResult> low,
3356 ArrayRef<OpFoldResult> high,
bool nofold,
3357 ArrayRef<NamedAttribute> attrs) {
3358 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3359 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3360 SmallVector<int64_t, 4> staticLow, staticHigh;
3368 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3370 assert(llvm::isa<RankedTensorType>(resultType));
3371 result.addAttributes(attrs);
3372 build(
b,
result, resultType, source, dynamicLow, dynamicHigh,
3373 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3374 nofold ?
b.getUnitAttr() : UnitAttr());
3377void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3378 Value source, ArrayRef<OpFoldResult> low,
3379 ArrayRef<OpFoldResult> high, Value constantPadValue,
3380 bool nofold, ArrayRef<NamedAttribute> attrs) {
3381 build(
b,
result, resultType, source, low, high, nofold, attrs);
3384 Region *region =
result.regions[0].get();
3385 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3386 Repeated<Type> blockArgTypes(sourceRank,
b.getIndexType());
3387 SmallVector<Location> blockArgLocs(sourceRank,
result.location);
3391 OpBuilder::InsertionGuard guard(
b);
3392 b.createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3393 tensor::YieldOp::create(
b,
result.location, constantPadValue);
3396llvm::SmallBitVector PadOp::getPaddedDims() {
3397 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3398 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3399 for (
const auto &en :
enumerate(paddingWidths))
3401 paddedDims.set(en.index());
3403 extractPaddedDims(getMixedLowPad());
3404 extractPaddedDims(getMixedHighPad());
3411struct FoldStaticZeroPadding :
public OpRewritePattern<PadOp> {
3412 using OpRewritePattern<PadOp>::OpRewritePattern;
3414 LogicalResult matchAndRewrite(PadOp padTensorOp,
3415 PatternRewriter &rewriter)
const override {
3416 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3418 if (padTensorOp.getNofold())
3421 padTensorOp, padTensorOp.getResult().
getType(),
3422 padTensorOp.getSource());
3428struct FoldSourceTensorCast :
public OpRewritePattern<PadOp> {
3429 using OpRewritePattern<PadOp>::OpRewritePattern;
3431 LogicalResult matchAndRewrite(PadOp padTensorOp,
3432 PatternRewriter &rewriter)
const override {
3433 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3437 auto newResultType = PadOp::inferResultType(
3438 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3439 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3440 padTensorOp.getResultType().getShape());
3442 if (newResultType == padTensorOp.getResultType()) {
3444 padTensorOp.getSourceMutable().assign(castOp.getSource());
3447 auto newOp = PadOp::create(
3448 rewriter, padTensorOp->getLoc(), newResultType,
3449 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3450 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3451 padTensorOp.getHigh(), padTensorOp.getNofold(),
3454 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3457 padTensorOp, padTensorOp.getResultType(), newOp);
3465struct FoldTargetTensorCast :
public OpRewritePattern<PadOp> {
3466 using OpRewritePattern<PadOp>::OpRewritePattern;
3468 LogicalResult matchAndRewrite(PadOp padTensorOp,
3469 PatternRewriter &rewriter)
const override {
3470 if (!padTensorOp.getResult().hasOneUse())
3473 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3477 tensorCastOp.getDest().getType()))
3480 auto replacementOp = PadOp::create(
3481 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3482 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3483 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3484 padTensorOp.getHigh(), padTensorOp.getNofold(),
3486 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3488 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3489 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3529struct FoldOrthogonalPaddings :
public OpRewritePattern<PadOp> {
3530 using OpRewritePattern<PadOp>::OpRewritePattern;
3532 LogicalResult matchAndRewrite(PadOp padOp,
3533 PatternRewriter &rewriter)
const override {
3534 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3537 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3538 if (!outerPadOp || outerPadOp.getNofold())
3540 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3545 int64_t rank = padOp.getSourceType().getRank();
3546 if (outerSliceOp.getSourceType().getRank() != rank) {
3548 "cannot fold rank-reducing chain");
3552 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3554 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3558 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3560 "cannot fold PadOps with low padding");
3564 Attribute innerAttr, outerAttr;
3565 Value innerValue = padOp.getConstantPaddingValue();
3566 Value outerValue = outerPadOp.getConstantPaddingValue();
3567 if (!innerValue || !outerValue ||
3570 innerAttr != outerAttr) {
3572 padOp,
"cannot fold PadOps with different padding values");
3576 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3577 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3578 if (innerDims.anyCommon(outerDims)) {
3580 padOp,
"cannot fold PadOps with common padding dimensions");
3588 SmallVector<OpFoldResult> newOffsets(rank, rewriter.
getIndexAttr(0));
3590 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3591 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3592 if (!innerDims.test(en.index()) &&
3594 en.value() = outerOffset;
3597 if (!outerDims.test(en.index()) &&
3599 en.value() = innerOffset;
3603 padOp,
"cannot find zero-offset and zero-padding pair");
3611 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3613 if (!outerDims.test(en.index()))
3615 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3616 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3617 assert(ShapedType::isStatic(sourceSize) &&
3618 "expected padded dimension to have a static size");
3621 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3622 "match the size of the outer padding");
3624 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3628 SmallVector<OpFoldResult> newHighPad(rank, rewriter.
getIndexAttr(0));
3630 if (innerDims.test(en.index()))
3631 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3632 if (outerDims.test(en.index()))
3633 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3638 auto newSliceOp = ExtractSliceOp::create(
3639 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3640 newSizes, innerSliceOp.getMixedStrides());
3641 auto newPadOp = PadOp::create(
3642 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3643 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3646 newPadOp.getRegion().begin());
3647 rewriter.
replaceOp(padOp, newPadOp.getResult());
3652struct FoldStaticPadding :
public OpRewritePattern<PadOp> {
3653 using OpRewritePattern<PadOp>::OpRewritePattern;
3655 LogicalResult matchAndRewrite(PadOp padTensorOp,
3656 PatternRewriter &rewriter)
const override {
3657 Value input = padTensorOp.getSource();
3658 if (!llvm::isa<RankedTensorType>(input.
getType()))
3660 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3661 auto inputRank = inputDims.size();
3663 auto oldResultType =
3664 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3668 auto outputDims = oldResultType.getShape();
3671 SmallVector<int64_t> constOperandsLow;
3672 SmallVector<Value> newLows;
3673 for (
auto operand : padTensorOp.getLow()) {
3676 constOperandsLow.push_back(ShapedType::kDynamic);
3677 newLows.push_back(operand);
3680 constOperandsLow.push_back(intOp.getExtValue());
3682 SmallVector<int64_t> constOperandsHigh;
3683 SmallVector<Value> newHighs;
3684 for (
auto operand : padTensorOp.getHigh()) {
3687 constOperandsHigh.push_back(ShapedType::kDynamic);
3688 newHighs.push_back(operand);
3691 constOperandsHigh.push_back(intOp.getExtValue());
3694 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3695 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3698 if (inputDims.size() != outputDims.size() ||
3699 inputDims.size() != constLow.size() ||
3700 inputDims.size() != constHigh.size())
3705 for (
size_t i = 0; i < inputRank; i++) {
3706 if (constLow[i] == ShapedType::kDynamic)
3707 constLow[i] = constOperandsLow[lowCount++];
3708 if (constHigh[i] == ShapedType::kDynamic)
3709 constHigh[i] = constOperandsHigh[highCount++];
3712 auto staticLow = ArrayRef<int64_t>(constLow);
3713 auto staticHigh = ArrayRef<int64_t>(constHigh);
3716 SmallVector<int64_t> newOutDims;
3717 for (
size_t i = 0; i < inputRank; i++) {
3718 if (outputDims[i] == ShapedType::kDynamic) {
3719 newOutDims.push_back(
3720 (staticLow[i] == ShapedType::kDynamic ||
3721 staticHigh[i] == ShapedType::kDynamic ||
3722 inputDims[i] == ShapedType::kDynamic
3723 ? ShapedType::kDynamic
3724 : inputDims[i] + staticLow[i] + staticHigh[i]));
3726 newOutDims.push_back(outputDims[i]);
3730 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3731 llvm::all_of(newOutDims,
3732 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3736 auto newResultType = RankedTensorType::get(
3737 newOutDims, padTensorOp.getType().getElementType());
3738 auto newOp = PadOp::create(
3739 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3740 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3744 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3772struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3773 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3775 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3776 PatternRewriter &rewriter)
const override {
3777 if (padOp.getNofold()) {
3781 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3782 if (!producerPad || producerPad.getNofold()) {
3784 padOp,
"producer is not a foldable tensor.pad op");
3788 Value consumerPadValue = padOp.getConstantPaddingValue();
3789 Value producerPadValue = producerPad.getConstantPaddingValue();
3790 if (!consumerPadValue || !producerPadValue ||
3791 consumerPadValue != producerPadValue) {
3794 "cannot fold PadOps with different or non-constant padding values");
3797 Location loc = padOp.getLoc();
3802 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3803 ArrayRef<OpFoldResult> producerPaddings) {
3804 SmallVector<OpFoldResult> sumPaddings;
3805 for (
auto [consumerIndex, producerIndex] :
3806 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3807 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3808 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3813 SmallVector<OpFoldResult> newHighPad =
3814 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3815 SmallVector<OpFoldResult> newLowPad =
3816 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3818 auto newPadOp = tensor::PadOp::create(
3819 rewriter, padOp.getLoc(), padOp.getResultType(),
3820 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3823 newPadOp.getRegion().begin());
3824 rewriter.
replaceOp(padOp, newPadOp.getResult());
3832PadOp::reifyResultShapes(OpBuilder &
b,
3834 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
3835 SmallVector<OpFoldResult> lp = getMixedLowPad();
3836 SmallVector<OpFoldResult> hp = getMixedHighPad();
3837 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3838 if (!
getType().isDynamicDim(i)) {
3839 reifiedReturnShapes[0][i] =
b.getIndexAttr(
getType().getDimSize(i));
3842 Location loc = getLoc();
3843 Value dim =
b.createOrFold<tensor::DimOp>(
3846 AffineExpr d0, d1, d2;
3849 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3854void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3855 MLIRContext *context) {
3856 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3857 FoldOrthogonalPaddings, FoldStaticPadding,
3858 FoldConsecutiveConstantPadding>(context);
3870Value PadOp::getConstantPaddingValue() {
3871 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3874 Value padValue = yieldOp.getValue();
3885OpFoldResult PadOp::fold(FoldAdaptor) {
3886 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3896OpResult ParallelInsertSliceOp::getTiedOpResult() {
3897 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3898 for (
const auto &it :
3899 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3900 Operation &nextOp = it.value();
3901 if (&nextOp == getOperation())
3902 return parallelCombiningParent.getParentResult(it.index());
3904 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3908void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3909 Value source, Value dest,
3910 ArrayRef<OpFoldResult> offsets,
3911 ArrayRef<OpFoldResult> sizes,
3912 ArrayRef<OpFoldResult> strides,
3913 ArrayRef<NamedAttribute> attrs) {
3914 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3915 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3919 result.addAttributes(attrs);
3920 build(
b,
result, {}, source, dest, dynamicOffsets, dynamicSizes,
3921 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3922 b.getDenseI64ArrayAttr(staticSizes),
3923 b.getDenseI64ArrayAttr(staticStrides));
3928void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3929 Value source, Value dest,
3930 ArrayRef<Range> ranges,
3931 ArrayRef<NamedAttribute> attrs) {
3933 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
3937void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3938 Value source, Value dest,
ValueRange offsets,
3940 ArrayRef<NamedAttribute> attrs) {
3941 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3942 offsets, [](Value v) -> OpFoldResult {
return v; });
3943 SmallVector<OpFoldResult> sizeValues =
3944 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
3945 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3946 strides, [](Value v) -> OpFoldResult {
return v; });
3947 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
3952void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
3953 Value dest, ArrayRef<OpFoldResult> sizes,
3954 ArrayRef<NamedAttribute> attrs) {
3955 Attribute zeroIdxAttr =
b.getIndexAttr(0);
3956 Attribute oneIdxAttr =
b.getIndexAttr(1);
3957 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3958 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3959 build(
b,
result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3962LogicalResult ParallelInsertSliceOp::verify() {
3963 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3964 return this->
emitError(
"expected InParallelOpInterface parent, got:")
3965 << *(getOperation()->getParentOp());
3968 RankedTensorType expectedType;
3971 getStaticSizes(), getStaticStrides(), &expectedType);
3978 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3979 getStaticStrides(),
true);
3981 return getOperation()->emitError(boundsResult.
errorMessage);
3986void ParallelInsertSliceOp::getCanonicalizationPatterns(
3987 RewritePatternSet &results, MLIRContext *context) {
3988 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3989 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3990 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3993llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3998MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3999 return getDestMutable();
4002Operation *ParallelInsertSliceOp::getIteratingParent() {
4004 if (
auto combiningOp =
4005 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
4006 return combiningOp->getParentOp();
4014void ScatterOp::getAsmResultNames(
4016 setNameFn(getResult(),
"scatter");
4019LogicalResult ScatterOp::verify() {
4020 int64_t destRank = getDestType().getRank();
4021 ArrayRef<int64_t> scatterDims = getScatterDims();
4023 getIndicesType().
getShape(), destRank,
4024 "scatter",
"dest")))
4028 return emitOpError(
"requires 'unique' attribute to be set");
4035 RankedTensorType expectedSourceType = GatherOp::inferResultType(
4036 getDestType(), getIndicesType(), scatterDims,
false);
4037 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
4038 getDestType(), getIndicesType(), scatterDims,
true);
4039 if (getSourceType() != expectedSourceType &&
4040 getSourceType() != expectedRankReducedSourceType) {
4044 << expectedSourceType <<
" or its rank-reduced variant "
4045 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
4056void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4057 Type aggregateType,
ValueRange dynamicSizes) {
4058 build(builder,
result, aggregateType, element, dynamicSizes);
4061void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4062 ArrayRef<int64_t> staticShape,
ValueRange dynamicSizes) {
4063 auto aggregateType = RankedTensorType::get(staticShape, element.
getType());
4064 build(builder,
result, aggregateType, element, dynamicSizes);
4067void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4068 ArrayRef<OpFoldResult> sizes) {
4069 SmallVector<int64_t> staticShape;
4070 SmallVector<Value> dynamicSizes;
4072 build(builder,
result, element, staticShape, dynamicSizes);
4075void SplatOp::getAsmResultNames(
4077 setNameFn(getResult(),
"splat");
4080LogicalResult SplatOp::verify() {
4086SplatOp::reifyResultShapes(OpBuilder &builder,
4088 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
4090 for (int64_t i = 0; i <
getType().getRank(); ++i) {
4091 if (
getType().isDynamicDim(i)) {
4100OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4101 auto constOperand = adaptor.getInput();
4102 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4106 if (!
getType().hasStaticShape())
4121 if (isa<InsertSliceOp>(op.getOperation()) ||
4122 isa<LoopLikeOpInterface>(op.getOperation()))
4155 isa<linalg::RelayoutOpInterface>(*op))
4163 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4166 replacements.reserve(newOp->getNumResults());
4167 for (
auto [oldResult, newResult] :
4168 llvm::zip(op->getResults(), newOp->getResults())) {
4169 if (newResult.getType() != oldResult.getType()) {
4170 replacements.push_back(tensor::CastOp::create(
4171 rewriter, op->getLoc(), oldResult.
getType(), newResult));
4173 replacements.push_back(newResult);
4186void TensorDialect::getCanonicalizationPatterns(
4187 RewritePatternSet &results)
const {
4195#define GET_OP_CLASSES
4196#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
Attributes are known-constant values of operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
auto value_begin() const
Get an iterator of the given type to the start of the held element values.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.