34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/SmallBitVector.h"
37#include "llvm/ADT/SmallVectorExtras.h"
38#include "llvm/ADT/StringRef.h"
39#include "llvm/Support/Casting.h"
40#include "llvm/Support/MathExtras.h"
51 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
53 if (complex::ConstantOp::isBuildableWith(value, type))
54 return complex::ConstantOp::create(builder, loc, type,
55 llvm::cast<ArrayAttr>(value));
61 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
62 if (tensorType.isDynamicDim(dim))
63 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
70 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
72 for (
int64_t i = 0; i < tensorType.getRank(); ++i)
79 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
80 assert(tensorType &&
"expected tensor type");
84 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
86 return destOp.getTiedOpOperand(opResult)->get();
94 if (!tensorType.hasStaticShape()) {
102 for (
int64_t sz : tensorType.getShape())
103 mixedSizes.push_back(
b.getIndexAttr(sz));
108 tensor::EmptyOp::create(
b, loc, mixedSizes, tensorType.getElementType());
116 if (llvm::isa<TensorType>(opResult.getType())) {
118 if (failed(destination))
120 result.push_back(*destination);
127 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
128 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
129 return rtp1.getShape() == rtp2.getShape() &&
130 rtp1.getElementType() == rtp2.getElementType();
140 llvm::SmallBitVector droppedDims(mixedSizes.size());
141 int64_t shapePos = reducedShape.size() - 1;
143 for (
const auto &size : enumerate(llvm::reverse(mixedSizes))) {
144 size_t idx = mixedSizes.size() - size.index() - 1;
146 bool isStaticUnitSize =
147 isa<Attribute>(size.value()) &&
148 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
153 assert(isStaticUnitSize &&
"expected unit dim");
154 droppedDims.set(idx);
159 if (!isStaticUnitSize) {
165 if (reducedShape[shapePos] == 1) {
171 droppedDims.set(idx);
174 assert(shapePos < 0 &&
"dimension mismatch");
181static RankedTensorType
185 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
186 "incorrect number of dynamic sizes");
190 for (
int64_t i = 0, e = type.getRank(); i < e; ++i) {
191 if (type.isDynamicDim(i)) {
192 Value dynamicSize = dynamicSizes[ctr++];
194 if (cst.has_value()) {
196 if (cst.value() < 0) {
197 foldedDynamicSizes.push_back(dynamicSize);
200 staticShape[i] = *cst;
202 foldedDynamicSizes.push_back(dynamicSize);
207 return RankedTensorType::get(staticShape, type.getElementType(),
216 if (inputs.size() != 1 || outputs.size() != 1)
218 Type a = inputs.front(),
b = outputs.front();
219 auto aT = dyn_cast<TensorType>(a);
220 auto bT = dyn_cast<TensorType>(
b);
224 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
235 using OpRewritePattern<BitcastOp>::OpRewritePattern;
237 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
238 PatternRewriter &rewriter)
const final {
239 auto tensorBitcastOperand =
240 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
241 if (!tensorBitcastOperand)
244 auto resultType = cast<TensorType>(tensorBitcast.getType());
245 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
246 tensorBitcastOperand.getOperand());
255 results.
add<ChainedTensorBitcast>(context);
263 setNameFn(getResult(),
"cast");
269 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
270 auto targetType = llvm::dyn_cast<RankedTensorType>(
target);
273 if (!sourceType || !targetType)
277 if (sourceType.getElementType() != targetType.getElementType())
281 if (sourceType.getRank() != targetType.getRank())
285 if (sourceType.getEncoding() != targetType.getEncoding())
289 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
290 if (ShapedType::isStatic(std::get<0>(t)) &&
291 ShapedType::isDynamic(std::get<1>(t)))
327 castOp.getSource().getType());
360 if (llvm::isa<BlockArgument>(opOperand.get()))
362 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
363 return castOp && canFoldIntoConsumerOp(castOp);
370 newOperands.reserve(op->getNumOperands());
376 for (
OpOperand &opOperand : op->getOpOperands()) {
377 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
379 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
380 if (op.isDpsInit(&opOperand) &&
381 !llvm::isa<MemRefType>(newOperands.back().getType()))
382 newResTy[dpsInitIdx++] = newOperands.back().getType();
392 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
394 operand.set(castOp.getOperand());
402 if (inputs.size() != 1 || outputs.size() != 1)
404 Type a = inputs.front(),
b = outputs.front();
405 auto aT = llvm::dyn_cast<TensorType>(a);
406 auto bT = llvm::dyn_cast<TensorType>(
b);
410 if (aT.getElementType() != bT.getElementType())
427 if (rank != two.getRank())
432 for (
int64_t i = 0; i < rank; ++i) {
433 if (one.isDynamicDim(i)) {
434 join.push_back(two.getDimSize(i));
437 if (two.isDynamicDim(i)) {
438 join.push_back(one.getDimSize(i));
441 if (one.getDimSize(i) != two.getDimSize(i))
443 join.push_back(one.getDimSize(i));
453 using OpRewritePattern<CastOp>::OpRewritePattern;
455 LogicalResult matchAndRewrite(CastOp tensorCast,
456 PatternRewriter &rewriter)
const final {
457 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
459 if (!tensorCastOperand)
463 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
464 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
465 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
479 auto newJoin =
joinShapes(sourceType, resultType);
480 if (firstJoin != newJoin)
483 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
484 tensorCastOperand.getOperand());
502 using OpRewritePattern<CastOp>::OpRewritePattern;
504 LogicalResult matchAndRewrite(CastOp tensorCast,
505 PatternRewriter &rewriter)
const final {
506 auto extractOperand =
507 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
510 auto rankedResultType =
511 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
512 if (!rankedResultType)
516 rankedResultType.getShape() ==
517 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
521 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
523 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
525 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
526 if (dimMask && dimMask->count(i))
528 int64_t dim = rankedResultType.getShape()[dimIndex++];
529 if (ShapedType::isDynamic(dim))
531 sizes[i] = rewriter.getIndexAttr(dim);
534 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
535 tensorCast, rankedResultType, extractOperand.getSource(),
536 extractOperand.getMixedOffsets(), sizes,
537 extractOperand.getMixedStrides());
546 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
553RankedTensorType ConcatOp::inferResultType(
int64_t dim,
TypeRange inputTypes) {
554 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
556 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
557 int64_t concatRank = tensorTypes[0].getRank();
560 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
563 for (
int64_t i = 0, e = concatRank; i < e; ++i) {
567 for (
auto tensorType : tensorTypes)
572 for (
auto tensorType : tensorTypes)
575 sizes[dim] = concatSize.asInteger();
576 return RankedTensorType::get(sizes, tensorTypes[0].
getElementType());
581 FailureOr<RankedTensorType> resultType =
582 inferResultType(dim, inputs.
getTypes());
583 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
584 build(builder,
result, *resultType, dim, inputs);
587LogicalResult ConcatOp::verify() {
588 if (getInputs().size() < 1)
592 for (
auto input : getInputs())
593 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
595 RankedTensorType resultType = getResultType();
596 int64_t resultRank = getRank();
597 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
598 return type.getRank() != resultRank;
600 return emitOpError(
"rank of concatenated inputs must match result rank");
602 Type resultElementType = resultType.getElementType();
603 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
604 return type.getElementType() != resultElementType;
606 return emitOpError(
"inputs and result element type must match");
609 if (dim >= resultRank)
610 return emitOpError(
"concatenation dim must be less than the tensor rank");
613 for (
int64_t i = 0, e = resultRank; i < e; ++i) {
617 for (
auto tensorType : inputTypes) {
618 FailureOr<SaturatedInteger> maybeSize =
621 return emitOpError(
"static concatenation size mismatch along ")
622 <<
"non-concatenated dimension " << i;
628 for (
auto tensorType : inputTypes)
631 sizes[dim] = concatSize.asInteger();
632 auto inferredResultType =
635 for (
auto [inferredSize, actualSize] :
636 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
637 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
638 ShapedType::isDynamic(actualSize);
639 if (!hasDynamic && inferredSize != actualSize)
641 << resultType <<
"does not match inferred shape "
642 << inferredResultType <<
" static sizes";
648FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
649 size_t numInputs = getInputs().size();
650 uint64_t concatDim = getDim();
653 inputShapes.reserve(numInputs);
655 concatOffsets.reserve(numInputs);
662 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
666 outputShape = inputShape;
667 concatOffsets.push_back(zero);
669 concatOffsets.push_back(outputShape[concatDim]);
671 builder, loc, addExpr,
672 {outputShape[concatDim], inputShape[concatDim]});
674 inputShapes.emplace_back(std::move(inputShape));
684 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
685 offsets[concatDim] = concatOffsets[
index];
686 auto insertSlice = tensor::InsertSliceOp::create(
697ConcatOp::reifyResultShapes(
OpBuilder &builder,
701 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
703 Value init = inputs[0];
711 for (
int64_t i = 0; i < rank; ++i) {
714 if (!
getType().isDynamicDim(i)) {
716 }
else if (!inferredResultType.isDynamicDim(i)) {
719 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
721 reifiedReturnShapes[0][i] =
722 tensor::DimOp::create(builder, init.
getLoc(), init, i).getResult();
726 if (
getType().isDynamicDim(dim)) {
731 for (
auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
734 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
742 reifiedReturnShapes[0][dim] =
748void ConcatOp::getAsmResultNames(
750 setNameFn(getResult(),
"concat");
755 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
763 using OpRewritePattern<ConcatOp>::OpRewritePattern;
765 LogicalResult matchAndRewrite(ConcatOp concatOp,
766 PatternRewriter &rewriter)
const override {
767 if (concatOp.getInputs().size() != 1)
770 concatOp.getInputs()[0]);
795 using OpRewritePattern<ConcatOp>::OpRewritePattern;
797 LogicalResult matchAndRewrite(ConcatOp concatOp,
798 PatternRewriter &rewriter)
const override {
799 int64_t dim = concatOp.getDim();
800 RankedTensorType inferredResultType =
801 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
804 LogicalResult matched = failure();
807 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
808 for (
auto [operandIdx, operandType] :
809 llvm::enumerate(concatOp->getOperandTypes())) {
811 inferredOperandShape[dim] =
812 cast<RankedTensorType>(operandType).getDimSize(dim);
813 auto inferredOperandType = RankedTensorType::get(
814 inferredOperandShape, inferredResultType.getElementType());
822 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
823 concatOp.getOperand(operandIdx));
825 concatOp->setOperand(operandIdx, castOp->getResult(0));
849 using OpRewritePattern<ConcatOp>::OpRewritePattern;
851 LogicalResult matchAndRewrite(ConcatOp concatOp,
852 PatternRewriter &rewriter)
const override {
853 int64_t dim = concatOp.getDim();
854 RankedTensorType inferredResultType =
855 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
859 concatOp.getResultType())) {
864 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
865 concatOp->getOperands());
877 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
886 setNameFn(getResult(),
"dim");
891 auto loc =
result.location;
893 build(builder,
result, source, indexValue);
896std::optional<int64_t> DimOp::getConstantIndex() {
905 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
906 if (!rankedSourceType)
909 if (rankedSourceType.getRank() <= constantIndex)
917 setResultRange(getResult(),
923 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
928 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
935 if (indexVal < 0 || indexVal >= tensorType.getRank())
939 if (!tensorType.isDynamicDim(
index.getInt())) {
944 Operation *definingOp = getSource().getDefiningOp();
947 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
949 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
952 assert(ShapedType::isDynamic(resultType.getShape()[
index.getInt()]));
955 auto dynExtents = fromElements.getDynamicExtents().begin();
956 for (
auto dim : resultType.getShape().take_front(
index.getInt()))
957 if (ShapedType::isDynamic(dim))
960 return Value{*dynExtents};
964 unsigned unsignedIndex =
index.getValue().getZExtValue();
966 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
969 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
970 sliceOp.isDynamicSize(unsignedIndex)) {
971 return {sliceOp.getDynamicSize(unsignedIndex)};
985 using OpRewritePattern<DimOp>::OpRewritePattern;
987 LogicalResult matchAndRewrite(DimOp dimOp,
988 PatternRewriter &rewriter)
const override {
989 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
992 Value newSource = castOp.getOperand();
1001 using OpRewritePattern<DimOp>::OpRewritePattern;
1003 LogicalResult matchAndRewrite(DimOp dimOp,
1004 PatternRewriter &rewriter)
const override {
1005 auto source = dimOp.getSource();
1006 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1010 auto resultIndex = cast<OpResult>(source).getResultNumber();
1011 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1014 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1022 using OpRewritePattern<DimOp>::OpRewritePattern;
1024 LogicalResult matchAndRewrite(DimOp dim,
1025 PatternRewriter &rewriter)
const override {
1026 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1034 Location loc = dim.getLoc();
1036 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1037 if (extract.
getType() != dim.getType())
1039 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1048 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1058 assert(none_of(staticShape, ShapedType::isDynamic) &&
1059 "expected only static sizes");
1063void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1064 ArrayRef<int64_t> staticShape, Type elementType,
1065 ValueRange dynamicSizes, Attribute encoding) {
1066 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1067 build(builder,
result, tensorType, dynamicSizes);
1070void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1071 ArrayRef<OpFoldResult> sizes, Type elementType,
1072 Attribute encoding) {
1073 SmallVector<int64_t> staticShape;
1074 SmallVector<Value> dynamicSizes;
1076 build(builder,
result, staticShape, elementType, dynamicSizes, encoding);
1079LogicalResult EmptyOp::verify() {
1085EmptyOp::reifyResultShapes(OpBuilder &builder,
1087 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1089 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1090 if (
getType().isDynamicDim(i)) {
1099Value EmptyOp::getDynamicSize(
unsigned idx) {
1100 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1102 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1103 if (
getType().isDynamicDim(i))
1108SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1109 SmallVector<OpFoldResult>
result;
1113 if (ShapedType::isDynamic(dim)) {
1116 result.push_back(
b.getIndexAttr(dim));
1134struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1135 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1137 LogicalResult matchAndRewrite(EmptyOp op,
1138 PatternRewriter &rewriter)
const override {
1139 SmallVector<Value> foldedDynamicSizes;
1141 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1144 if (foldedTensorType == op.getType())
1147 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1148 foldedDynamicSizes);
1154struct FoldEmptyTensorWithDimOp :
public OpRewritePattern<DimOp> {
1155 using OpRewritePattern<DimOp>::OpRewritePattern;
1157 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1158 PatternRewriter &rewriter)
const override {
1159 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1160 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1161 if (!emptyTensorOp || !maybeConstantIndex)
1163 auto emptyTensorType = emptyTensorOp.getType();
1164 if (*maybeConstantIndex < 0 ||
1165 *maybeConstantIndex >= emptyTensorType.getRank() ||
1166 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1169 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1189struct FoldEmptyTensorWithCastOp :
public OpRewritePattern<CastOp> {
1190 using OpRewritePattern<CastOp>::OpRewritePattern;
1192 LogicalResult matchAndRewrite(CastOp castOp,
1193 PatternRewriter &rewriter)
const override {
1196 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1201 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1202 ArrayRef<int64_t> resultShape = resultType.getShape();
1203 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1204 SmallVector<OpFoldResult> newMixedSizes;
1205 newMixedSizes.reserve(currMixedSizes.size());
1206 assert(resultShape.size() == currMixedSizes.size() &&
1207 "mismatch in result shape and sizes of empty op");
1208 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1209 int64_t newDim = std::get<0>(it);
1210 OpFoldResult currDim = std::get<1>(it);
1213 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1214 if (ShapedType::isDynamic(newDim) ||
1215 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1220 producer,
"mismatch in static value of shape of empty tensor "
1221 "result and cast result");
1223 newMixedSizes.push_back(attr);
1229 if (ShapedType::isStatic(newDim)) {
1230 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1236 newMixedSizes.push_back(currDim);
1241 resultType.getElementType());
1248void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1249 MLIRContext *context) {
1250 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1251 ReplaceEmptyTensorStaticShapeDims>(context);
1268struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1269 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1271 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1272 PatternRewriter &rewriter)
const final {
1273 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1276 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1279 extract, tensorCast.getSource(), extract.getIndices());
1294struct ExtractFromCollapseShape :
public OpRewritePattern<tensor::ExtractOp> {
1295 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1297 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1298 PatternRewriter &rewriter)
const final {
1300 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1303 if (!collapseOp.getSrcType().hasStaticShape())
1306 auto sourceSizes = collapseOp.getSrcType().getShape();
1308 SmallVector<Value>
indices(extractOp.getIndices().begin(),
1309 extractOp.getIndices().end());
1310 SmallVector<Value> sourceIndices;
1311 for (
auto [index, group] :
1312 llvm::zip(
indices, collapseOp.getReassociationIndices())) {
1313 assert(!group.empty() &&
"association indices groups cannot be empty");
1314 auto groupSize = group.size();
1316 if (groupSize == 1) {
1317 sourceIndices.push_back(index);
1321 SmallVector<int64_t> basis =
1322 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
1323 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1324 rewriter, extractOp.getLoc(), index, basis,
true);
1325 llvm::append_range(sourceIndices,
delinearize.getResults());
1327 if (collapseOp.getReassociationIndices().empty()) {
1330 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1331 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1332 rewriter, extractOp.getLoc(), zeroAffineMap,
1333 ArrayRef<OpFoldResult>{});
1334 for (int64_t i = 0; i < srcRank; i++) {
1335 sourceIndices.push_back(
1341 extractOp, collapseOp.getSrc(), sourceIndices);
1348void ExtractOp::getAsmResultNames(
1350 setNameFn(getResult(),
"extracted");
1353LogicalResult ExtractOp::verify() {
1355 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1356 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1357 return emitOpError(
"incorrect number of indices for extract_element");
1366 auto insertOp = extractOp.getTensor().
getDefiningOp<InsertOp>();
1371 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1372 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1373 return insertOp.getScalar();
1378OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1379 if (Attribute tensor = adaptor.getTensor()) {
1382 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1383 return splatTensor.getSplatValue<Attribute>();
1386 if (isa<DenseResourceElementsAttr>(tensor))
1391 SmallVector<uint64_t, 8>
indices;
1392 for (Attribute indice : adaptor.getIndices()) {
1393 if (!indice || !llvm::isa<IntegerAttr>(indice))
1395 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1399 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1400 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1401 auto rank = tensorType.getRank();
1402 assert(
static_cast<int64_t
>(
indices.size()) == tensorType.getRank() &&
1406 for (
int i = rank - 1; i >= 0; --i) {
1407 flatIndex +=
indices[i] * stride;
1408 stride *= tensorType.getDimSize(i);
1412 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1415 return fromElementsOp.getElements()[flatIndex];
1419 if (Attribute tensor = adaptor.getTensor()) {
1420 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1421 if (elementsAttr && elementsAttr.isValidIndex(
indices))
1422 return elementsAttr.getValues<Attribute>()[
indices];
1431void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1432 MLIRContext *context) {
1433 results.
add<ExtractFromTensorCast>(context);
1438 patterns.
add<ExtractFromCollapseShape>(patterns.
getContext());
1445void FromElementsOp::getAsmResultNames(
1447 setNameFn(getResult(),
"from_elements");
1452 assert(!elements.empty() &&
"expected at least one element");
1453 Type resultType = RankedTensorType::get(
1454 {
static_cast<int64_t>(elements.size())}, elements.front().
getType());
1455 build(builder,
result, resultType, elements);
1458OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1463 Type eltType =
getType().getElementType();
1466 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1489struct ExtractElementFromIndexCast
1490 :
public OpRewritePattern<tensor::ExtractOp> {
1491 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1493 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1494 PatternRewriter &rewriter)
const final {
1495 Location loc = extract.getLoc();
1496 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1502 auto newExtract = tensor::ExtractOp::create(
1503 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1514void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1515 MLIRContext *context) {
1516 results.
add<ExtractElementFromIndexCast>(context);
1523void GatherOp::getAsmResultNames(
1525 setNameFn(getResult(),
"gather");
1540RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1541 RankedTensorType indicesType,
1542 ArrayRef<int64_t> gatherDims,
1544 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1545 resultShape.reserve(resultShape.size() + sourceType.getRank());
1546 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1547 if (llvm::binary_search(gatherDims, idx)) {
1549 resultShape.push_back(1);
1552 resultShape.push_back(sourceType.getDimSize(idx));
1554 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1560 StringRef gatherOrScatter, StringRef sourceOrDest) {
1562 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1564 int64_t numGatherDims = dims.size();
1565 if (numGatherDims > rank)
1567 <<
"_dims overflow " << sourceOrDest <<
" rank";
1570 <<
"_dims length must match the size of last dimension of indices";
1574 <<
"_dims value must be non-negative";
1577 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1579 for (
int64_t i = 1; i < numGatherDims; ++i) {
1580 if (dims[i - 1] >= dims[i])
1582 <<
"_dims values must be strictly increasing";
1587LogicalResult GatherOp::verify() {
1588 int64_t sourceRank = getSourceType().getRank();
1589 ArrayRef<int64_t> gatherDims = getGatherDims();
1591 getIndicesType().
getShape(), sourceRank,
1592 "gather",
"source")))
1595 RankedTensorType expectedResultType = GatherOp::inferResultType(
1596 getSourceType(), getIndicesType(), gatherDims,
false);
1597 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1598 getSourceType(), getIndicesType(), gatherDims,
true);
1599 if (getResultType() != expectedResultType &&
1600 getResultType() != expectedRankReducedResultType) {
1604 << expectedResultType <<
" or its rank-reduced variant "
1605 << expectedRankReducedResultType <<
" (got: " << getResultType()
1612OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1613 if (OpFoldResult reshapedSource = reshapeConstantSource(
1614 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1616 return reshapedSource;
1624void InsertOp::getAsmResultNames(
1626 setNameFn(getResult(),
"inserted");
1629LogicalResult InsertOp::verify() {
1631 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1632 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1633 return emitOpError(
"incorrect number of indices");
1637OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1638 Attribute scalar = adaptor.getScalar();
1639 Attribute dest = adaptor.getDest();
1641 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1642 if (scalar == splatDest.getSplatValue<Attribute>())
1651void GenerateOp::getAsmResultNames(
1653 setNameFn(getResult(),
"generated");
1656LogicalResult GenerateOp::reifyResultShapes(
1658 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1660 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1661 if (
getType().isDynamicDim(dim)) {
1662 reifiedReturnShapes[0][dim] = getOperand(idx++);
1664 reifiedReturnShapes[0][dim] =
1671LogicalResult GenerateOp::verify() {
1674 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1681LogicalResult GenerateOp::verifyRegions() {
1682 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1684 if (!llvm::all_of(getBody().getArgumentTypes(),
1685 [](Type ty) {
return ty.
isIndex(); }))
1686 return emitError(
"all body arguments must be index");
1687 if (getBody().getNumArguments() != resultTy.getRank())
1688 return emitError(
"must have one body argument per input dimension");
1691 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1693 if (yieldOp.getValue().getType() != resultTy.getElementType())
1695 "body must be terminated with a `yield` operation of the tensor "
1701void GenerateOp::build(
1702 OpBuilder &
b, OperationState &
result, Type resultTy,
1705 build(
b,
result, resultTy, dynamicExtents);
1708 OpBuilder::InsertionGuard guard(
b);
1709 Region *bodyRegion =
result.regions.front().get();
1710 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1711 SmallVector<Type, 2> argumentTypes(rank,
b.getIndexType());
1712 SmallVector<Location, 2> argumentLocs(rank,
result.location);
1714 b.createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1724struct StaticTensorGenerate :
public OpRewritePattern<GenerateOp> {
1725 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1727 LogicalResult matchAndRewrite(GenerateOp generateOp,
1728 PatternRewriter &rewriter)
const final {
1729 SmallVector<Value> foldedDynamicSizes;
1731 generateOp.getType(), generateOp.getDynamicExtents(),
1732 foldedDynamicSizes);
1735 if (foldedTensorType == generateOp.getType())
1738 auto loc = generateOp.getLoc();
1740 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1742 newOp.getBody().begin());
1744 generateOp.getType(), newOp);
1760struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1761 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1763 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1764 PatternRewriter &rewriter)
const final {
1765 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1770 Block *body = &tensorFromElements.getBody().front();
1773 rewriter.
clone(op, mapping);
1784void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1785 MLIRContext *context) {
1787 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1794void RankOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1795 setNameFn(getResult(),
"rank");
1798OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1800 auto type = getOperand().getType();
1801 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1802 if (shapedType && shapedType.hasRank())
1803 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
1804 return IntegerAttr();
1811void ReshapeOp::getAsmResultNames(
1813 setNameFn(getResult(),
"reshape");
1818 for (
auto dim : type.getShape())
1823LogicalResult ReshapeOp::verify() {
1824 TensorType operandType = llvm::cast<TensorType>(getSource().
getType());
1825 TensorType resultType = llvm::cast<TensorType>(getResult().
getType());
1828 return emitOpError(
"element types of source and destination tensor "
1829 "types should be the same");
1833 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1834 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1836 if (resultRankedType) {
1837 if (operandRankedType && resultRankedType.hasStaticShape() &&
1838 operandRankedType.hasStaticShape()) {
1840 return emitOpError(
"source and destination tensor should have the "
1841 "same number of elements");
1843 if (ShapedType::isDynamic(shapeSize))
1844 return emitOpError(
"cannot use shape operand with dynamic length to "
1845 "reshape to statically-ranked tensor type");
1846 if (shapeSize != resultRankedType.getRank())
1848 "length of shape operand differs from the result's tensor rank");
1853OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1854 if (OpFoldResult reshapedSource = reshapeConstantSource(
1855 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1857 return reshapedSource;
1862 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1863 getSourceMutable().assign(reshapeOpProducer.getSource());
1867 auto source = getSource();
1868 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1869 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1870 if (!sourceTy || !resultTy || sourceTy != resultTy)
1875 if (sourceTy.getRank() <= 1)
1878 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1879 auto elements = fromElements.getElements();
1881 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1882 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++
id) {
1883 auto element = elements[id];
1886 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1890 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1891 dynamicNoop &= dimOp.getSource() == source;
1895 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1899 dynamicNoop =
false;
1914void CollapseShapeOp::getAsmResultNames(
1916 setNameFn(getResult(),
"collapsed");
1919void ExpandShapeOp::getAsmResultNames(
1921 setNameFn(getResult(),
"expanded");
1924int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1925 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1926 "invalid resultDim");
1927 for (
const auto &it : llvm::enumerate(getReassociationIndices()))
1928 if (llvm::is_contained(it.value(), resultDim))
1930 llvm_unreachable(
"could not find reassociation group");
1933FailureOr<SmallVector<OpFoldResult>>
1934ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
1935 RankedTensorType expandedType,
1936 ArrayRef<ReassociationIndices> reassociation,
1937 ArrayRef<OpFoldResult> inputShape) {
1938 std::optional<SmallVector<OpFoldResult>> outputShape =
1943 return *outputShape;
1946SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1950void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1951 Type resultType, Value src,
1952 ArrayRef<ReassociationIndices> reassociation,
1953 ArrayRef<OpFoldResult> outputShape) {
1954 auto [staticOutputShape, dynamicOutputShape] =
1956 build(builder,
result, cast<RankedTensorType>(resultType), src,
1958 dynamicOutputShape, staticOutputShape);
1961void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1962 Type resultType, Value src,
1963 ArrayRef<ReassociationIndices> reassociation) {
1964 SmallVector<OpFoldResult> inputShape =
1966 auto tensorResultTy = cast<RankedTensorType>(resultType);
1967 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1968 builder,
result.location, tensorResultTy, reassociation, inputShape);
1969 SmallVector<OpFoldResult> outputShapeOrEmpty;
1970 if (succeeded(outputShape)) {
1971 outputShapeOrEmpty = *outputShape;
1973 build(builder,
result, tensorResultTy, src, reassociation,
1974 outputShapeOrEmpty);
1977SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1980SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1982 getReassociationIndices());
1985SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1988SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1990 getReassociationIndices());
1993RankedTensorType CollapseShapeOp::inferCollapsedType(
1994 RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
1995 return inferCollapsedType(
1997 type.getContext(), reassociation)));
2003CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2004 ArrayRef<AffineMap> reassociation) {
2005 auto shape = type.getShape();
2006 SmallVector<int64_t, 4> newShape;
2007 newShape.reserve(reassociation.size());
2012 unsigned currentDim = 0;
2013 for (AffineMap m : reassociation) {
2014 unsigned dim = m.getNumResults();
2015 auto band = shape.slice(currentDim, dim);
2017 if (llvm::is_contained(band, ShapedType::kDynamic))
2018 size = ShapedType::kDynamic;
2020 for (
unsigned d = 0; d < dim; ++d)
2021 size *= shape[currentDim + d];
2022 newShape.push_back(size);
2026 return RankedTensorType::get(newShape, type.getElementType());
2029void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2030 ArrayRef<ReassociationIndices> reassociation,
2031 ArrayRef<NamedAttribute> attrs) {
2032 auto srcType = llvm::cast<RankedTensorType>(src.
getType());
2033 RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2035 RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2036 srcType.getEncoding());
2037 result.addAttribute(getReassociationAttrStrName(),
2039 build(
b,
result, resultType, src, attrs);
2042template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
2043 TensorReshapeOp, ExpandShapeOp>::value>
2045 RankedTensorType expandedType,
2046 RankedTensorType collapsedType) {
2048 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2052 if (expandedType.hasStaticShape() && collapsedType.hasStaticShape()) {
2053 int64_t expandedNumElements = expandedType.getNumElements();
2054 int64_t collapsedNumElements = collapsedType.getNumElements();
2055 if (expandedNumElements != collapsedNumElements) {
2056 return op.emitOpError(
"number of elements must be preserved: ")
2057 << expandedNumElements <<
" != " << collapsedNumElements;
2061 auto maps = op.getReassociationMaps();
2062 RankedTensorType expectedType =
2063 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2065 return op.emitOpError(
"expected collapsed type to be ")
2066 << expectedType <<
", but got " << collapsedType;
2070LogicalResult ExpandShapeOp::verify() {
2071 RankedTensorType srcType = getSrc().getType();
2072 RankedTensorType resultType = getResult().getType();
2074 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2075 return emitOpError(
"expected number of static shape dims to be equal to "
2076 "the output rank (")
2077 << resultType.getRank() <<
") but found "
2078 << getStaticOutputShape().size() <<
" inputs instead";
2080 if ((int64_t)getOutputShape().size() !=
2081 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2082 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2083 "static_output_shape: static_output_shape has ")
2084 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2085 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2096 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2097 for (
auto [pos, shape] : llvm::enumerate(resShape))
2098 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos])
2099 return emitOpError(
"invalid output shape provided at pos ") << pos;
2104LogicalResult CollapseShapeOp::verify() {
2105 CollapseShapeOp op = *
this;
2106 if (llvm::any_of(op.getReassociationIndices(),
2108 return op.emitOpError(
"reassociation indices must not be empty");
2110 RankedTensorType srcType = op.getSrc().getType();
2111 RankedTensorType resultType = op.getResult().getType();
2119template <
typename TensorReshapeOp>
2120struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2121 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2122 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2123 PatternRewriter &rewriter)
const override {
2124 DenseElementsAttr attr;
2131 if (!reshapeOp.getResultType().hasStaticShape())
2134 reshapeOp.getResultType(), attr.
getRawData());
2141template <
typename TensorReshapeOp>
2142class FoldReshapeWithSplat :
public OpRewritePattern<TensorReshapeOp> {
2144 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2146 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2147 PatternRewriter &rewriter)
const override {
2148 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2149 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2153 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2160template <
typename TensorReshapeOp>
2161struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2162 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2163 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2164 PatternRewriter &rewriter)
const override {
2166 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2170 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2172 if (!shapedTy.hasStaticShape())
2176 fromElements.getElements());
2182struct FoldCollapseOfCastOp :
public OpRewritePattern<CollapseShapeOp> {
2183 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2185 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2186 PatternRewriter &rewriter)
const override {
2187 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2191 RankedTensorType srcType =
2192 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2193 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2194 srcType, collapseShapeOp.getReassociationMaps());
2196 if (newResultType == collapseShapeOp.getResultType()) {
2198 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2201 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2202 newResultType, castOp.getSource(),
2203 collapseShapeOp.getReassociation());
2205 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2215struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2216 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2218 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2219 PatternRewriter &rewriter)
const override {
2220 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2224 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2225 SmallVector<ReassociationIndices, 4> reassoc =
2226 expandOp.getReassociationIndices();
2228 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2229 SmallVector<Value> dynamicOutputShape;
2230 auto outputIt = expandOp.getOutputShape().begin();
2232 for (
const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2233 for (uint64_t outDim : innerReassoc) {
2234 if (ShapedType::isStatic(newOutputShape[outDim]))
2241 Value val = *outputIt;
2243 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2244 dynamicOutputShape.push_back(val);
2250 newOutputShape[outDim] = cst.getSExtValue();
2252 dynamicOutputShape.push_back(val);
2258 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2262 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2263 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2264 for (
auto outDim : reassoc[inDim]) {
2265 auto ofr = newOutputShape[outDim];
2266 if (ShapedType::isDynamic(ofr)) {
2267 newInputShape[inDim] = ShapedType::kDynamic;
2270 newInputShape[inDim] *= ofr;
2274 SmallVector<OpFoldResult> outputOfr =
2276 auto inputType = RankedTensorType::get(
2277 newInputShape, expandOp.getSrcType().getElementType());
2278 auto outputType = RankedTensorType::get(
2279 newOutputShape, expandOp.getSrcType().getElementType());
2280 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2282 auto newExpand = ExpandShapeOp::create(
2283 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2284 expandOp.getReassociationIndices(), outputOfr);
2286 newExpand.getResult());
2292void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2293 MLIRContext *context) {
2295 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2296 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2297 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2298 FoldReshapeWithSplat<ExpandShapeOp>,
2299 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2302void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2303 MLIRContext *context) {
2305 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2306 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2307 tensor::DimOp, RankedTensorType>,
2308 FoldReshapeWithConstant<CollapseShapeOp>,
2309 FoldReshapeWithSplat<CollapseShapeOp>,
2310 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2314OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2316 adaptor.getOperands());
2319OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2321 adaptor.getOperands());
2328void ExtractSliceOp::getAsmResultNames(
2330 setNameFn(getResult(),
"extracted_slice");
2337ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2338 ArrayRef<int64_t> staticSizes) {
2342 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2343 sourceTensorType.getRank() &&
2344 "unexpected staticSizes not equal to rank of source");
2345 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2346 sourceTensorType.getEncoding());
2351ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2352 ArrayRef<OpFoldResult> sizes) {
2353 SmallVector<int64_t> staticSizes;
2356 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2357 sourceTensorType.getRank() &&
2358 "unexpected staticSizes not equal to rank of source");
2359 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2360 sourceTensorType.getEncoding());
2371RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2372 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2373 ArrayRef<int64_t> sizes) {
2375 auto inferredType = llvm::cast<RankedTensorType>(
2376 inferResultType(sourceRankedTensorType, sizes));
2377 int rankDiff = inferredType.getRank() - desiredResultRank;
2379 auto shape = inferredType.getShape();
2380 llvm::SmallBitVector dimsToProject =
2382 SmallVector<int64_t> projectedShape;
2384 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2385 if (!dimsToProject.test(pos))
2386 projectedShape.push_back(shape[pos]);
2388 RankedTensorType::get(projectedShape, inferredType.getElementType());
2390 return inferredType;
2393RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2394 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2395 ArrayRef<OpFoldResult> sizes) {
2396 SmallVector<int64_t> staticSizes;
2397 SmallVector<Value> dynamicSizes;
2399 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2400 desiredResultRank, sourceRankedTensorType, staticSizes);
2405void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2406 RankedTensorType resultType, Value source,
2407 ArrayRef<OpFoldResult> offsets,
2408 ArrayRef<OpFoldResult> sizes,
2409 ArrayRef<OpFoldResult> strides,
2410 ArrayRef<NamedAttribute> attrs) {
2411 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2412 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2416 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2419 resultType = llvm::cast<RankedTensorType>(
2420 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2422 result.addAttributes(attrs);
2423 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
2424 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2425 b.getDenseI64ArrayAttr(staticSizes),
2426 b.getDenseI64ArrayAttr(staticStrides));
2431void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2432 ArrayRef<OpFoldResult> offsets,
2433 ArrayRef<OpFoldResult> sizes,
2434 ArrayRef<OpFoldResult> strides,
2435 ArrayRef<NamedAttribute> attrs) {
2436 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2441void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2442 ArrayRef<Range> ranges,
2443 ArrayRef<NamedAttribute> attrs) {
2445 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2450void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2451 RankedTensorType resultType, Value source,
2453 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2454 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2455 offsets, [](Value v) -> OpFoldResult {
return v; });
2456 SmallVector<OpFoldResult> sizeValues =
2457 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
2458 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2459 strides, [](Value v) -> OpFoldResult {
return v; });
2460 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
2464void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2466 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2467 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2472 RankedTensorType expectedType) {
2477 return op->
emitError(
"expected rank to be smaller or equal to ")
2478 <<
"the other rank. ";
2480 return op->
emitError(
"expected type to be ")
2481 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2483 return op->
emitError(
"expected element type to be ")
2484 << expectedType.getElementType();
2486 llvm_unreachable(
"unexpected extract_slice op verification result");
2492void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2493 RankedTensorType resultType, Value source,
2494 ArrayRef<OpFoldResult> sizes,
2495 ArrayRef<NamedAttribute> attrs) {
2496 Attribute zeroIdxAttr =
b.getIndexAttr(0);
2497 Attribute oneIdxAttr =
b.getIndexAttr(1);
2498 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2499 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2500 build(
b,
result, resultType, source, readOffsets, sizes, readStrides, attrs);
2504LogicalResult ExtractSliceOp::verify() {
2505 RankedTensorType sourceType = getSourceType();
2508 RankedTensorType expectedType =
2509 ExtractSliceOp::inferResultType(sourceType,
getMixedSizes());
2517 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2518 getStaticStrides(),
true);
2520 return getOperation()->emitError(boundsResult.
errorMessage);
2525llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2530ExtractSliceOp::rankReduceIfNeeded(OpBuilder &
b, Location loc, Value value,
2531 ArrayRef<int64_t> desiredShape) {
2532 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2533 assert(sourceTensorType &&
"not a ranked tensor type");
2534 auto sourceShape = sourceTensorType.getShape();
2535 if (sourceShape.equals(desiredShape))
2537 auto maybeRankReductionMask =
2539 if (!maybeRankReductionMask)
2543 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2546LogicalResult ExtractSliceOp::reifyResultShapes(
2548 reifiedReturnShapes.resize(1);
2549 reifiedReturnShapes[0].reserve(
getType().getRank());
2552 for (
const auto &size :
enumerate(mixedSizes)) {
2553 if (droppedDims.test(size.index()))
2555 reifiedReturnShapes[0].push_back(size.value());
2576class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2578 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2580 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2581 PatternRewriter &rewriter)
const override {
2583 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2584 return matchPattern(operand, matchConstantIndex());
2588 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2597 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2598 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2599 sliceOp.getStaticStrides());
2604 Location loc = sliceOp.getLoc();
2605 Value newResult = ExtractSliceOp::create(
2606 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2607 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2608 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2609 sliceOp.getStaticStrides());
2618template <
typename IterTy,
typename ElemTy>
2619static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2620 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2621 ArrayRef<int64_t> strides,
2622 llvm::SmallVectorImpl<ElemTy> *outValues) {
2623 assert(offsets.size() == sizes.size());
2624 assert(offsets.size() == strides.size());
2625 if (offsets.empty())
2628 int64_t offset = offsets.front();
2629 int64_t size = sizes.front();
2630 int64_t stride = strides.front();
2631 if (offsets.size() == 1) {
2632 for (int64_t i = 0; i < size; ++i, offset += stride)
2633 outValues->push_back(*(values + offset));
2638 for (int64_t i = 0; i < size; ++i, offset += stride) {
2639 auto begin = values + offset * counts.front();
2640 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2641 offsets.drop_front(), sizes.drop_front(),
2642 strides.drop_front(), outValues);
2649class ConstantOpExtractSliceFolder final
2650 :
public OpRewritePattern<ExtractSliceOp> {
2652 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2654 ConstantOpExtractSliceFolder(MLIRContext *context,
2656 : OpRewritePattern<ExtractSliceOp>(context),
2657 controlFn(std::move(controlFn)) {}
2659 LogicalResult matchAndRewrite(ExtractSliceOp op,
2660 PatternRewriter &rewriter)
const override {
2661 DenseElementsAttr attr;
2670 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2671 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2672 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2679 int64_t count = sourceType.getNumElements();
2684 auto offsets = op.getStaticOffsets();
2685 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2687 auto sizes = op.getStaticSizes();
2688 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2690 auto strides = op.getStaticStrides();
2691 if (llvm::is_contained(strides, ShapedType::kDynamic))
2695 SmallVector<int64_t> counts;
2696 ArrayRef<int64_t> shape = sourceType.getShape();
2697 counts.reserve(shape.size());
2698 for (int64_t v : shape) {
2700 counts.push_back(count);
2704 SmallVector<Attribute> outValues;
2705 outValues.reserve(resultType.getNumElements());
2706 sliceElements(attr.
value_begin<Attribute>(), counts, offsets, sizes,
2707 strides, &outValues);
2724 patterns.
add<ConstantOpExtractSliceFolder>(patterns.
getContext(), controlFn);
2734 RankedTensorType nonReducedType =
2735 ExtractSliceOp::inferResultType(op.getSourceType(), mixedSizes);
2739 llvm::SmallBitVector droppedDims = op.getDroppedDims();
2740 if (droppedDims.none())
2741 return nonReducedType;
2745 for (
auto i : llvm::seq<int64_t>(mixedSizes.size()))
2746 if (!droppedDims.test(i))
2747 targetShape.push_back(nonReducedType.getDimSize(i));
2749 return RankedTensorType::get(targetShape, nonReducedType.getElementType(),
2750 nonReducedType.getEncoding());
2757 ExtractSliceOp newOp) {
2760 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2766void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2767 MLIRContext *context) {
2769 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2770 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2771 ExtractSliceOpCastFolder>(context);
2777 ShapedType shapedType) {
2784 auto shape = shapedType.getShape();
2785 for (
auto it : llvm::zip(op.getMixedSizes(),
shape))
2799 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2802 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2803 insertOp.isSameAs(extractOp, isSame))
2804 return insertOp.getSource();
2809OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2810 if (OpFoldResult reshapedSource = reshapeConstantSource(
2811 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2813 return reshapedSource;
2814 if (getSourceType() ==
getType() &&
2816 return this->getSource();
2820 return OpFoldResult();
2825 auto rankedTensorType = llvm::cast<RankedTensorType>(
tensor.getType());
2826 unsigned rank = rankedTensorType.getRank();
2830 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType,
tensor,
2831 offsets, sizes, strides);
2838void InsertSliceOp::getAsmResultNames(
2840 setNameFn(getResult(),
"inserted_slice");
2854 result.addAttributes(attrs);
2855 build(
b,
result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2856 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2857 b.getDenseI64ArrayAttr(staticSizes),
2858 b.getDenseI64ArrayAttr(staticStrides));
2863void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2864 Value dest, ArrayRef<Range> ranges,
2865 ArrayRef<NamedAttribute> attrs) {
2867 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
2871void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2873 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2874 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2875 offsets, [](Value v) -> OpFoldResult {
return v; });
2876 SmallVector<OpFoldResult> sizeValues =
2877 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
2878 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2879 strides, [](Value v) -> OpFoldResult {
return v; });
2880 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
2886 RankedTensorType srcType, RankedTensorType dstType,
2891 RankedTensorType expected =
2892 ExtractSliceOp::inferResultType(dstType, staticSizes);
2894 *expectedType = expected;
2899LogicalResult InsertSliceOp::verify() {
2901 RankedTensorType expectedType;
2904 getStaticSizes(), getStaticStrides(), &expectedType);
2911 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2912 getStaticStrides(),
true);
2914 return getOperation()->emitError(boundsResult.
errorMessage);
2937 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2940 if (!prevInsertOp ||
2941 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2942 !prevInsertOp.isSameAs(insertOp, isSame))
2945 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2957 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2960 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2961 !extractOp.isSameAs(insertOp, isSame))
2964 return extractOp.getSource();
2967OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2968 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2969 getSourceType() ==
getType() &&
2971 return this->getSource();
2978 return OpFoldResult();
2981LogicalResult InsertSliceOp::reifyResultShapes(
2983 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
2992template <
typename InsertOpTy>
2993class InsertSliceOpConstantArgumentFolder final
2994 :
public OpRewritePattern<InsertOpTy> {
2996 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2998 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2999 PatternRewriter &rewriter)
const override {
3000 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
3001 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3002 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
3011 SliceBoundsVerificationResult sliceResult =
3013 mixedOffsets, mixedSizes, mixedStrides);
3018 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
3019 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
3021 Value toInsert = insertSliceOp.getSource();
3022 if (sourceType != insertSliceOp.getSourceType()) {
3023 OpBuilder::InsertionGuard g(rewriter);
3027 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
3029 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3030 sourceType, toInsert);
3033 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
3034 mixedSizes, mixedStrides);
3059template <
typename InsertOpTy>
3060struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
3061 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3063 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3064 PatternRewriter &rewriter)
const override {
3065 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3066 return matchPattern(operand, matchConstantIndex());
3070 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3073 return std::nullopt;
3074 return castOp.getSource();
3076 std::optional<Value> sourceCastSource =
3077 getSourceOfCastOp(insertSliceOp.getSource());
3078 std::optional<Value> destCastSource =
3079 getSourceOfCastOp(insertSliceOp.getDest());
3080 if (!sourceCastSource && !destCastSource)
3084 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3085 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3086 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
3087 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3088 if (!srcType || !dstType)
3094 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3096 staticSizes, srcType.getShape(),
true);
3097 if (!rankReductionMask.has_value())
3104 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3105 int64_t rankReducedIdx = 0;
3106 for (
auto [idx, size] :
enumerate(staticSizes)) {
3107 if (!rankReductionMask.value().contains(idx) &&
3108 !srcType.isDynamicDim(rankReducedIdx)) {
3110 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
3111 size = srcType.getDimSize(rankReducedIdx++);
3117 staticSizes, insertSliceOp.getStaticStrides()) !=
3118 SliceVerificationResult::Success)
3120 SliceBoundsVerificationResult sliceResult =
3122 mixedSizes, insertSliceOp.getMixedStrides());
3127 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3128 insertSliceOp.getMixedOffsets(), mixedSizes,
3129 insertSliceOp.getMixedStrides());
3132 bool isParallelInsert =
3133 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3134 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3135 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3136 insertSliceOp.getDestType(),
3165template <
typename InsertOpTy>
3166struct InsertSliceOpSourceCastInserter final
3167 :
public OpRewritePattern<InsertOpTy> {
3168 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3170 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3171 PatternRewriter &rewriter)
const override {
3172 RankedTensorType srcType = insertSliceOp.getSourceType();
3173 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3175 SmallVector<int64_t> newSrcShape(srcType.getShape());
3176 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3177 if (std::optional<int64_t> constInt =
3182 newSrcShape[i] = *constInt;
3188 RankedTensorType newSrcType = RankedTensorType::get(
3189 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3190 if (srcType == newSrcType ||
3192 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3200 OpBuilder::InsertionGuard g(rewriter);
3204 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3206 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3207 newSrcType, insertSliceOp.getSource());
3209 insertSliceOp, cast, insertSliceOp.getDest(),
3210 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3211 insertSliceOp.getMixedStrides());
3217llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3221void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3222 MLIRContext *context) {
3223 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3224 InsertSliceOpCastFolder<InsertSliceOp>,
3225 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3232 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3233 unsigned rank = rankedTensorType.getRank();
3237 return b.createOrFold<tensor::InsertSliceOp>(loc,
tensor, dest, offsets,
3246 setNameFn(getResult(),
"padded");
3249LogicalResult PadOp::verify() {
3250 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3251 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3253 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3254 if (!expectedType) {
3255 return emitError(
"failed to infer expectedType from sourceType ")
3256 << sourceType <<
", specified resultType is " << resultType;
3258 if (resultType.getRank() != expectedType.getRank()) {
3260 << resultType <<
" does not match the inferred type "
3263 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3264 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3266 if (expectedType.isDynamicDim(i))
3269 << resultType <<
" does not match the inferred type "
3276LogicalResult PadOp::verifyRegions() {
3277 auto ®ion = getRegion();
3278 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3279 Block &block = region.front();
3281 return emitError(
"expected the block to have ") << rank <<
" arguments";
3285 if (!en.value().isIndex())
3287 << (en.index() + 1) <<
" to be an index";
3292 if (yieldOp.getValue().getType() !=
3294 return emitOpError(
"expected yield type to match shape element type");
3299RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3300 ArrayRef<int64_t> staticLow,
3301 ArrayRef<int64_t> staticHigh,
3302 ArrayRef<int64_t> resultShape) {
3303 unsigned rank = sourceType.getRank();
3304 if (staticLow.size() != rank)
3305 return RankedTensorType();
3306 if (staticHigh.size() != rank)
3307 return RankedTensorType();
3308 if (!resultShape.empty() && resultShape.size() != rank)
3309 return RankedTensorType();
3311 SmallVector<int64_t, 4> inferredShape;
3312 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3313 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3314 staticHigh[i] == ShapedType::kDynamic) {
3315 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3318 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3319 assert((resultShape.empty() || size == resultShape[i] ||
3320 resultShape[i] == ShapedType::kDynamic) &&
3321 "mismatch between inferred shape and result shape");
3322 inferredShape.push_back(size);
3326 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3329void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3330 Value source, ArrayRef<int64_t> staticLow,
3332 bool nofold, ArrayRef<NamedAttribute> attrs) {
3333 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3335 resultType = inferResultType(sourceType, staticLow, staticHigh);
3336 result.addAttributes(attrs);
3337 build(
b,
result, resultType, source, low, high,
3338 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3339 nofold ?
b.getUnitAttr() : UnitAttr());
3342void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3344 ArrayRef<NamedAttribute> attrs) {
3345 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3346 unsigned rank = sourceType.getRank();
3347 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3348 build(
b,
result, resultType, source, staticVector, staticVector, low, high,
3352void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3353 Value source, ArrayRef<OpFoldResult> low,
3354 ArrayRef<OpFoldResult> high,
bool nofold,
3355 ArrayRef<NamedAttribute> attrs) {
3356 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3357 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3358 SmallVector<int64_t, 4> staticLow, staticHigh;
3366 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3368 assert(llvm::isa<RankedTensorType>(resultType));
3369 result.addAttributes(attrs);
3370 build(
b,
result, resultType, source, dynamicLow, dynamicHigh,
3371 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3372 nofold ?
b.getUnitAttr() : UnitAttr());
3375void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3376 Value source, ArrayRef<OpFoldResult> low,
3377 ArrayRef<OpFoldResult> high, Value constantPadValue,
3378 bool nofold, ArrayRef<NamedAttribute> attrs) {
3379 build(
b,
result, resultType, source, low, high, nofold, attrs);
3382 Region *region =
result.regions[0].get();
3383 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3384 SmallVector<Type> blockArgTypes(sourceRank,
b.getIndexType());
3385 SmallVector<Location> blockArgLocs(sourceRank,
result.location);
3389 OpBuilder::InsertionGuard guard(
b);
3390 b.createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3391 tensor::YieldOp::create(
b,
result.location, constantPadValue);
3394llvm::SmallBitVector PadOp::getPaddedDims() {
3395 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3396 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3397 for (
const auto &en :
enumerate(paddingWidths))
3399 paddedDims.set(en.index());
3401 extractPaddedDims(getMixedLowPad());
3402 extractPaddedDims(getMixedHighPad());
3409struct FoldStaticZeroPadding :
public OpRewritePattern<PadOp> {
3410 using OpRewritePattern<PadOp>::OpRewritePattern;
3412 LogicalResult matchAndRewrite(PadOp padTensorOp,
3413 PatternRewriter &rewriter)
const override {
3414 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3416 if (padTensorOp.getNofold())
3419 padTensorOp, padTensorOp.getResult().
getType(),
3420 padTensorOp.getSource());
3426struct FoldSourceTensorCast :
public OpRewritePattern<PadOp> {
3427 using OpRewritePattern<PadOp>::OpRewritePattern;
3429 LogicalResult matchAndRewrite(PadOp padTensorOp,
3430 PatternRewriter &rewriter)
const override {
3431 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3435 auto newResultType = PadOp::inferResultType(
3436 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3437 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3438 padTensorOp.getResultType().getShape());
3440 if (newResultType == padTensorOp.getResultType()) {
3442 padTensorOp.getSourceMutable().assign(castOp.getSource());
3445 auto newOp = PadOp::create(
3446 rewriter, padTensorOp->getLoc(), newResultType,
3447 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3448 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3449 padTensorOp.getHigh(), padTensorOp.getNofold(),
3452 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3455 padTensorOp, padTensorOp.getResultType(), newOp);
3463struct FoldTargetTensorCast :
public OpRewritePattern<PadOp> {
3464 using OpRewritePattern<PadOp>::OpRewritePattern;
3466 LogicalResult matchAndRewrite(PadOp padTensorOp,
3467 PatternRewriter &rewriter)
const override {
3468 if (!padTensorOp.getResult().hasOneUse())
3471 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3475 tensorCastOp.getDest().getType()))
3478 auto replacementOp = PadOp::create(
3479 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3480 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3481 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3482 padTensorOp.getHigh(), padTensorOp.getNofold(),
3484 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3486 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3487 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3527struct FoldOrthogonalPaddings :
public OpRewritePattern<PadOp> {
3528 using OpRewritePattern<PadOp>::OpRewritePattern;
3530 LogicalResult matchAndRewrite(PadOp padOp,
3531 PatternRewriter &rewriter)
const override {
3532 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3535 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3536 if (!outerPadOp || outerPadOp.getNofold())
3538 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3543 int64_t rank = padOp.getSourceType().getRank();
3544 if (outerSliceOp.getSourceType().getRank() != rank) {
3546 "cannot fold rank-reducing chain");
3550 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3552 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3556 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3558 "cannot fold PadOps with low padding");
3562 Attribute innerAttr, outerAttr;
3563 Value innerValue = padOp.getConstantPaddingValue();
3564 Value outerValue = outerPadOp.getConstantPaddingValue();
3565 if (!innerValue || !outerValue ||
3568 innerAttr != outerAttr) {
3570 padOp,
"cannot fold PadOps with different padding values");
3574 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3575 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3576 if (innerDims.anyCommon(outerDims)) {
3578 padOp,
"cannot fold PadOps with common padding dimensions");
3586 SmallVector<OpFoldResult> newOffsets(rank, rewriter.
getIndexAttr(0));
3588 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3589 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3590 if (!innerDims.test(en.index()) &&
3592 en.value() = outerOffset;
3595 if (!outerDims.test(en.index()) &&
3597 en.value() = innerOffset;
3601 padOp,
"cannot find zero-offset and zero-padding pair");
3609 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3611 if (!outerDims.test(en.index()))
3613 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3614 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3615 assert(ShapedType::isStatic(sourceSize) &&
3616 "expected padded dimension to have a static size");
3619 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3620 "match the size of the outer padding");
3622 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3626 SmallVector<OpFoldResult> newHighPad(rank, rewriter.
getIndexAttr(0));
3628 if (innerDims.test(en.index()))
3629 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3630 if (outerDims.test(en.index()))
3631 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3636 auto newSliceOp = ExtractSliceOp::create(
3637 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3638 newSizes, innerSliceOp.getMixedStrides());
3639 auto newPadOp = PadOp::create(
3640 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3641 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3644 newPadOp.getRegion().begin());
3645 rewriter.
replaceOp(padOp, newPadOp.getResult());
3650struct FoldStaticPadding :
public OpRewritePattern<PadOp> {
3651 using OpRewritePattern<PadOp>::OpRewritePattern;
3653 LogicalResult matchAndRewrite(PadOp padTensorOp,
3654 PatternRewriter &rewriter)
const override {
3655 Value input = padTensorOp.getSource();
3656 if (!llvm::isa<RankedTensorType>(input.
getType()))
3658 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3659 auto inputRank = inputDims.size();
3661 auto oldResultType =
3662 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3666 auto outputDims = oldResultType.getShape();
3669 SmallVector<int64_t> constOperandsLow;
3670 SmallVector<Value> newLows;
3671 for (
auto operand : padTensorOp.getLow()) {
3674 constOperandsLow.push_back(ShapedType::kDynamic);
3675 newLows.push_back(operand);
3678 constOperandsLow.push_back(intOp.getExtValue());
3680 SmallVector<int64_t> constOperandsHigh;
3681 SmallVector<Value> newHighs;
3682 for (
auto operand : padTensorOp.getHigh()) {
3685 constOperandsHigh.push_back(ShapedType::kDynamic);
3686 newHighs.push_back(operand);
3689 constOperandsHigh.push_back(intOp.getExtValue());
3692 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3693 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3696 if (inputDims.size() != outputDims.size() ||
3697 inputDims.size() != constLow.size() ||
3698 inputDims.size() != constHigh.size())
3703 for (
size_t i = 0; i < inputRank; i++) {
3704 if (constLow[i] == ShapedType::kDynamic)
3705 constLow[i] = constOperandsLow[lowCount++];
3706 if (constHigh[i] == ShapedType::kDynamic)
3707 constHigh[i] = constOperandsHigh[highCount++];
3710 auto staticLow = ArrayRef<int64_t>(constLow);
3711 auto staticHigh = ArrayRef<int64_t>(constHigh);
3714 SmallVector<int64_t> newOutDims;
3715 for (
size_t i = 0; i < inputRank; i++) {
3716 if (outputDims[i] == ShapedType::kDynamic) {
3717 newOutDims.push_back(
3718 (staticLow[i] == ShapedType::kDynamic ||
3719 staticHigh[i] == ShapedType::kDynamic ||
3720 inputDims[i] == ShapedType::kDynamic
3721 ? ShapedType::kDynamic
3722 : inputDims[i] + staticLow[i] + staticHigh[i]));
3724 newOutDims.push_back(outputDims[i]);
3728 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3729 llvm::all_of(newOutDims,
3730 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3734 auto newResultType = RankedTensorType::get(
3735 newOutDims, padTensorOp.getType().getElementType());
3736 auto newOp = PadOp::create(
3737 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3738 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3742 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3770struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3771 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3773 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3774 PatternRewriter &rewriter)
const override {
3775 if (padOp.getNofold()) {
3779 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3780 if (!producerPad || producerPad.getNofold()) {
3782 padOp,
"producer is not a foldable tensor.pad op");
3786 Value consumerPadValue = padOp.getConstantPaddingValue();
3787 Value producerPadValue = producerPad.getConstantPaddingValue();
3788 if (!consumerPadValue || !producerPadValue ||
3789 consumerPadValue != producerPadValue) {
3792 "cannot fold PadOps with different or non-constant padding values");
3795 Location loc = padOp.getLoc();
3800 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3801 ArrayRef<OpFoldResult> producerPaddings) {
3802 SmallVector<OpFoldResult> sumPaddings;
3803 for (
auto [consumerIndex, producerIndex] :
3804 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3805 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3806 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3811 SmallVector<OpFoldResult> newHighPad =
3812 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3813 SmallVector<OpFoldResult> newLowPad =
3814 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3816 auto newPadOp = tensor::PadOp::create(
3817 rewriter, padOp.getLoc(), padOp.getResultType(),
3818 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3821 newPadOp.getRegion().begin());
3822 rewriter.
replaceOp(padOp, newPadOp.getResult());
3830PadOp::reifyResultShapes(OpBuilder &
b,
3832 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
3833 SmallVector<OpFoldResult> lp = getMixedLowPad();
3834 SmallVector<OpFoldResult> hp = getMixedHighPad();
3835 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3836 if (!
getType().isDynamicDim(i)) {
3837 reifiedReturnShapes[0][i] =
b.getIndexAttr(
getType().getDimSize(i));
3840 Location loc = getLoc();
3841 Value dim =
b.createOrFold<tensor::DimOp>(
3844 AffineExpr d0, d1, d2;
3847 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3852void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3853 MLIRContext *context) {
3854 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3855 FoldOrthogonalPaddings, FoldStaticPadding,
3856 FoldConsecutiveConstantPadding>(context);
3868Value PadOp::getConstantPaddingValue() {
3869 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3872 Value padValue = yieldOp.getValue();
3883OpFoldResult PadOp::fold(FoldAdaptor) {
3884 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3894OpResult ParallelInsertSliceOp::getTiedOpResult() {
3895 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3896 for (
const auto &it :
3897 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3898 Operation &nextOp = it.value();
3899 if (&nextOp == getOperation())
3900 return parallelCombiningParent.getParentResult(it.index());
3902 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3906void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3907 Value source, Value dest,
3908 ArrayRef<OpFoldResult> offsets,
3909 ArrayRef<OpFoldResult> sizes,
3910 ArrayRef<OpFoldResult> strides,
3911 ArrayRef<NamedAttribute> attrs) {
3912 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3913 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3917 result.addAttributes(attrs);
3918 build(
b,
result, {}, source, dest, dynamicOffsets, dynamicSizes,
3919 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3920 b.getDenseI64ArrayAttr(staticSizes),
3921 b.getDenseI64ArrayAttr(staticStrides));
3926void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3927 Value source, Value dest,
3928 ArrayRef<Range> ranges,
3929 ArrayRef<NamedAttribute> attrs) {
3931 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
3935void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3936 Value source, Value dest,
ValueRange offsets,
3938 ArrayRef<NamedAttribute> attrs) {
3939 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3940 offsets, [](Value v) -> OpFoldResult {
return v; });
3941 SmallVector<OpFoldResult> sizeValues =
3942 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
3943 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3944 strides, [](Value v) -> OpFoldResult {
return v; });
3945 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
3950void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
3951 Value dest, ArrayRef<OpFoldResult> sizes,
3952 ArrayRef<NamedAttribute> attrs) {
3953 Attribute zeroIdxAttr =
b.getIndexAttr(0);
3954 Attribute oneIdxAttr =
b.getIndexAttr(1);
3955 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3956 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3957 build(
b,
result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3960LogicalResult ParallelInsertSliceOp::verify() {
3961 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3962 return this->
emitError(
"expected InParallelOpInterface parent, got:")
3963 << *(getOperation()->getParentOp());
3966 RankedTensorType expectedType;
3969 getStaticSizes(), getStaticStrides(), &expectedType);
3976 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3977 getStaticStrides(),
true);
3979 return getOperation()->emitError(boundsResult.
errorMessage);
3984void ParallelInsertSliceOp::getCanonicalizationPatterns(
3985 RewritePatternSet &results, MLIRContext *context) {
3986 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3987 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3988 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3991llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3996MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3997 return getDestMutable();
4000Operation *ParallelInsertSliceOp::getIteratingParent() {
4002 if (
auto combiningOp =
4003 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
4004 return combiningOp->getParentOp();
4012void ScatterOp::getAsmResultNames(
4014 setNameFn(getResult(),
"scatter");
4017LogicalResult ScatterOp::verify() {
4018 int64_t destRank = getDestType().getRank();
4019 ArrayRef<int64_t> scatterDims = getScatterDims();
4021 getIndicesType().
getShape(), destRank,
4022 "scatter",
"dest")))
4026 return emitOpError(
"requires 'unique' attribute to be set");
4033 RankedTensorType expectedSourceType = GatherOp::inferResultType(
4034 getDestType(), getIndicesType(), scatterDims,
false);
4035 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
4036 getDestType(), getIndicesType(), scatterDims,
true);
4037 if (getSourceType() != expectedSourceType &&
4038 getSourceType() != expectedRankReducedSourceType) {
4042 << expectedSourceType <<
" or its rank-reduced variant "
4043 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
4054void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4055 Type aggregateType,
ValueRange dynamicSizes) {
4056 build(builder,
result, aggregateType, element, dynamicSizes);
4059void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4060 ArrayRef<int64_t> staticShape,
ValueRange dynamicSizes) {
4061 auto aggregateType = RankedTensorType::get(staticShape, element.
getType());
4062 build(builder,
result, aggregateType, element, dynamicSizes);
4065void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4066 ArrayRef<OpFoldResult> sizes) {
4067 SmallVector<int64_t> staticShape;
4068 SmallVector<Value> dynamicSizes;
4070 build(builder,
result, element, staticShape, dynamicSizes);
4073void SplatOp::getAsmResultNames(
4075 setNameFn(getResult(),
"splat");
4078LogicalResult SplatOp::verify() {
4084SplatOp::reifyResultShapes(OpBuilder &builder,
4086 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
4088 for (int64_t i = 0; i <
getType().getRank(); ++i) {
4089 if (
getType().isDynamicDim(i)) {
4098OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4099 auto constOperand = adaptor.getInput();
4100 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4104 if (!
getType().hasStaticShape())
4119 if (isa<InsertSliceOp>(op.getOperation()) ||
4120 isa<LoopLikeOpInterface>(op.getOperation()))
4153 isa<linalg::RelayoutOpInterface>(*op))
4161 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4164 replacements.reserve(newOp->getNumResults());
4165 for (
auto [oldResult, newResult] :
4166 llvm::zip(op->getResults(), newOp->getResults())) {
4167 if (newResult.getType() != oldResult.getType()) {
4168 replacements.push_back(tensor::CastOp::create(
4169 rewriter, op->getLoc(), oldResult.
getType(), newResult));
4171 replacements.push_back(newResult);
4184void TensorDialect::getCanonicalizationPatterns(
4185 RewritePatternSet &results)
const {
4193#define GET_OP_CLASSES
4194#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
Attributes are known-constant values of operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
auto value_begin() const
Get an iterator of the given type to the start of the held element values.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.