33 #include "llvm/ADT/DenseSet.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallBitVector.h"
36 #include "llvm/ADT/StringRef.h"
37 #include "llvm/Support/Casting.h"
38 #include "llvm/Support/MathExtras.h"
49 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
51 if (complex::ConstantOp::isBuildableWith(value, type))
52 return complex::ConstantOp::create(builder, loc, type,
53 llvm::cast<ArrayAttr>(value));
59 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
60 if (tensorType.isDynamicDim(dim))
61 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
68 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
70 for (int64_t i = 0; i < tensorType.getRank(); ++i)
77 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
78 assert(tensorType &&
"expected tensor type");
82 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
84 return destOp.getTiedOpOperand(opResult)->get();
92 if (!tensorType.hasStaticShape()) {
100 for (int64_t sz : tensorType.getShape())
106 tensor::EmptyOp::create(b, loc, mixedSizes, tensorType.getElementType());
114 if (llvm::isa<TensorType>(opResult.getType())) {
118 result.push_back(*destination);
125 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127 return rtp1.getShape() == rtp2.getShape() &&
128 rtp1.getElementType() == rtp2.getElementType();
138 llvm::SmallBitVector droppedDims(mixedSizes.size());
139 int64_t shapePos = reducedShape.size() - 1;
141 for (
const auto &size :
enumerate(llvm::reverse(mixedSizes))) {
142 size_t idx = mixedSizes.size() - size.index() - 1;
144 bool isStaticUnitSize =
145 isa<Attribute>(size.value()) &&
146 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
151 assert(isStaticUnitSize &&
"expected unit dim");
152 droppedDims.set(idx);
157 if (!isStaticUnitSize) {
163 if (reducedShape[shapePos] == 1) {
169 droppedDims.set(idx);
172 assert(shapePos < 0 &&
"dimension mismatch");
179 static RankedTensorType
183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184 "incorrect number of dynamic sizes");
188 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
192 if (cst.has_value()) {
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
198 staticShape[i] = *cst;
200 foldedDynamicSizes.push_back(dynamicSize);
214 if (inputs.size() != 1 || outputs.size() != 1)
216 Type a = inputs.front(), b = outputs.front();
217 auto aT = dyn_cast<TensorType>(a);
218 auto bT = dyn_cast<TensorType>(b);
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239 if (!tensorBitcastOperand)
242 auto resultType = cast<TensorType>(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
253 results.
add<ChainedTensorBitcast>(context);
261 setNameFn(getResult(),
"cast");
267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
271 if (!sourceType || !targetType)
275 if (sourceType.getElementType() != targetType.getElementType())
279 if (sourceType.getRank() != targetType.getRank())
283 if (sourceType.getEncoding() != targetType.getEncoding())
287 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (ShapedType::isStatic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
325 castOp.getSource().getType());
358 if (llvm::isa<BlockArgument>(opOperand.get()))
360 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
361 return castOp && canFoldIntoConsumerOp(castOp);
368 newOperands.reserve(op->getNumOperands());
373 int64_t dpsInitIdx = 0;
374 for (
OpOperand &opOperand : op->getOpOperands()) {
375 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
377 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
378 if (op.isDpsInit(&opOperand) &&
379 !llvm::isa<MemRefType>(newOperands.back().getType()))
380 newResTy[dpsInitIdx++] = newOperands.back().getType();
390 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
392 operand.set(castOp.getOperand());
396 return success(folded);
400 if (inputs.size() != 1 || outputs.size() != 1)
402 Type a = inputs.front(), b = outputs.front();
403 auto aT = llvm::dyn_cast<TensorType>(a);
404 auto bT = llvm::dyn_cast<TensorType>(b);
408 if (aT.getElementType() != bT.getElementType())
424 int64_t rank = one.getRank();
425 if (rank != two.getRank())
430 for (int64_t i = 0; i < rank; ++i) {
431 if (one.isDynamicDim(i)) {
432 join.push_back(two.getDimSize(i));
435 if (two.isDynamicDim(i)) {
436 join.push_back(one.getDimSize(i));
439 if (one.getDimSize(i) != two.getDimSize(i))
441 join.push_back(one.getDimSize(i));
453 LogicalResult matchAndRewrite(CastOp tensorCast,
455 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
457 if (!tensorCastOperand)
461 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
462 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
463 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
477 auto newJoin =
joinShapes(sourceType, resultType);
478 if (firstJoin != newJoin)
481 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
482 tensorCastOperand.getOperand());
502 LogicalResult matchAndRewrite(CastOp tensorCast,
504 auto extractOperand =
505 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
508 auto rankedResultType =
509 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
510 if (!rankedResultType)
514 rankedResultType.getShape() ==
515 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
521 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
523 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
524 if (dimMask && dimMask->count(i))
526 int64_t dim = rankedResultType.getShape()[dimIndex++];
527 if (ShapedType::isDynamic(dim))
529 sizes[i] = rewriter.getIndexAttr(dim);
532 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
533 tensorCast, rankedResultType, extractOperand.getSource(),
534 extractOperand.getMixedOffsets(), sizes,
535 extractOperand.getMixedStrides());
544 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
551 RankedTensorType ConcatOp::inferResultType(int64_t dim,
TypeRange inputTypes) {
552 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
554 llvm::to_vector<4>(llvm::map_range(inputTypes, [](
Type type) {
555 return llvm::cast<RankedTensorType>(type);
557 int64_t concatRank = tensorTypes[0].getRank();
560 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
563 for (int64_t i = 0, e = concatRank; i < e; ++i) {
567 for (
auto tensorType : tensorTypes)
572 for (
auto tensorType : tensorTypes)
575 sizes[dim] = concatSize.asInteger();
581 FailureOr<RankedTensorType> resultType =
582 inferResultType(dim, inputs.
getTypes());
583 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
584 build(builder, result, *resultType, dim, inputs);
588 if (getInputs().size() < 1)
589 return emitOpError(
"requires at least one input");
592 for (
auto input : getInputs())
593 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
595 RankedTensorType resultType = getResultType();
596 int64_t resultRank = getRank();
597 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
598 return type.getRank() != resultRank;
600 return emitOpError(
"rank of concatenated inputs must match result rank");
602 Type resultElementType = resultType.getElementType();
603 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
604 return type.getElementType() != resultElementType;
606 return emitOpError(
"inputs and result element type must match");
608 int64_t dim = getDim();
609 if (dim >= resultRank)
610 return emitOpError(
"concatenation dim must be less than the tensor rank");
613 for (int64_t i = 0, e = resultRank; i < e; ++i) {
617 for (
auto tensorType : inputTypes) {
618 FailureOr<SaturatedInteger> maybeSize =
621 return emitOpError(
"static concatenation size mismatch along ")
622 <<
"non-concatenated dimension " << i;
628 for (
auto tensorType : inputTypes)
631 sizes[dim] = concatSize.asInteger();
632 auto inferredResultType =
635 for (
auto [inferredSize, actualSize] :
636 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
637 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
638 ShapedType::isDynamic(actualSize);
639 if (!hasDynamic && inferredSize != actualSize)
640 return emitOpError(
"result type ")
641 << resultType <<
"does not match inferred shape "
642 << inferredResultType <<
" static sizes";
648 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
649 size_t numInputs = getInputs().size();
650 uint64_t concatDim = getDim();
653 inputShapes.reserve(numInputs);
655 concatOffsets.reserve(numInputs);
666 outputShape = inputShape;
667 concatOffsets.push_back(zero);
669 concatOffsets.push_back(outputShape[concatDim]);
671 builder, loc, addExpr,
672 {outputShape[concatDim], inputShape[concatDim]});
674 inputShapes.emplace_back(std::move(inputShape));
677 Value replacement = tensor::EmptyOp::create(builder, loc, outputShape,
680 int64_t rank =
getType().getRank();
685 offsets[concatDim] = concatOffsets[index];
686 auto insertSlice = tensor::InsertSliceOp::create(
687 builder, loc, input, replacement, offsets, inputShapes[index], strides);
688 replacement = insertSlice.getResult();
691 replacement = tensor::CastOp::create(builder, loc,
getType(), replacement);
700 int64_t dim = getDim();
701 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
703 Value init = inputs[0];
704 int64_t rank =
getType().getRank();
711 for (int64_t i = 0; i < rank; ++i) {
714 if (!
getType().isDynamicDim(i)) {
716 }
else if (!inferredResultType.isDynamicDim(i)) {
719 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
721 reifiedReturnShapes[0][i] =
722 tensor::DimOp::create(builder, init.
getLoc(), init, i).getResult();
726 if (
getType().isDynamicDim(dim)) {
734 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
742 reifiedReturnShapes[0][dim] =
748 void ConcatOp::getAsmResultNames(
750 setNameFn(getResult(),
"concat");
755 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
765 LogicalResult matchAndRewrite(ConcatOp concatOp,
767 if (concatOp.getInputs().size() != 1)
770 concatOp.getInputs()[0]);
797 LogicalResult matchAndRewrite(ConcatOp concatOp,
799 int64_t dim = concatOp.getDim();
800 RankedTensorType inferredResultType =
801 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
804 LogicalResult matched = failure();
808 for (
auto [operandIdx, operandType] :
811 inferredOperandShape[dim] =
812 cast<RankedTensorType>(operandType).getDimSize(dim);
814 inferredOperandShape, inferredResultType.getElementType());
822 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
823 concatOp.getOperand(operandIdx));
825 concatOp->setOperand(operandIdx, castOp->getResult(0));
851 LogicalResult matchAndRewrite(ConcatOp concatOp,
853 int64_t dim = concatOp.getDim();
854 RankedTensorType inferredResultType =
855 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
859 concatOp.getResultType())) {
864 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
865 concatOp->getOperands());
877 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
886 setNameFn(getResult(),
"dim");
893 build(builder, result, source, indexValue);
896 std::optional<int64_t> DimOp::getConstantIndex() {
905 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
906 if (!rankedSourceType)
917 setResultRange(getResult(),
923 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
928 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
934 int64_t indexVal = index.getInt();
935 if (indexVal < 0 || indexVal >= tensorType.getRank())
939 if (!tensorType.isDynamicDim(index.getInt())) {
941 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
944 Operation *definingOp = getSource().getDefiningOp();
947 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
949 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
952 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
955 auto dynExtents = fromElements.getDynamicExtents().begin();
956 for (
auto dim : resultType.getShape().take_front(index.getInt()))
957 if (ShapedType::isDynamic(dim))
960 return Value{*dynExtents};
964 unsigned unsignedIndex = index.getValue().getZExtValue();
966 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
969 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
970 sliceOp.isDynamicSize(unsignedIndex)) {
971 return {sliceOp.getDynamicSize(unsignedIndex)};
987 LogicalResult matchAndRewrite(DimOp dimOp,
989 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
992 Value newSource = castOp.getOperand();
1003 LogicalResult matchAndRewrite(DimOp dimOp,
1005 auto source = dimOp.getSource();
1006 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1010 auto resultIndex = cast<OpResult>(source).getResultNumber();
1011 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1014 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1024 LogicalResult matchAndRewrite(DimOp dim,
1026 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1036 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1037 if (extract.
getType() != dim.getType())
1039 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1048 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1058 assert(none_of(staticShape, ShapedType::isDynamic) &&
1059 "expected only static sizes");
1060 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
1067 build(builder, result, tensorType, dynamicSizes);
1076 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
1081 return emitOpError(
"incorrect number of dynamic sizes, has ")
1083 <<
getType().getNumDynamicDims();
1092 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1093 if (
getType().isDynamicDim(i)) {
1103 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1105 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1106 if (
getType().isDynamicDim(i))
1115 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1116 if (
getType().isDynamicDim(i)) {
1140 LogicalResult matchAndRewrite(EmptyOp op,
1144 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1147 if (foldedTensorType == op.getType())
1150 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1151 foldedDynamicSizes);
1160 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1162 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1163 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1164 if (!emptyTensorOp || !maybeConstantIndex)
1166 auto emptyTensorType = emptyTensorOp.getType();
1167 if (*maybeConstantIndex < 0 ||
1168 *maybeConstantIndex >= emptyTensorType.getRank() ||
1169 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1172 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1195 LogicalResult matchAndRewrite(CastOp castOp,
1199 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1204 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1208 newMixedSizes.reserve(currMixedSizes.size());
1209 assert(resultShape.size() == currMixedSizes.size() &&
1210 "mismatch in result shape and sizes of empty op");
1211 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1212 int64_t newDim = std::get<0>(it);
1216 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1217 if (ShapedType::isDynamic(newDim) ||
1218 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1223 producer,
"mismatch in static value of shape of empty tensor "
1224 "result and cast result");
1226 newMixedSizes.push_back(attr);
1232 if (ShapedType::isStatic(newDim)) {
1233 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1239 newMixedSizes.push_back(currDim);
1244 resultType.getElementType());
1253 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1254 ReplaceEmptyTensorStaticShapeDims>(context);
1271 struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1274 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1276 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1279 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1282 extract, tensorCast.getSource(), extract.getIndices());
1297 struct ExtractFromCollapseShape :
public OpRewritePattern<tensor::ExtractOp> {
1300 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1303 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1306 if (!collapseOp.getSrcType().hasStaticShape())
1309 auto sourceSizes = collapseOp.getSrcType().getShape();
1312 extractOp.getIndices().end());
1314 for (
auto [index, group] :
1315 llvm::zip(indices, collapseOp.getReassociationIndices())) {
1316 assert(!group.empty() &&
"association indices groups cannot be empty");
1317 auto groupSize = group.size();
1319 if (groupSize == 1) {
1320 sourceIndices.push_back(index);
1325 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
1326 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1327 rewriter, extractOp.getLoc(), index, basis,
true);
1328 llvm::append_range(sourceIndices,
delinearize.getResults());
1330 if (collapseOp.getReassociationIndices().empty()) {
1333 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1335 rewriter, extractOp.getLoc(), zeroAffineMap,
1337 for (int64_t i = 0; i < srcRank; i++) {
1338 sourceIndices.push_back(
1344 extractOp, collapseOp.getSrc(), sourceIndices);
1351 void ExtractOp::getAsmResultNames(
1353 setNameFn(getResult(),
"extracted");
1358 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1359 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1360 return emitOpError(
"incorrect number of indices for extract_element");
1369 auto insertOp = extractOp.getTensor().
getDefiningOp<InsertOp>();
1374 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1375 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1376 return insertOp.getScalar();
1382 if (
Attribute tensor = adaptor.getTensor()) {
1385 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1386 return splatTensor.getSplatValue<
Attribute>();
1389 if (isa<DenseResourceElementsAttr>(tensor))
1395 for (
Attribute indice : adaptor.getIndices()) {
1396 if (!indice || !llvm::isa<IntegerAttr>(indice))
1398 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1402 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1403 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1404 auto rank = tensorType.getRank();
1405 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
1409 for (
int i = rank - 1; i >= 0; --i) {
1410 flatIndex += indices[i] * stride;
1411 stride *= tensorType.getDimSize(i);
1415 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1418 return fromElementsOp.getElements()[flatIndex];
1422 if (
Attribute tensor = adaptor.getTensor()) {
1423 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1424 if (elementsAttr && elementsAttr.isValidIndex(indices))
1425 return elementsAttr.getValues<
Attribute>()[indices];
1436 results.
add<ExtractFromTensorCast>(context);
1448 void FromElementsOp::getAsmResultNames(
1450 setNameFn(getResult(),
"from_elements");
1455 assert(!elements.empty() &&
"expected at least one element");
1457 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
1458 build(builder, result, resultType, elements);
1461 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1462 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1485 struct ExtractElementFromIndexCast
1489 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1492 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1498 auto newExtract = tensor::ExtractOp::create(
1499 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1512 results.
add<ExtractElementFromIndexCast>(context);
1519 void GatherOp::getAsmResultNames(
1521 setNameFn(getResult(),
"gather");
1536 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1537 RankedTensorType indicesType,
1541 resultShape.reserve(resultShape.size() + sourceType.getRank());
1542 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1543 if (llvm::binary_search(gatherDims, idx)) {
1545 resultShape.push_back(1);
1548 resultShape.push_back(sourceType.getDimSize(idx));
1553 static LogicalResult
1556 StringRef gatherOrScatter, StringRef sourceOrDest) {
1558 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1560 int64_t numGatherDims = dims.size();
1561 if (numGatherDims > rank)
1563 <<
"_dims overflow " << sourceOrDest <<
" rank";
1564 if (indices.empty() || indices.back() != numGatherDims)
1566 <<
"_dims length must match the size of last dimension of indices";
1567 for (int64_t val : dims) {
1570 <<
"_dims value must be non-negative";
1573 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1575 for (int64_t i = 1; i < numGatherDims; ++i) {
1576 if (dims[i - 1] >= dims[i])
1578 <<
"_dims values must be strictly increasing";
1584 int64_t sourceRank = getSourceType().getRank();
1587 getIndicesType().
getShape(), sourceRank,
1588 "gather",
"source")))
1591 RankedTensorType expectedResultType = GatherOp::inferResultType(
1592 getSourceType(), getIndicesType(), gatherDims,
false);
1593 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1594 getSourceType(), getIndicesType(), gatherDims,
true);
1595 if (getResultType() != expectedResultType &&
1596 getResultType() != expectedRankReducedResultType) {
1597 return emitOpError(
"result type "
1600 << expectedResultType <<
" or its rank-reduced variant "
1601 << expectedRankReducedResultType <<
" (got: " << getResultType()
1609 if (
OpFoldResult reshapedSource = reshapeConstantSource(
1610 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1612 return reshapedSource;
1620 void InsertOp::getAsmResultNames(
1622 setNameFn(getResult(),
"inserted");
1627 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1628 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1629 return emitOpError(
"incorrect number of indices");
1637 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1638 if (scalar == splatDest.getSplatValue<
Attribute>())
1647 void GenerateOp::getAsmResultNames(
1649 setNameFn(getResult(),
"generated");
1656 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1657 if (
getType().isDynamicDim(dim)) {
1658 reifiedReturnShapes[0][dim] = getOperand(idx++);
1660 reifiedReturnShapes[0][dim] =
1670 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1671 if (getNumOperands() != resultType.getNumDynamicDims())
1672 return emitError(
"must have as many index operands as dynamic extents "
1673 "in the result type");
1677 LogicalResult GenerateOp::verifyRegions() {
1678 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1680 if (!llvm::all_of(getBody().getArgumentTypes(),
1682 return emitError(
"all body arguments must be index");
1683 if (getBody().getNumArguments() != resultTy.getRank())
1684 return emitError(
"must have one body argument per input dimension");
1687 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1689 if (yieldOp.getValue().getType() != resultTy.getElementType())
1691 "body must be terminated with a `yield` operation of the tensor "
1697 void GenerateOp::build(
1701 build(b, result, resultTy, dynamicExtents);
1706 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1710 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1723 LogicalResult matchAndRewrite(GenerateOp generateOp,
1727 generateOp.getType(), generateOp.getDynamicExtents(),
1728 foldedDynamicSizes);
1731 if (foldedTensorType == generateOp.getType())
1734 auto loc = generateOp.getLoc();
1736 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1738 newOp.getBody().begin());
1740 generateOp.getType(), newOp);
1756 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1759 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1761 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1766 Block *body = &tensorFromElements.getBody().
front();
1769 rewriter.
clone(op, mapping);
1783 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1790 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1791 setNameFn(getResult(),
"rank");
1796 auto type = getOperand().getType();
1797 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1798 if (shapedType && shapedType.hasRank())
1800 return IntegerAttr();
1807 void ReshapeOp::getAsmResultNames(
1809 setNameFn(getResult(),
"reshape");
1813 int64_t numElements = 1;
1814 for (
auto dim : type.getShape())
1824 return emitOpError(
"element types of source and destination tensor "
1825 "types should be the same");
1829 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1830 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1832 if (resultRankedType) {
1833 if (operandRankedType && resultRankedType.hasStaticShape() &&
1834 operandRankedType.hasStaticShape()) {
1836 return emitOpError(
"source and destination tensor should have the "
1837 "same number of elements");
1839 if (ShapedType::isDynamic(shapeSize))
1840 return emitOpError(
"cannot use shape operand with dynamic length to "
1841 "reshape to statically-ranked tensor type");
1842 if (shapeSize != resultRankedType.getRank())
1844 "length of shape operand differs from the result's tensor rank");
1850 if (
OpFoldResult reshapedSource = reshapeConstantSource(
1851 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1853 return reshapedSource;
1858 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1859 getSourceMutable().assign(reshapeOpProducer.getSource());
1863 auto source = getSource();
1864 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1865 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1866 if (!sourceTy || !resultTy || sourceTy != resultTy)
1871 if (sourceTy.getRank() <= 1)
1874 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1875 auto elements = fromElements.getElements();
1877 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1878 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++id) {
1879 auto element = elements[id];
1882 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1886 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1887 dynamicNoop &= dimOp.getSource() == source;
1891 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1895 dynamicNoop =
false;
1910 void CollapseShapeOp::getAsmResultNames(
1912 setNameFn(getResult(),
"collapsed");
1915 void ExpandShapeOp::getAsmResultNames(
1917 setNameFn(getResult(),
"expanded");
1920 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1921 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1922 "invalid resultDim");
1924 if (llvm::is_contained(it.value(), resultDim))
1926 llvm_unreachable(
"could not find reassociation group");
1929 FailureOr<SmallVector<OpFoldResult>>
1931 RankedTensorType expandedType,
1934 std::optional<SmallVector<OpFoldResult>> outputShape =
1939 return *outputShape;
1950 auto [staticOutputShape, dynamicOutputShape] =
1952 build(builder, result, cast<RankedTensorType>(resultType), src,
1954 dynamicOutputShape, staticOutputShape);
1962 auto tensorResultTy = cast<RankedTensorType>(resultType);
1963 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1964 builder, result.
location, tensorResultTy, reassociation, inputShape);
1966 if (succeeded(outputShape)) {
1967 outputShapeOrEmpty = *outputShape;
1969 build(builder, result, tensorResultTy, src, reassociation,
1970 outputShapeOrEmpty);
1978 getReassociationIndices());
1986 getReassociationIndices());
1989 RankedTensorType CollapseShapeOp::inferCollapsedType(
1991 return inferCollapsedType(
1993 type.getContext(), reassociation)));
1999 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2001 auto shape = type.getShape();
2003 newShape.reserve(reassociation.size());
2008 unsigned currentDim = 0;
2010 unsigned dim = m.getNumResults();
2011 auto band = shape.slice(currentDim, dim);
2013 if (llvm::is_contained(band, ShapedType::kDynamic))
2014 size = ShapedType::kDynamic;
2016 for (
unsigned d = 0; d < dim; ++d)
2017 size *= shape[currentDim + d];
2018 newShape.push_back(size);
2028 auto resultType = inferCollapsedType(
2029 llvm::cast<RankedTensorType>(src.
getType()),
2034 build(b, result, resultType, src, attrs);
2037 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
2038 TensorReshapeOp, ExpandShapeOp>::value>
2040 RankedTensorType expandedType,
2041 RankedTensorType collapsedType) {
2046 auto maps = op.getReassociationMaps();
2047 RankedTensorType expectedType =
2048 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2050 return op.emitOpError(
"expected collapsed type to be ")
2051 << expectedType <<
", but got " << collapsedType;
2056 auto srcType = getSrcType();
2057 auto resultType = getResultType();
2059 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2060 return emitOpError(
"expected number of static shape dims to be equal to "
2061 "the output rank (")
2062 << resultType.getRank() <<
") but found "
2063 << getStaticOutputShape().size() <<
" inputs instead";
2065 if ((int64_t)getOutputShape().size() !=
2066 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2067 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2068 "static_output_shape: static_output_shape has ")
2069 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2070 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2083 template <
typename TensorReshapeOp>
2086 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2094 reshapeOp.getResultType(), attr.
getRawData());
2101 template <
typename TensorReshapeOp>
2106 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2108 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2109 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2113 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2120 template <
typename TensorReshapeOp>
2123 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2126 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2130 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2132 if (!shapedTy.hasStaticShape())
2136 fromElements.getElements());
2145 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2147 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2151 RankedTensorType srcType =
2152 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2153 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2154 srcType, collapseShapeOp.getReassociationMaps());
2156 if (newResultType == collapseShapeOp.getResultType()) {
2158 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2161 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2162 newResultType, castOp.getSource(),
2163 collapseShapeOp.getReassociation());
2165 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2175 struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2178 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2180 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2186 expandOp.getReassociationIndices();
2190 auto outputIt = expandOp.getOutputShape().begin();
2192 for (
const auto &[inputDim, innerReassoc] :
llvm::enumerate(reassoc)) {
2193 for (uint64_t outDim : innerReassoc) {
2194 if (ShapedType::isStatic(newOutputShape[outDim]))
2201 Value val = *outputIt;
2203 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2204 dynamicOutputShape.push_back(val);
2210 newOutputShape[outDim] = cst.getSExtValue();
2212 dynamicOutputShape.push_back(val);
2218 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2223 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2224 for (
auto outDim : reassoc[inDim]) {
2225 auto ofr = newOutputShape[outDim];
2226 if (ShapedType::isDynamic(ofr)) {
2227 newInputShape[inDim] = ShapedType::kDynamic;
2230 newInputShape[inDim] *= ofr;
2237 newInputShape, expandOp.getSrcType().getElementType());
2239 newOutputShape, expandOp.getSrcType().getElementType());
2240 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2242 auto newExpand = ExpandShapeOp::create(
2243 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2244 expandOp.getReassociationIndices(), outputOfr);
2246 newExpand.getResult());
2257 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2258 FoldReshapeWithSplat<ExpandShapeOp>,
2259 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2267 tensor::DimOp, RankedTensorType>,
2268 FoldReshapeWithConstant<CollapseShapeOp>,
2269 FoldReshapeWithSplat<CollapseShapeOp>,
2270 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2274 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2275 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2276 adaptor.getOperands());
2279 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2280 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2281 adaptor.getOperands());
2288 void ExtractSliceOp::getAsmResultNames(
2290 setNameFn(getResult(),
"extracted_slice");
2296 RankedTensorType ExtractSliceOp::inferResultType(
2302 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2303 sourceTensorType.getRank() &&
2304 "unexpected staticSizes not equal to rank of source");
2306 sourceTensorType.getEncoding());
2310 RankedTensorType ExtractSliceOp::inferResultType(
2315 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2316 sourceTensorType.getRank() &&
2317 "unexpected staticSizes not equal to rank of source");
2319 sourceTensorType.getEncoding());
2330 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2331 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2335 auto inferredType = llvm::cast<RankedTensorType>(
2336 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2337 int rankDiff = inferredType.getRank() - desiredResultRank;
2339 auto shape = inferredType.getShape();
2340 llvm::SmallBitVector dimsToProject =
2344 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2345 if (!dimsToProject.test(pos))
2346 projectedShape.push_back(shape[pos]);
2350 return inferredType;
2353 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2354 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2362 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2363 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2370 RankedTensorType resultType,
Value source,
2380 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2383 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2384 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2387 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2400 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2409 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2415 RankedTensorType resultType,
Value source,
2424 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2431 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2436 RankedTensorType expectedType) {
2441 return op->
emitError(
"expected rank to be smaller or equal to ")
2442 <<
"the other rank. ";
2444 return op->
emitError(
"expected type to be ")
2445 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2447 return op->
emitError(
"expected element type to be ")
2448 << expectedType.getElementType();
2450 llvm_unreachable(
"unexpected extract_slice op verification result");
2456 RankedTensorType sourceType = getSourceType();
2459 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2460 sourceType, getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2468 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2469 getStaticStrides(),
true);
2471 return getOperation()->emitError(boundsResult.
errorMessage);
2483 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2484 assert(sourceTensorType &&
"not a ranked tensor type");
2485 auto sourceShape = sourceTensorType.getShape();
2486 if (sourceShape.equals(desiredShape))
2488 auto maybeRankReductionMask =
2490 if (!maybeRankReductionMask)
2499 reifiedReturnShapes.resize(1);
2500 reifiedReturnShapes[0].reserve(
getType().getRank());
2503 for (
const auto &size :
enumerate(mixedSizes)) {
2504 if (droppedDims.test(size.index()))
2506 reifiedReturnShapes[0].push_back(size.value());
2527 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2531 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2534 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
2535 return matchPattern(operand, matchConstantIndex());
2539 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2548 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2549 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2550 sliceOp.getStaticStrides());
2556 Value newResult = ExtractSliceOp::create(
2557 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2558 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2559 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2560 sliceOp.getStaticStrides());
2569 template <
typename IterTy,
typename ElemTy>
2574 assert(offsets.size() == sizes.size());
2575 assert(offsets.size() == strides.size());
2576 if (offsets.empty())
2579 int64_t offset = offsets.front();
2580 int64_t size = sizes.front();
2581 int64_t stride = strides.front();
2582 if (offsets.size() == 1) {
2583 for (int64_t i = 0; i < size; ++i, offset += stride)
2584 outValues->push_back(*(values + offset));
2589 for (int64_t i = 0; i < size; ++i, offset += stride) {
2590 auto begin = values + offset * counts.front();
2591 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2592 offsets.drop_front(), sizes.drop_front(),
2593 strides.drop_front(), outValues);
2600 class ConstantOpExtractSliceFolder final
2605 ConstantOpExtractSliceFolder(
MLIRContext *context,
2608 controlFn(std::move(controlFn)) {}
2610 LogicalResult matchAndRewrite(ExtractSliceOp op,
2621 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2622 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2623 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2630 int64_t count = sourceType.getNumElements();
2635 auto offsets = op.getStaticOffsets();
2636 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2638 auto sizes = op.getStaticSizes();
2639 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2641 auto strides = op.getStaticStrides();
2642 if (llvm::is_contained(strides, ShapedType::kDynamic))
2648 counts.reserve(shape.size());
2649 for (int64_t v : shape) {
2651 counts.push_back(count);
2657 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2659 outValues.reserve(sourceType.getNumElements());
2660 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2661 elems.begin(), counts, offsets, sizes, strides, &outValues);
2663 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2665 outValues.reserve(sourceType.getNumElements());
2666 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2667 elems.begin(), counts, offsets, sizes, strides, &outValues);
2690 patterns.add<ConstantOpExtractSliceFolder>(
patterns.getContext(), controlFn);
2699 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2700 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2708 ExtractSliceOp newOp) {
2709 Value replacement = newOp.getResult();
2710 if (replacement.
getType() != op.getType())
2711 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.
getType(),
2722 ExtractSliceOpCastFolder>(context);
2726 static LogicalResult
2728 ShapedType shapedType) {
2735 auto shape = shapedType.getShape();
2736 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2750 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2753 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2754 insertOp.isSameAs(extractOp, isSame))
2755 return insertOp.getSource();
2760 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2761 if (
OpFoldResult reshapedSource = reshapeConstantSource(
2762 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2764 return reshapedSource;
2765 if (getSourceType() ==
getType() &&
2767 return this->getSource();
2776 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.
getType());
2777 unsigned rank = rankedTensorType.getRank();
2781 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2782 offsets, sizes, strides);
2789 void InsertSliceOp::getAsmResultNames(
2791 setNameFn(getResult(),
"inserted_slice");
2806 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2818 build(b, result, source, dest, offsets, sizes, strides, attrs);
2831 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2837 RankedTensorType srcType, RankedTensorType dstType,
2842 RankedTensorType expected = ExtractSliceOp::inferResultType(
2843 dstType, staticOffsets, staticSizes, staticStrides);
2845 *expectedType = expected;
2852 RankedTensorType expectedType;
2855 getStaticSizes(), getStaticStrides(), &expectedType);
2862 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2863 getStaticStrides(),
true);
2865 return getOperation()->emitError(boundsResult.
errorMessage);
2888 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2891 if (!prevInsertOp ||
2892 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2893 !prevInsertOp.isSameAs(insertOp, isSame))
2896 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2908 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2911 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2912 !extractOp.isSameAs(insertOp, isSame))
2915 return extractOp.getSource();
2919 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2920 getSourceType() ==
getType() &&
2922 return this->getSource();
2943 template <
typename InsertOpTy>
2944 class InsertSliceOpConstantArgumentFolder final
2949 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2964 mixedOffsets, mixedSizes, mixedStrides);
2969 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2970 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2971 mixedOffsets, mixedSizes, mixedStrides);
2972 Value toInsert = insertSliceOp.getSource();
2973 if (sourceType != insertSliceOp.getSourceType()) {
2978 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2980 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2981 sourceType, toInsert);
2984 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2985 mixedSizes, mixedStrides);
3010 template <
typename InsertOpTy>
3011 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
3014 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3016 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
3017 return matchPattern(operand, matchConstantIndex());
3021 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
3022 auto castOp = v.getDefiningOp<tensor::CastOp>();
3024 return std::nullopt;
3025 return castOp.getSource();
3027 std::optional<Value> sourceCastSource =
3028 getSourceOfCastOp(insertSliceOp.getSource());
3029 std::optional<Value> destCastSource =
3030 getSourceOfCastOp(insertSliceOp.getDest());
3031 if (!sourceCastSource && !destCastSource)
3035 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3036 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3037 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
3038 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3039 if (!srcType || !dstType)
3047 staticSizes, srcType.getShape(),
true);
3048 if (!rankReductionMask.has_value())
3056 int64_t rankReducedIdx = 0;
3057 for (
auto [idx, size] :
enumerate(staticSizes)) {
3058 if (!rankReductionMask.value().contains(idx) &&
3059 !srcType.isDynamicDim(rankReducedIdx)) {
3061 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
3062 size = srcType.getDimSize(rankReducedIdx++);
3068 staticSizes, insertSliceOp.getStaticStrides()) !=
3073 mixedSizes, insertSliceOp.getMixedStrides());
3078 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3079 insertSliceOp.getMixedOffsets(), mixedSizes,
3080 insertSliceOp.getMixedStrides());
3083 bool isParallelInsert =
3084 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3085 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3086 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3087 insertSliceOp.getDestType(),
3116 template <
typename InsertOpTy>
3117 struct InsertSliceOpSourceCastInserter final
3121 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3123 RankedTensorType srcType = insertSliceOp.getSourceType();
3124 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3127 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3128 if (std::optional<int64_t> constInt =
3133 newSrcShape[i] = *constInt;
3140 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3141 if (srcType == newSrcType ||
3143 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3155 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3157 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3158 newSrcType, insertSliceOp.getSource());
3160 insertSliceOp, cast, insertSliceOp.getDest(),
3161 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3162 insertSliceOp.getMixedStrides());
3174 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3175 InsertSliceOpCastFolder<InsertSliceOp>,
3176 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3183 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3184 unsigned rank = rankedTensorType.getRank();
3188 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3197 setNameFn(getResult(),
"padded");
3201 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3202 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3204 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3205 if (!expectedType) {
3206 return emitError(
"failed to infer expectedType from sourceType ")
3207 << sourceType <<
", specified resultType is " << resultType;
3209 if (resultType.getRank() != expectedType.getRank()) {
3211 << resultType <<
" does not match the inferred type "
3214 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3215 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3217 if (expectedType.isDynamicDim(i))
3220 << resultType <<
" does not match the inferred type "
3227 LogicalResult PadOp::verifyRegions() {
3228 auto ®ion = getRegion();
3229 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3232 return emitError(
"expected the block to have ") << rank <<
" arguments";
3236 if (!en.value().isIndex())
3237 return emitOpError(
"expected block argument ")
3238 << (en.index() + 1) <<
" to be an index";
3243 if (yieldOp.getValue().getType() !=
3245 return emitOpError(
"expected yield type to match shape element type");
3250 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3254 unsigned rank = sourceType.getRank();
3255 if (staticLow.size() != rank)
3256 return RankedTensorType();
3257 if (staticHigh.size() != rank)
3258 return RankedTensorType();
3259 if (!resultShape.empty() && resultShape.size() != rank)
3260 return RankedTensorType();
3263 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3264 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3265 staticHigh[i] == ShapedType::kDynamic) {
3266 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3269 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3270 assert((resultShape.empty() || size == resultShape[i] ||
3271 resultShape[i] == ShapedType::kDynamic) &&
3272 "mismatch between inferred shape and result shape");
3273 inferredShape.push_back(size);
3284 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3286 resultType = inferResultType(sourceType, staticLow, staticHigh);
3288 build(b, result, resultType, source, low, high,
3296 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3297 unsigned rank = sourceType.getRank();
3299 build(b, result, resultType, source, staticVector, staticVector, low, high,
3307 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3317 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3319 assert(llvm::isa<RankedTensorType>(resultType));
3321 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3330 build(b, result, resultType, source, low, high, nofold, attrs);
3334 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3341 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3342 tensor::YieldOp::create(b, result.
location, constantPadValue);
3345 llvm::SmallBitVector PadOp::getPaddedDims() {
3346 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3348 for (
const auto &en :
enumerate(paddingWidths))
3350 paddedDims.set(en.index());
3352 extractPaddedDims(getMixedLowPad());
3353 extractPaddedDims(getMixedHighPad());
3363 LogicalResult matchAndRewrite(PadOp padTensorOp,
3365 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3367 if (padTensorOp.getNofold())
3370 padTensorOp, padTensorOp.getResult().getType(),
3371 padTensorOp.getSource());
3380 LogicalResult matchAndRewrite(PadOp padTensorOp,
3382 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3386 auto newResultType = PadOp::inferResultType(
3387 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3388 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3389 padTensorOp.getResultType().getShape());
3391 if (newResultType == padTensorOp.getResultType()) {
3393 padTensorOp.getSourceMutable().assign(castOp.getSource());
3396 auto newOp = PadOp::create(
3397 rewriter, padTensorOp->getLoc(), newResultType,
3398 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3399 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3400 padTensorOp.getHigh(), padTensorOp.getNofold(),
3403 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3406 padTensorOp, padTensorOp.getResultType(), newOp);
3417 LogicalResult matchAndRewrite(PadOp padTensorOp,
3419 if (!padTensorOp.getResult().hasOneUse())
3422 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3426 tensorCastOp.getDest().getType()))
3429 auto replacementOp = PadOp::create(
3430 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3431 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3432 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3433 padTensorOp.getHigh(), padTensorOp.getNofold(),
3435 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3437 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3438 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3481 LogicalResult matchAndRewrite(PadOp padOp,
3483 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3486 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3487 if (!outerPadOp || outerPadOp.getNofold())
3489 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3494 int64_t rank = padOp.getSourceType().getRank();
3495 if (outerSliceOp.getSourceType().getRank() != rank) {
3497 "cannot fold rank-reducing chain");
3501 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3503 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3507 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3509 "cannot fold PadOps with low padding");
3514 Value innerValue = padOp.getConstantPaddingValue();
3515 Value outerValue = outerPadOp.getConstantPaddingValue();
3516 if (!innerValue || !outerValue ||
3519 innerAttr != outerAttr) {
3521 padOp,
"cannot fold PadOps with different padding values");
3525 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3526 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3527 if (innerDims.anyCommon(outerDims)) {
3529 padOp,
"cannot fold PadOps with common padding dimensions");
3539 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3540 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3541 if (!innerDims.test(en.index()) &&
3543 en.value() = outerOffset;
3546 if (!outerDims.test(en.index()) &&
3548 en.value() = innerOffset;
3552 padOp,
"cannot find zero-offset and zero-padding pair");
3562 if (!outerDims.test(en.index()))
3564 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3565 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3566 assert(ShapedType::isStatic(sourceSize) &&
3567 "expected padded dimension to have a static size");
3570 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3571 "match the size of the outer padding");
3573 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3579 if (innerDims.test(en.index()))
3580 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3581 if (outerDims.test(en.index()))
3582 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3587 auto newSliceOp = ExtractSliceOp::create(
3588 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3589 newSizes, innerSliceOp.getMixedStrides());
3590 auto newPadOp = PadOp::create(
3591 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3592 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3595 newPadOp.getRegion().begin());
3596 rewriter.
replaceOp(padOp, newPadOp.getResult());
3604 LogicalResult matchAndRewrite(PadOp padTensorOp,
3606 Value input = padTensorOp.getSource();
3607 if (!llvm::isa<RankedTensorType>(input.
getType()))
3609 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3610 auto inputRank = inputDims.size();
3612 auto oldResultType =
3613 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3617 auto outputDims = oldResultType.getShape();
3622 for (
auto operand : padTensorOp.getLow()) {
3625 constOperandsLow.push_back(ShapedType::kDynamic);
3626 newLows.push_back(operand);
3629 constOperandsLow.push_back(intOp.getExtValue());
3633 for (
auto operand : padTensorOp.getHigh()) {
3636 constOperandsHigh.push_back(ShapedType::kDynamic);
3637 newHighs.push_back(operand);
3640 constOperandsHigh.push_back(intOp.getExtValue());
3647 if (inputDims.size() != outputDims.size() ||
3648 inputDims.size() != constLow.size() ||
3649 inputDims.size() != constHigh.size())
3654 for (
size_t i = 0; i < inputRank; i++) {
3655 if (constLow[i] == ShapedType::kDynamic)
3656 constLow[i] = constOperandsLow[lowCount++];
3657 if (constHigh[i] == ShapedType::kDynamic)
3658 constHigh[i] = constOperandsHigh[highCount++];
3666 for (
size_t i = 0; i < inputRank; i++) {
3667 if (outputDims[i] == ShapedType::kDynamic) {
3668 newOutDims.push_back(
3669 (staticLow[i] == ShapedType::kDynamic ||
3670 staticHigh[i] == ShapedType::kDynamic ||
3671 inputDims[i] == ShapedType::kDynamic
3672 ? ShapedType::kDynamic
3673 : inputDims[i] + staticLow[i] + staticHigh[i]));
3675 newOutDims.push_back(outputDims[i]);
3680 llvm::all_of(newOutDims,
3681 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3686 newOutDims, padTensorOp.getType().getElementType());
3687 auto newOp = PadOp::create(
3688 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3689 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3693 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3721 struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3724 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3726 if (padOp.getNofold()) {
3730 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3731 if (!producerPad || producerPad.getNofold()) {
3733 padOp,
"producer is not a foldable tensor.pad op");
3737 Value consumerPadValue = padOp.getConstantPaddingValue();
3738 Value producerPadValue = producerPad.getConstantPaddingValue();
3739 if (!consumerPadValue || !producerPadValue ||
3740 consumerPadValue != producerPadValue) {
3743 "cannot fold PadOps with different or non-constant padding values");
3754 for (
auto [consumerIndex, producerIndex] :
3755 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3757 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3763 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3765 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3767 auto newPadOp = tensor::PadOp::create(
3768 rewriter, padOp.getLoc(), padOp.getResultType(),
3769 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3772 newPadOp.getRegion().begin());
3773 rewriter.
replaceOp(padOp, newPadOp.getResult());
3786 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3787 if (!
getType().isDynamicDim(i)) {
3798 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3805 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3806 FoldOrthogonalPaddings, FoldStaticPadding,
3807 FoldConsecutiveConstantPadding>(context);
3819 Value PadOp::getConstantPaddingValue() {
3820 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3823 Value padValue = yieldOp.getValue();
3835 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3845 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3846 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3847 for (
const auto &it :
3850 if (&nextOp == getOperation())
3851 return parallelCombiningParent.getParentResult(it.index());
3853 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3869 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3882 build(b, result, source, dest, offsets, sizes, strides, attrs);
3896 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3900 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3901 return this->
emitError(
"expected InParallelOpInterface parent, got:")
3902 << *(getOperation()->getParentOp());
3905 RankedTensorType expectedType;
3908 getStaticSizes(), getStaticStrides(), &expectedType);
3915 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3916 getStaticStrides(),
true);
3918 return getOperation()->emitError(boundsResult.
errorMessage);
3923 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3925 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3926 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3927 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3936 return getDestMutable();
3939 Operation *ParallelInsertSliceOp::getIteratingParent() {
3941 if (
auto combiningOp =
3942 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3951 void ScatterOp::getAsmResultNames(
3953 setNameFn(getResult(),
"scatter");
3957 int64_t destRank = getDestType().getRank();
3960 getIndicesType().
getShape(), destRank,
3961 "scatter",
"dest")))
3965 return emitOpError(
"requires 'unique' attribute to be set");
3972 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3973 getDestType(), getIndicesType(), scatterDims,
false);
3974 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3975 getDestType(), getIndicesType(), scatterDims,
true);
3976 if (getSourceType() != expectedSourceType &&
3977 getSourceType() != expectedRankReducedSourceType) {
3978 return emitOpError(
"source type "
3981 << expectedSourceType <<
" or its rank-reduced variant "
3982 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3995 build(builder, result, aggregateType, element, dynamicSizes);
4001 build(builder, result, aggregateType, element, dynamicSizes);
4009 build(builder, result, element, staticShape, dynamicSizes);
4012 void SplatOp::getAsmResultNames(
4014 setNameFn(getResult(),
"splat");
4019 return emitOpError(
"incorrect number of dynamic sizes, has ")
4021 <<
getType().getNumDynamicDims();
4030 for (int64_t i = 0; i <
getType().getRank(); ++i) {
4031 if (
getType().isDynamicDim(i)) {
4041 auto constOperand = adaptor.getInput();
4042 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4046 if (!
getType().hasStaticShape())
4061 if (isa<InsertSliceOp>(op.getOperation()) ||
4062 isa<LoopLikeOpInterface>(op.getOperation()))
4095 isa<linalg::RelayoutOpInterface>(*op))
4103 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4106 replacements.reserve(newOp->getNumResults());
4107 for (
auto [oldResult, newResult] :
4108 llvm::zip(op->getResults(), newOp->getResults())) {
4109 if (newResult.getType() != oldResult.getType()) {
4110 replacements.push_back(tensor::CastOp::create(
4111 rewriter, op->getLoc(), oldResult.
getType(), newResult));
4113 replacements.push_back(newResult);
4126 void TensorDialect::getCanonicalizationPatterns(
4135 #define GET_OP_CLASSES
4136 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static int64_t getNumElements(ShapedType type)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
Result for slice bounds verification;.
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.