31 #include "llvm/ADT/DenseSet.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallBitVector.h"
34 #include "llvm/ADT/StringRef.h"
35 #include "llvm/Support/MathExtras.h"
42 using llvm::divideCeilSigned;
43 using llvm::divideFloorSigned;
51 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
53 if (complex::ConstantOp::isBuildableWith(value, type))
54 return builder.
create<complex::ConstantOp>(loc, type,
55 llvm::cast<ArrayAttr>(value));
61 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
63 if (tensorType.isDynamicDim(dim))
64 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
71 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
73 for (int64_t i = 0; i < tensorType.getRank(); ++i)
80 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
81 assert(tensorType &&
"expected tensor type");
85 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
87 return destOp.getTiedOpOperand(opResult)->get();
95 if (!tensorType.hasStaticShape()) {
103 for (int64_t sz : tensorType.getShape())
109 b.
create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
117 if (llvm::isa<TensorType>(opResult.getType())) {
119 if (failed(destination))
121 result.push_back(*destination);
128 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
129 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
130 return rtp1.getShape() == rtp2.getShape() &&
131 rtp1.getElementType() == rtp2.getElementType();
141 llvm::SmallBitVector droppedDims(mixedSizes.size());
142 int64_t shapePos = reducedShape.size() - 1;
144 for (
const auto &size :
enumerate(llvm::reverse(mixedSizes))) {
145 size_t idx = mixedSizes.size() - size.index() - 1;
147 bool isStaticUnitSize =
148 isa<Attribute>(size.value()) &&
149 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
154 assert(isStaticUnitSize &&
"expected unit dim");
155 droppedDims.set(idx);
160 if (!isStaticUnitSize) {
166 if (reducedShape[shapePos] == 1) {
172 droppedDims.set(idx);
175 assert(shapePos < 0 &&
"dimension mismatch");
182 static RankedTensorType
186 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
187 "incorrect number of dynamic sizes");
191 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
192 if (type.isDynamicDim(i)) {
193 Value dynamicSize = dynamicSizes[ctr++];
195 if (cst.has_value()) {
197 if (cst.value() < 0) {
198 foldedDynamicSizes.push_back(dynamicSize);
201 staticShape[i] = *cst;
203 foldedDynamicSizes.push_back(dynamicSize);
217 if (inputs.size() != 1 || outputs.size() != 1)
219 Type a = inputs.front(), b = outputs.front();
220 auto aT = dyn_cast<TensorType>(a);
221 auto bT = dyn_cast<TensorType>(b);
225 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
238 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
240 auto tensorBitcastOperand =
241 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
242 if (!tensorBitcastOperand)
245 auto resultType = cast<TensorType>(tensorBitcast.getType());
246 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
247 tensorBitcastOperand.getOperand());
256 results.
add<ChainedTensorBitcast>(context);
264 setNameFn(getResult(),
"cast");
270 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
271 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
274 if (!sourceType || !targetType)
278 if (sourceType.getElementType() != targetType.getElementType())
282 if (sourceType.getRank() != targetType.getRank())
286 if (sourceType.getEncoding() != targetType.getEncoding())
290 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
291 if (!ShapedType::isDynamic(std::get<0>(t)) &&
292 ShapedType::isDynamic(std::get<1>(t)))
328 castOp.getSource().getType());
360 if (llvm::isa<BlockArgument>(opOperand.get()))
362 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
363 return castOp && canFoldIntoConsumerOp(castOp);
370 newOperands.reserve(op->getNumOperands());
375 int64_t dpsInitIdx = 0;
376 for (
OpOperand &opOperand : op->getOpOperands()) {
377 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
379 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
380 if (op.isDpsInit(&opOperand) &&
381 !llvm::isa<MemRefType>(newOperands.back().getType()))
382 newResTy[dpsInitIdx++] = newOperands.back().getType();
392 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
394 operand.set(castOp.getOperand());
398 return success(folded);
402 if (inputs.size() != 1 || outputs.size() != 1)
404 Type a = inputs.front(), b = outputs.front();
405 auto aT = llvm::dyn_cast<TensorType>(a);
406 auto bT = llvm::dyn_cast<TensorType>(b);
410 if (aT.getElementType() != bT.getElementType())
426 int64_t rank = one.getRank();
427 if (rank != two.getRank())
432 for (int64_t i = 0; i < rank; ++i) {
433 if (one.isDynamicDim(i)) {
434 join.push_back(two.getDimSize(i));
437 if (two.isDynamicDim(i)) {
438 join.push_back(one.getDimSize(i));
441 if (one.getDimSize(i) != two.getDimSize(i))
443 join.push_back(one.getDimSize(i));
455 LogicalResult matchAndRewrite(CastOp tensorCast,
457 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
459 if (!tensorCastOperand)
463 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
464 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
465 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
479 auto newJoin =
joinShapes(sourceType, resultType);
480 if (firstJoin != newJoin)
483 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
484 tensorCastOperand.getOperand());
504 LogicalResult matchAndRewrite(CastOp tensorCast,
506 auto extractOperand =
507 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
510 auto rankedResultType =
511 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
512 if (!rankedResultType)
516 rankedResultType.getShape() ==
517 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
523 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
525 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
526 if (dimMask && dimMask->count(i))
528 int64_t dim = rankedResultType.getShape()[dimIndex++];
529 if (ShapedType::isDynamic(dim))
531 sizes[i] = rewriter.getIndexAttr(dim);
534 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
535 tensorCast, rankedResultType, extractOperand.getSource(),
536 extractOperand.getMixedOffsets(), sizes,
537 extractOperand.getMixedStrides());
546 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
553 RankedTensorType ConcatOp::inferResultType(int64_t dim,
TypeRange inputTypes) {
554 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
556 llvm::to_vector<4>(llvm::map_range(inputTypes, [](
Type type) {
557 return llvm::cast<RankedTensorType>(type);
559 int64_t concatRank = tensorTypes[0].getRank();
562 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
565 for (int64_t i = 0, e = concatRank; i < e; ++i) {
569 for (
auto tensorType : tensorTypes)
574 for (
auto tensorType : tensorTypes)
577 sizes[dim] = concatSize.asInteger();
583 FailureOr<RankedTensorType> resultType =
584 inferResultType(dim, inputs.
getTypes());
585 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
586 build(builder, result, *resultType, dim, inputs);
590 if (getInputs().size() < 1)
591 return emitOpError(
"requires at least one input");
594 for (
auto input : getInputs())
595 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
597 RankedTensorType resultType = getResultType();
598 int64_t resultRank = getRank();
599 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
600 return type.getRank() != resultRank;
602 return emitOpError(
"rank of concatenated inputs must match result rank");
604 Type resultElementType = resultType.getElementType();
605 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
606 return type.getElementType() != resultElementType;
608 return emitOpError(
"inputs and result element type must match");
610 int64_t dim = getDim();
611 if (dim >= resultRank)
612 return emitOpError(
"concatenation dim must be less than the tensor rank");
615 for (int64_t i = 0, e = resultRank; i < e; ++i) {
619 for (
auto tensorType : inputTypes) {
620 FailureOr<SaturatedInteger> maybeSize =
622 if (failed(maybeSize))
623 return emitOpError(
"static concatenation size mismatch along ")
624 <<
"non-concatenated dimension " << i;
630 for (
auto tensorType : inputTypes)
633 sizes[dim] = concatSize.asInteger();
634 auto inferredResultType =
637 for (
auto [inferredSize, actualSize] :
638 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
639 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
640 ShapedType::isDynamic(actualSize);
641 if (!hasDynamic && inferredSize != actualSize)
642 return emitOpError(
"result type ")
643 << resultType <<
"does not match inferred shape "
644 << inferredResultType <<
" static sizes";
650 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
651 size_t numInputs = getInputs().size();
652 uint64_t concatDim = getDim();
655 inputShapes.reserve(numInputs);
657 concatOffsets.reserve(numInputs);
668 outputShape = inputShape;
669 concatOffsets.push_back(zero);
671 concatOffsets.push_back(outputShape[concatDim]);
673 builder, loc, addExpr,
674 {outputShape[concatDim], inputShape[concatDim]});
676 inputShapes.emplace_back(std::move(inputShape));
679 Value replacement = builder.
create<tensor::EmptyOp>(
680 loc, outputShape,
getType().getElementType());
682 int64_t rank =
getType().getRank();
687 offsets[concatDim] = concatOffsets[index];
688 auto insertSlice = builder.
create<tensor::InsertSliceOp>(
689 loc, input, replacement, offsets, inputShapes[index], strides);
692 if (replacement.getType() !=
getType()) {
693 replacement = builder.
create<tensor::CastOp>(loc,
getType(), replacement);
702 int64_t dim = getDim();
703 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
705 Value init = inputs[0];
706 int64_t rank =
getType().getRank();
713 for (int64_t i = 0; i < rank; ++i) {
716 if (!
getType().isDynamicDim(i)) {
718 }
else if (!inferredResultType.isDynamicDim(i)) {
721 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
723 reifiedReturnShapes[0][i] =
724 builder.
create<tensor::DimOp>(init.
getLoc(), init, i).getResult();
728 if (
getType().isDynamicDim(dim)) {
736 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
744 reifiedReturnShapes[0][dim] =
750 void ConcatOp::getAsmResultNames(
752 setNameFn(getResult(),
"concat");
757 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
767 LogicalResult matchAndRewrite(ConcatOp concatOp,
769 if (concatOp.getInputs().size() != 1)
772 concatOp.getInputs()[0]);
780 results.
add<SingleInputConcatOp>(context);
788 setNameFn(getResult(),
"dim");
794 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
795 build(builder, result, source, indexValue);
798 std::optional<int64_t> DimOp::getConstantIndex() {
807 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
808 if (!rankedSourceType)
819 setResultRange(getResult(),
825 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
830 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
836 int64_t indexVal = index.getInt();
837 if (indexVal < 0 || indexVal >= tensorType.getRank())
841 if (!tensorType.isDynamicDim(index.getInt())) {
843 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
846 Operation *definingOp = getSource().getDefiningOp();
849 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
851 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
854 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
857 auto dynExtents = fromElements.getDynamicExtents().begin();
858 for (
auto dim : resultType.getShape().take_front(index.getInt()))
859 if (ShapedType::isDynamic(dim))
862 return Value{*dynExtents};
866 unsigned unsignedIndex = index.getValue().getZExtValue();
868 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
871 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
872 sliceOp.isDynamicSize(unsignedIndex)) {
873 return {sliceOp.getDynamicSize(unsignedIndex)};
889 LogicalResult matchAndRewrite(DimOp dimOp,
891 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
894 Value newSource = castOp.getOperand();
905 LogicalResult matchAndRewrite(DimOp dimOp,
907 auto source = dimOp.getSource();
908 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
912 auto resultIndex = cast<OpResult>(source).getResultNumber();
913 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
916 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
926 LogicalResult matchAndRewrite(DimOp dim,
928 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
938 rewriter.
create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
939 if (extract.
getType() != dim.getType())
941 rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), extract);
950 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
960 assert(all_of(staticShape,
961 [](int64_t sz) {
return !ShapedType::isDynamic(sz); }) &&
962 "expected only static sizes");
963 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
970 build(builder, result, tensorType, dynamicSizes);
979 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
984 return emitOpError(
"incorrect number of dynamic sizes, has ")
986 <<
getType().getNumDynamicDims();
995 for (int64_t i = 0; i <
getType().getRank(); ++i) {
996 if (
getType().isDynamicDim(i)) {
1005 Value EmptyOp::getDynamicSize(
unsigned idx) {
1006 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1008 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1009 if (
getType().isDynamicDim(i))
1018 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1019 if (
getType().isDynamicDim(i)) {
1043 LogicalResult matchAndRewrite(EmptyOp op,
1047 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1050 if (foldedTensorType == op.getType())
1053 auto newOp = rewriter.
create<EmptyOp>(op.getLoc(), foldedTensorType,
1054 foldedDynamicSizes);
1063 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1065 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1066 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1067 if (!emptyTensorOp || !maybeConstantIndex)
1069 auto emptyTensorType = emptyTensorOp.getType();
1070 if (*maybeConstantIndex < 0 ||
1071 *maybeConstantIndex >= emptyTensorType.getRank() ||
1072 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1075 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1098 LogicalResult matchAndRewrite(CastOp castOp,
1102 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1107 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1111 newMixedSizes.reserve(currMixedSizes.size());
1112 assert(resultShape.size() == currMixedSizes.size() &&
1113 "mismatch in result shape and sizes of empty op");
1114 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1115 int64_t newDim = std::get<0>(it);
1119 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1120 if (ShapedType::isDynamic(newDim) ||
1121 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1126 producer,
"mismatch in static value of shape of empty tensor "
1127 "result and cast result");
1129 newMixedSizes.push_back(attr);
1135 if (!ShapedType::isDynamic(newDim)) {
1136 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1142 newMixedSizes.push_back(currDim);
1147 resultType.getElementType());
1156 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1157 ReplaceEmptyTensorStaticShapeDims>(context);
1174 struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1177 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1179 auto tensorCast = extract.getTensor().
getDefiningOp<tensor::CastOp>();
1182 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1185 extract, tensorCast.getSource(), extract.getIndices());
1192 void ExtractOp::getAsmResultNames(
1194 setNameFn(getResult(),
"extracted");
1199 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1200 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1201 return emitOpError(
"incorrect number of indices for extract_element");
1206 if (
Attribute tensor = adaptor.getTensor()) {
1209 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1210 return splatTensor.getSplatValue<
Attribute>();
1213 if (isa<DenseResourceElementsAttr>(tensor))
1219 for (
Attribute indice : adaptor.getIndices()) {
1220 if (!indice || !llvm::isa<IntegerAttr>(indice))
1222 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1226 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1227 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1228 auto rank = tensorType.getRank();
1229 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
1233 for (
int i = rank - 1; i >= 0; --i) {
1234 flatIndex += indices[i] * stride;
1235 stride *= tensorType.getDimSize(i);
1239 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1242 return fromElementsOp.getElements()[flatIndex];
1246 if (
Attribute tensor = adaptor.getTensor()) {
1247 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1248 if (elementsAttr && elementsAttr.isValidIndex(indices))
1249 return elementsAttr.getValues<
Attribute>()[indices];
1257 results.
add<ExtractFromTensorCast>(context);
1264 void FromElementsOp::getAsmResultNames(
1266 setNameFn(getResult(),
"from_elements");
1271 assert(!elements.empty() &&
"expected at least one element");
1273 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
1274 build(builder, result, resultType, elements);
1277 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1278 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1301 struct ExtractElementFromIndexCast
1305 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1308 auto indexCast = extract.getTensor().
getDefiningOp<arith::IndexCastOp>();
1314 auto newExtract = rewriter.
create<tensor::ExtractOp>(
1315 loc, elementTy, indexCast.getIn(), extract.getIndices());
1328 results.
add<ExtractElementFromIndexCast>(context);
1335 void GatherOp::getAsmResultNames(
1337 setNameFn(getResult(),
"gather");
1352 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1353 RankedTensorType indicesType,
1357 resultShape.reserve(resultShape.size() + sourceType.getRank());
1358 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1359 if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1361 resultShape.push_back(1);
1364 resultShape.push_back(sourceType.getDimSize(idx));
1369 static LogicalResult
1372 StringRef gatherOrScatter, StringRef sourceOrDest) {
1374 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1376 int64_t numGatherDims = dims.size();
1377 if (numGatherDims > rank)
1379 <<
"_dims overflow " << sourceOrDest <<
" rank";
1380 if (indices.empty() || indices.back() != numGatherDims)
1382 <<
"_dims length must match the size of last dimension of indices";
1383 for (int64_t val : dims) {
1386 <<
"_dims value must be non-negative";
1389 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1391 for (int64_t i = 1; i < numGatherDims; ++i) {
1392 if (dims[i - 1] >= dims[i])
1394 <<
"_dims values must be strictly increasing";
1400 int64_t sourceRank = getSourceType().getRank();
1403 getIndicesType().
getShape(), sourceRank,
1404 "gather",
"source")))
1407 RankedTensorType expectedResultType = GatherOp::inferResultType(
1408 getSourceType(), getIndicesType(), gatherDims,
false);
1409 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1410 getSourceType(), getIndicesType(), gatherDims,
true);
1411 if (getResultType() != expectedResultType &&
1412 getResultType() != expectedRankReducedResultType) {
1413 return emitOpError(
"result type "
1416 << expectedResultType <<
" or its rank-reduced variant "
1417 << expectedRankReducedResultType <<
" (got: " << getResultType()
1425 if (
OpFoldResult reshapedSource = reshapeConstantSource(
1426 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1428 return reshapedSource;
1436 void InsertOp::getAsmResultNames(
1438 setNameFn(getResult(),
"inserted");
1443 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1444 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1445 return emitOpError(
"incorrect number of indices");
1453 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1454 if (scalar == splatDest.getSplatValue<
Attribute>())
1463 void GenerateOp::getAsmResultNames(
1465 setNameFn(getResult(),
"generated");
1472 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1473 if (
getType().isDynamicDim(dim)) {
1474 reifiedReturnShapes[0][dim] = getOperand(idx++);
1476 reifiedReturnShapes[0][dim] =
1486 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1487 if (getNumOperands() != resultType.getNumDynamicDims())
1488 return emitError(
"must have as many index operands as dynamic extents "
1489 "in the result type");
1493 LogicalResult GenerateOp::verifyRegions() {
1494 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1496 if (!llvm::all_of(getBody().getArgumentTypes(),
1498 return emitError(
"all body arguments must be index");
1499 if (getBody().getNumArguments() != resultTy.getRank())
1500 return emitError(
"must have one body argument per input dimension");
1503 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1505 if (yieldOp.getValue().getType() != resultTy.getElementType())
1507 "body must be terminated with a `yield` operation of the tensor "
1513 void GenerateOp::build(
1517 build(b, result, resultTy, dynamicExtents);
1522 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1526 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1539 LogicalResult matchAndRewrite(GenerateOp generateOp,
1543 generateOp.getType(), generateOp.getDynamicExtents(),
1544 foldedDynamicSizes);
1547 if (foldedTensorType == generateOp.getType())
1550 auto loc = generateOp.getLoc();
1552 rewriter.
create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1554 newOp.getBody().begin());
1556 generateOp.getType(), newOp);
1572 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1575 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1577 auto tensorFromElements = extract.getTensor().
getDefiningOp<GenerateOp>();
1582 Block *body = &tensorFromElements.getBody().
front();
1585 rewriter.
clone(op, mapping);
1599 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1606 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1607 setNameFn(getResult(),
"rank");
1612 auto type = getOperand().getType();
1613 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1614 if (shapedType && shapedType.hasRank())
1616 return IntegerAttr();
1623 void ReshapeOp::getAsmResultNames(
1625 setNameFn(getResult(),
"reshape");
1629 int64_t numElements = 1;
1630 for (
auto dim : type.getShape())
1640 return emitOpError(
"element types of source and destination tensor "
1641 "types should be the same");
1645 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1646 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1648 if (resultRankedType) {
1649 if (operandRankedType && resultRankedType.hasStaticShape() &&
1650 operandRankedType.hasStaticShape()) {
1652 return emitOpError(
"source and destination tensor should have the "
1653 "same number of elements");
1655 if (ShapedType::isDynamic(shapeSize))
1656 return emitOpError(
"cannot use shape operand with dynamic length to "
1657 "reshape to statically-ranked tensor type");
1658 if (shapeSize != resultRankedType.getRank())
1660 "length of shape operand differs from the result's tensor rank");
1666 if (
OpFoldResult reshapedSource = reshapeConstantSource(
1667 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1669 return reshapedSource;
1674 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1675 getSourceMutable().assign(reshapeOpProducer.getSource());
1679 auto source = getSource();
1680 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1681 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1682 if (!sourceTy || !resultTy || sourceTy != resultTy)
1687 if (sourceTy.getRank() == 1)
1690 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1691 auto elements = fromElements.getElements();
1693 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1694 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++id) {
1695 auto element = elements[id];
1698 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1702 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1703 dynamicNoop &= dimOp.getSource() == source;
1708 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1712 dynamicNoop =
false;
1727 void CollapseShapeOp::getAsmResultNames(
1729 setNameFn(getResult(),
"collapsed");
1732 void ExpandShapeOp::getAsmResultNames(
1734 setNameFn(getResult(),
"expanded");
1737 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1738 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1739 "invalid resultDim");
1741 if (llvm::is_contained(it.value(), resultDim))
1743 llvm_unreachable(
"could not find reassociation group");
1746 FailureOr<SmallVector<OpFoldResult>>
1748 RankedTensorType expandedType,
1751 std::optional<SmallVector<OpFoldResult>> outputShape =
1756 return *outputShape;
1767 auto [staticOutputShape, dynamicOutputShape] =
1769 build(builder, result, cast<RankedTensorType>(resultType), src,
1771 dynamicOutputShape, staticOutputShape);
1779 auto tensorResultTy = cast<RankedTensorType>(resultType);
1780 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1781 builder, result.
location, tensorResultTy, reassociation, inputShape);
1783 if (succeeded(outputShape)) {
1784 outputShapeOrEmpty = *outputShape;
1786 build(builder, result, tensorResultTy, src, reassociation,
1787 outputShapeOrEmpty);
1795 getReassociationIndices());
1803 getReassociationIndices());
1806 RankedTensorType CollapseShapeOp::inferCollapsedType(
1808 return inferCollapsedType(
1810 type.getContext(), reassociation)));
1816 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1818 auto shape = type.getShape();
1820 newShape.reserve(reassociation.size());
1825 unsigned currentDim = 0;
1827 unsigned dim = m.getNumResults();
1828 auto band = shape.slice(currentDim, dim);
1830 if (llvm::is_contained(band, ShapedType::kDynamic))
1831 size = ShapedType::kDynamic;
1833 for (
unsigned d = 0; d < dim; ++d)
1834 size *= shape[currentDim + d];
1835 newShape.push_back(size);
1845 auto resultType = inferCollapsedType(
1846 llvm::cast<RankedTensorType>(src.
getType()),
1851 build(b, result, resultType, src, attrs);
1854 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
1855 TensorReshapeOp, ExpandShapeOp>::value>
1857 RankedTensorType expandedType,
1858 RankedTensorType collapsedType) {
1863 auto maps = op.getReassociationMaps();
1864 RankedTensorType expectedType =
1865 CollapseShapeOp::inferCollapsedType(expandedType, maps);
1867 return op.emitOpError(
"expected collapsed type to be ")
1868 << expectedType <<
", but got " << collapsedType;
1873 auto srcType = getSrcType();
1874 auto resultType = getResultType();
1876 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1877 return emitOpError(
"expected number of static shape dims to be equal to "
1878 "the output rank (")
1879 << resultType.getRank() <<
") but found "
1880 << getStaticOutputShape().size() <<
" inputs instead";
1882 if ((int64_t)getOutputShape().size() !=
1883 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1884 return emitOpError(
"mismatch in dynamic dims in output_shape and "
1885 "static_output_shape: static_output_shape has ")
1886 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1887 <<
" dynamic dims while output_shape has " << getOutputShape().size()
1900 template <
typename TensorReshapeOp>
1903 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1911 reshapeOp.getResultType(), attr.
getRawData());
1918 template <
typename TensorReshapeOp>
1923 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1925 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1926 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1930 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1937 template <
typename TensorReshapeOp>
1940 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1943 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1947 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1949 if (!shapedTy.hasStaticShape())
1953 fromElements.getElements());
1962 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1964 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1968 RankedTensorType srcType =
1969 llvm::cast<RankedTensorType>(castOp.getSource().getType());
1970 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1971 srcType, collapseShapeOp.getReassociationMaps());
1973 if (newResultType == collapseShapeOp.getResultType()) {
1975 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1978 auto newOp = rewriter.
create<CollapseShapeOp>(
1979 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1980 collapseShapeOp.getReassociation());
1982 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1991 LogicalResult matchAndRewrite(DimOp dimOp,
1993 auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1998 std::optional<int64_t> dim = dimOp.getConstantIndex();
1999 if (!dim.has_value())
2003 RankedTensorType resultType = expandShapeOp.getResultType();
2004 if (!resultType.isDynamicDim(*dim))
2008 int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
2014 for (int64_t d : grp) {
2016 assert(!resultType.isDynamicDim(d) &&
"expected static dim");
2017 product *= resultType.getDimSize(d);
2023 rewriter.
create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
2027 dimOp, expr.floorDiv(
product), srcDimSz);
2035 LogicalResult matchAndRewrite(DimOp dimOp,
2037 auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
2038 if (!collapseShapeOp)
2042 std::optional<int64_t> dim = dimOp.getConstantIndex();
2043 if (!dim.has_value() ||
2044 dim.value() >= collapseShapeOp.getResultType().getRank())
2048 RankedTensorType resultType = collapseShapeOp.getResultType();
2049 if (!resultType.isDynamicDim(*dim))
2054 collapseShapeOp.getReassociationIndices()[*dim];
2061 srcDimSizes.push_back(rewriter.
create<DimOp>(
2062 dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
2076 struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2079 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2081 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2087 expandOp.getReassociationIndices();
2091 auto outputIt = expandOp.getOutputShape().begin();
2093 for (
const auto &[inputDim, innerReassoc] :
llvm::enumerate(reassoc)) {
2094 for (uint64_t outDim : innerReassoc) {
2095 if (!ShapedType::isDynamic(newOutputShape[outDim]))
2102 Value val = *outputIt;
2104 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2105 dynamicOutputShape.push_back(val);
2111 newOutputShape[outDim] = cst.getSExtValue();
2113 dynamicOutputShape.push_back(val);
2119 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2124 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2125 for (
auto outDim : reassoc[inDim]) {
2126 auto ofr = newOutputShape[outDim];
2127 if (ShapedType::isDynamic(ofr)) {
2128 newInputShape[inDim] = ShapedType::kDynamic;
2131 newInputShape[inDim] *= ofr;
2138 newInputShape, expandOp.getSrcType().getElementType());
2140 newOutputShape, expandOp.getSrcType().getElementType());
2141 auto inputCast = rewriter.
create<CastOp>(expandOp.getLoc(), inputType,
2143 auto newExpand = rewriter.
create<ExpandShapeOp>(
2144 expandOp.getLoc(), outputType, inputCast.getResult(),
2145 expandOp.getReassociationIndices(), outputOfr);
2147 newExpand.getResult());
2158 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2159 FoldReshapeWithSplat<ExpandShapeOp>,
2160 FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
2161 FoldDimOfCollapseShape>(context);
2169 tensor::DimOp, RankedTensorType>,
2170 FoldReshapeWithConstant<CollapseShapeOp>,
2171 FoldReshapeWithSplat<CollapseShapeOp>,
2172 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2176 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2177 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2178 adaptor.getOperands());
2181 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2182 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2183 adaptor.getOperands());
2190 void ExtractSliceOp::getAsmResultNames(
2192 setNameFn(getResult(),
"extracted_slice");
2198 RankedTensorType ExtractSliceOp::inferResultType(
2204 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2205 sourceTensorType.getRank() &&
2206 "unexpected staticSizes not equal to rank of source");
2208 sourceTensorType.getEncoding());
2211 RankedTensorType ExtractSliceOp::inferResultType(
2219 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2220 staticSizes, staticStrides);
2231 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2232 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2236 auto inferredType = llvm::cast<RankedTensorType>(
2237 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2238 int rankDiff = inferredType.getRank() - desiredResultRank;
2240 auto shape = inferredType.getShape();
2241 llvm::SmallBitVector dimsToProject =
2245 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2246 if (!dimsToProject.test(pos))
2247 projectedShape.push_back(shape[pos]);
2251 return inferredType;
2254 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2255 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2263 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2264 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2271 RankedTensorType resultType,
Value source,
2281 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2284 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2285 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2288 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2301 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2310 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2316 RankedTensorType resultType,
Value source,
2325 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2332 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2337 RankedTensorType expectedType) {
2342 return op->
emitError(
"expected rank to be smaller or equal to ")
2343 <<
"the other rank. ";
2345 return op->
emitError(
"expected type to be ")
2346 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2348 return op->
emitError(
"expected element type to be ")
2349 << expectedType.getElementType();
2351 llvm_unreachable(
"unexpected extract_slice op verification result");
2358 RankedTensorType tensorType,
2362 for (int64_t i = 0, e = tensorType.getRank(); i < e; ++i) {
2364 if (tensorType.isDynamicDim(i))
2367 if (ShapedType::isDynamic(staticOffsets[i]))
2369 if (staticOffsets[i] >= tensorType.getDimSize(i))
2371 << i <<
" is out-of-bounds: " << staticOffsets[i]
2372 <<
" >= " << tensorType.getDimSize(i);
2373 if (ShapedType::isDynamic(staticSizes[i]) ||
2374 ShapedType::isDynamic(staticStrides[i]))
2377 staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
2378 if (lastPos >= tensorType.getDimSize(i))
2380 << i <<
" runs out-of-bounds: " << lastPos
2381 <<
" >= " << tensorType.getDimSize(i);
2388 RankedTensorType sourceType = getSourceType();
2391 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2392 sourceType, getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2400 getStaticSizes(), getStaticStrides());
2410 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2411 assert(sourceTensorType &&
"not a ranked tensor type");
2412 auto sourceShape = sourceTensorType.getShape();
2413 if (sourceShape.equals(desiredShape))
2415 auto maybeRankReductionMask =
2417 if (!maybeRankReductionMask)
2426 reifiedReturnShapes.resize(1);
2427 reifiedReturnShapes[0].reserve(
getType().getRank());
2430 for (
const auto &size :
enumerate(mixedSizes)) {
2431 if (droppedDims.test(size.index()))
2433 reifiedReturnShapes[0].push_back(size.value());
2454 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2458 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2461 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
2462 return matchPattern(operand, matchConstantIndex());
2466 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2475 Value newResult = rewriter.
create<ExtractSliceOp>(
2476 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2477 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2478 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2487 template <
typename IterTy,
typename ElemTy>
2492 assert(offsets.size() == sizes.size());
2493 assert(offsets.size() == strides.size());
2494 if (offsets.empty())
2497 int64_t offset = offsets.front();
2498 int64_t size = sizes.front();
2499 int64_t stride = strides.front();
2500 if (offsets.size() == 1) {
2501 for (int64_t i = 0; i < size; ++i, offset += stride)
2502 outValues->push_back(*(values + offset));
2507 for (int64_t i = 0; i < size; ++i, offset += stride) {
2508 auto begin = values + offset * counts.front();
2509 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2510 offsets.drop_front(), sizes.drop_front(),
2511 strides.drop_front(), outValues);
2518 class ConstantOpExtractSliceFolder final
2523 ConstantOpExtractSliceFolder(
MLIRContext *context,
2526 controlFn(std::move(controlFn)) {}
2528 LogicalResult matchAndRewrite(ExtractSliceOp op,
2539 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2540 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2541 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2548 int64_t count = sourceType.getNumElements();
2553 auto offsets = op.getStaticOffsets();
2554 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2556 auto sizes = op.getStaticSizes();
2557 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2559 auto strides = op.getStaticStrides();
2560 if (llvm::is_contained(strides, ShapedType::kDynamic))
2566 counts.reserve(shape.size());
2567 for (int64_t v : shape) {
2569 counts.push_back(count);
2575 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2577 outValues.reserve(sourceType.getNumElements());
2578 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2579 elems.begin(), counts, offsets, sizes, strides, &outValues);
2581 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2583 outValues.reserve(sourceType.getNumElements());
2584 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2585 elems.begin(), counts, offsets, sizes, strides, &outValues);
2608 patterns.add<ConstantOpExtractSliceFolder>(
patterns.getContext(), controlFn);
2617 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2618 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2626 ExtractSliceOp newOp) {
2627 Value replacement = newOp.getResult();
2628 if (replacement.
getType() != op.getType())
2629 replacement = rewriter.
create<tensor::CastOp>(op.getLoc(), op.getType(),
2640 ExtractSliceOpCastFolder>(context);
2644 static LogicalResult
2646 ShapedType shapedType) {
2653 auto shape = shapedType.getShape();
2654 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2668 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2671 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2672 insertOp.isSameAs(extractOp, isSame))
2673 return insertOp.getSource();
2678 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2679 if (
OpFoldResult reshapedSource = reshapeConstantSource(
2680 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2682 return reshapedSource;
2683 if (getSourceType() ==
getType() &&
2685 return this->getSource();
2694 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.
getType());
2695 unsigned rank = rankedTensorType.getRank();
2699 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2700 offsets, sizes, strides);
2707 void InsertSliceOp::getAsmResultNames(
2709 setNameFn(getResult(),
"inserted_slice");
2724 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2736 build(b, result, source, dest, offsets, sizes, strides, attrs);
2749 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2755 RankedTensorType srcType, RankedTensorType dstType,
2760 RankedTensorType expected = ExtractSliceOp::inferResultType(
2761 dstType, staticOffsets, staticSizes, staticStrides);
2763 *expectedType = expected;
2770 RankedTensorType expectedType;
2773 getStaticSizes(), getStaticStrides(), &expectedType);
2780 getStaticSizes(), getStaticStrides());
2801 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2804 if (!prevInsertOp ||
2805 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2806 !prevInsertOp.isSameAs(insertOp, isSame))
2809 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2821 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2824 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2825 !extractOp.isSameAs(insertOp, isSame))
2828 return extractOp.getSource();
2832 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2833 getSourceType() ==
getType() &&
2835 return this->getSource();
2857 template <
typename InsertOpTy>
2858 class InsertSliceOpConstantArgumentFolder final
2863 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2876 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2877 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2878 mixedOffsets, mixedSizes, mixedStrides);
2879 Value toInsert = insertSliceOp.getSource();
2880 if (sourceType != insertSliceOp.getSourceType()) {
2885 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2887 toInsert = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2888 sourceType, toInsert);
2891 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2892 mixedSizes, mixedStrides);
2917 template <
typename InsertOpTy>
2918 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
2921 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2923 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
2924 return matchPattern(operand, matchConstantIndex());
2928 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
2929 auto castOp = v.getDefiningOp<tensor::CastOp>();
2931 return std::nullopt;
2932 return castOp.getSource();
2934 std::optional<Value> sourceCastSource =
2935 getSourceOfCastOp(insertSliceOp.getSource());
2936 std::optional<Value> destCastSource =
2937 getSourceOfCastOp(insertSliceOp.getDest());
2938 if (!sourceCastSource && !destCastSource)
2942 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2943 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2944 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
2945 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2946 if (!srcType || !dstType)
2954 staticSizes, srcType.getShape(),
true);
2955 if (!rankReductionMask.has_value())
2963 int64_t rankReducedIdx = 0;
2964 for (
auto [idx, size] :
enumerate(staticSizes)) {
2965 if (!rankReductionMask.value().contains(idx) &&
2966 !srcType.isDynamicDim(rankReducedIdx)) {
2968 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
2969 size = srcType.getDimSize(rankReducedIdx++);
2973 staticSizes, insertSliceOp.getStaticStrides()) !=
2978 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2979 mixedSizes, insertSliceOp.getMixedStrides());
2982 bool isParallelInsert =
2983 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2984 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2985 replacement = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2986 insertSliceOp.getDestType(),
3015 template <
typename InsertOpTy>
3016 struct InsertSliceOpSourceCastInserter final
3020 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3022 RankedTensorType srcType = insertSliceOp.getSourceType();
3023 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3026 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3027 if (std::optional<int64_t> constInt =
3032 newSrcShape[i] = *constInt;
3039 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3040 if (srcType == newSrcType ||
3042 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3054 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
3057 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
3059 insertSliceOp, cast, insertSliceOp.getDest(),
3060 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3061 insertSliceOp.getMixedStrides());
3073 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3074 InsertSliceOpCastFolder<InsertSliceOp>,
3075 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3082 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3083 unsigned rank = rankedTensorType.getRank();
3087 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3096 setNameFn(getResult(),
"padded");
3102 Type typeToInfer,
Type typeToInferFrom) {}
3106 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3107 Type &typeToInfer,
Type typeToInferFrom) {
3109 typeToInfer = typeToInferFrom;
3114 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3115 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3117 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3118 if (!expectedType) {
3119 return emitError(
"failed to infer expectedType from sourceType ")
3120 << sourceType <<
", specified resultType is " << resultType;
3122 if (resultType.getRank() != expectedType.getRank()) {
3124 << resultType <<
" does not match the inferred type "
3127 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3128 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3130 if (expectedType.isDynamicDim(i))
3133 << resultType <<
" does not match the inferred type "
3140 LogicalResult PadOp::verifyRegions() {
3141 auto ®ion = getRegion();
3142 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3145 return emitError(
"expected the block to have ") << rank <<
" arguments";
3149 if (!en.value().isIndex())
3150 return emitOpError(
"expected block argument ")
3151 << (en.index() + 1) <<
" to be an index";
3156 if (yieldOp.getValue().getType() !=
3158 return emitOpError(
"expected yield type to match shape element type");
3163 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3167 unsigned rank = sourceType.getRank();
3168 if (staticLow.size() != rank)
3169 return RankedTensorType();
3170 if (staticHigh.size() != rank)
3171 return RankedTensorType();
3172 if (!resultShape.empty() && resultShape.size() != rank)
3173 return RankedTensorType();
3176 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3177 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3178 staticHigh[i] == ShapedType::kDynamic) {
3179 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3182 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3183 assert((resultShape.empty() || size == resultShape[i] ||
3184 resultShape[i] == ShapedType::kDynamic) &&
3185 "mismatch between inferred shape and result shape");
3186 inferredShape.push_back(size);
3197 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3199 resultType = inferResultType(sourceType, staticLow, staticHigh);
3201 build(b, result, resultType, source, low, high,
3209 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3210 unsigned rank = sourceType.getRank();
3212 build(b, result, resultType, source, staticVector, staticVector, low, high,
3220 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3230 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3232 assert(llvm::isa<RankedTensorType>(resultType));
3234 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3243 build(b, result, resultType, source, low, high, nofold, attrs);
3247 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3254 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3258 llvm::SmallBitVector PadOp::getPaddedDims() {
3259 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3261 for (
const auto &en :
enumerate(paddingWidths))
3263 paddedDims.set(en.index());
3265 extractPaddedDims(getMixedLowPad());
3266 extractPaddedDims(getMixedHighPad());
3276 LogicalResult matchAndRewrite(PadOp padTensorOp,
3278 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3280 if (padTensorOp.getNofold())
3283 padTensorOp, padTensorOp.getResult().getType(),
3284 padTensorOp.getSource());
3293 LogicalResult matchAndRewrite(PadOp padTensorOp,
3295 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3299 auto newResultType = PadOp::inferResultType(
3300 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3301 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3302 padTensorOp.getResultType().getShape());
3304 if (newResultType == padTensorOp.getResultType()) {
3306 padTensorOp.getSourceMutable().assign(castOp.getSource());
3309 auto newOp = rewriter.
create<PadOp>(
3310 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3311 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3312 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3315 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3318 padTensorOp, padTensorOp.getResultType(), newOp);
3329 LogicalResult matchAndRewrite(PadOp padTensorOp,
3331 if (!padTensorOp.getResult().hasOneUse())
3334 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3338 tensorCastOp.getDest().getType()))
3341 auto replacementOp = rewriter.
create<PadOp>(
3342 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3343 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3344 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3345 padTensorOp.getHigh(), padTensorOp.getNofold(),
3349 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3350 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3393 LogicalResult matchAndRewrite(PadOp padOp,
3395 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3398 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3399 if (!outerPadOp || outerPadOp.getNofold())
3401 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3406 int64_t rank = padOp.getSourceType().getRank();
3407 if (outerSliceOp.getSourceType().getRank() != rank) {
3409 "cannot fold rank-reducing chain");
3413 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3415 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3419 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3421 "cannot fold PadOps with low padding");
3426 Value innerValue = padOp.getConstantPaddingValue();
3427 Value outerValue = outerPadOp.getConstantPaddingValue();
3428 if (!innerValue || !outerValue ||
3431 innerAttr != outerAttr) {
3433 padOp,
"cannot fold PadOps with different padding values");
3437 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3438 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3439 if (innerDims.anyCommon(outerDims)) {
3441 padOp,
"cannot fold PadOps with common padding dimensions");
3451 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3452 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3453 if (!innerDims.test(en.index()) &&
3455 en.value() = outerOffset;
3458 if (!outerDims.test(en.index()) &&
3460 en.value() = innerOffset;
3464 padOp,
"cannot find zero-offset and zero-padding pair");
3474 if (!outerDims.test(en.index()))
3476 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3477 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3478 assert(!ShapedType::isDynamic(sourceSize) &&
3479 "expected padded dimension to have a static size");
3482 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3483 "match the size of the outer padding");
3485 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3491 if (innerDims.test(en.index()))
3492 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3493 if (outerDims.test(en.index()))
3494 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3499 auto newSliceOp = rewriter.
create<ExtractSliceOp>(
3500 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3501 innerSliceOp.getMixedStrides());
3502 auto newPadOp = rewriter.
create<PadOp>(
3503 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3504 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3507 newPadOp.getRegion().begin());
3508 rewriter.
replaceOp(padOp, newPadOp.getResult());
3516 LogicalResult matchAndRewrite(PadOp padTensorOp,
3518 Value input = padTensorOp.getSource();
3519 if (!llvm::isa<RankedTensorType>(input.
getType()))
3521 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3522 auto inputRank = inputDims.size();
3524 auto oldResultType =
3525 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3529 auto outputDims = oldResultType.getShape();
3534 for (
auto operand : padTensorOp.getLow()) {
3537 constOperandsLow.push_back(ShapedType::kDynamic);
3538 newLows.push_back(operand);
3541 constOperandsLow.push_back(intOp.getExtValue());
3545 for (
auto operand : padTensorOp.getHigh()) {
3548 constOperandsHigh.push_back(ShapedType::kDynamic);
3549 newHighs.push_back(operand);
3552 constOperandsHigh.push_back(intOp.getExtValue());
3559 if (inputDims.size() != outputDims.size() ||
3560 inputDims.size() != constLow.size() ||
3561 inputDims.size() != constHigh.size())
3566 for (
size_t i = 0; i < inputRank; i++) {
3567 if (constLow[i] == ShapedType::kDynamic)
3568 constLow[i] = constOperandsLow[lowCount++];
3569 if (constHigh[i] == ShapedType::kDynamic)
3570 constHigh[i] = constOperandsHigh[highCount++];
3578 for (
size_t i = 0; i < inputRank; i++) {
3579 if (outputDims[i] == ShapedType::kDynamic) {
3580 newOutDims.push_back(
3581 (staticLow[i] == ShapedType::kDynamic ||
3582 staticHigh[i] == ShapedType::kDynamic ||
3583 inputDims[i] == ShapedType::kDynamic
3584 ? ShapedType::kDynamic
3585 : inputDims[i] + staticLow[i] + staticHigh[i]));
3587 newOutDims.push_back(outputDims[i]);
3592 llvm::all_of(newOutDims,
3593 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3598 newOutDims, padTensorOp.getType().getElementType());
3599 auto newOp = rewriter.
create<PadOp>(
3600 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3601 newLows, newHighs, padTensorOp.getNofold(),
3605 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3633 struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3636 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3638 if (padOp.getNofold()) {
3642 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3643 if (!producerPad || producerPad.getNofold()) {
3645 padOp,
"producer is not a foldable tensor.pad op");
3649 Value consumerPadValue = padOp.getConstantPaddingValue();
3650 Value producerPadValue = producerPad.getConstantPaddingValue();
3651 if (!consumerPadValue || !producerPadValue ||
3652 consumerPadValue != producerPadValue) {
3655 "cannot fold PadOps with different or non-constant padding values");
3666 for (
auto [consumerIndex, producerIndex] :
3667 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3669 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3675 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3677 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3679 auto newPadOp = rewriter.
create<tensor::PadOp>(
3680 padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3681 newLowPad, newHighPad, padOp.getNofold(),
3684 newPadOp.getRegion().begin());
3685 rewriter.
replaceOp(padOp, newPadOp.getResult());
3694 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3695 FoldOrthogonalPaddings, FoldStaticPadding,
3696 FoldConsecutiveConstantPadding>(context);
3708 Value PadOp::getConstantPaddingValue() {
3709 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3712 Value padValue = yieldOp.getValue();
3724 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3734 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3735 ParallelCombiningOpInterface parallelCombiningParent =
3736 getParallelCombiningParent();
3737 for (
const auto &it :
3740 if (&nextOp == getOperation())
3741 return parallelCombiningParent.getParentResult(it.index());
3743 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3759 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3772 build(b, result, source, dest, offsets, sizes, strides, attrs);
3786 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3790 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3791 return this->
emitError(
"expected ParallelCombiningOpInterface parent, got:")
3792 << *(getOperation()->getParentOp());
3795 RankedTensorType expectedType;
3798 getStaticSizes(), getStaticStrides(), &expectedType);
3805 getStaticSizes(), getStaticStrides());
3808 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3810 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3811 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3812 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3823 void ScatterOp::getAsmResultNames(
3825 setNameFn(getResult(),
"scatter");
3829 int64_t destRank = getDestType().getRank();
3832 getIndicesType().
getShape(), destRank,
3833 "scatter",
"dest")))
3837 return emitOpError(
"requires 'unique' attribute to be set");
3844 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3845 getDestType(), getIndicesType(), scatterDims,
false);
3846 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3847 getDestType(), getIndicesType(), scatterDims,
true);
3848 if (getSourceType() != expectedSourceType &&
3849 getSourceType() != expectedRankReducedSourceType) {
3850 return emitOpError(
"source type "
3853 << expectedSourceType <<
" or its rank-reduced variant "
3854 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3867 build(builder, result, aggregateType, element, dynamicSizes);
3873 build(builder, result, aggregateType, element, dynamicSizes);
3881 build(builder, result, element, staticShape, dynamicSizes);
3884 void SplatOp::getAsmResultNames(
3886 setNameFn(getResult(),
"splat");
3891 return emitOpError(
"incorrect number of dynamic sizes, has ")
3893 <<
getType().getNumDynamicDims();
3902 for (int64_t i = 0; i <
getType().getRank(); ++i) {
3903 if (
getType().isDynamicDim(i)) {
3913 auto constOperand = adaptor.getInput();
3914 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3918 if (!
getType().hasStaticShape())
3933 if (isa<InsertSliceOp>(op.getOperation()) ||
3934 isa<LoopLikeOpInterface>(op.getOperation()))
3967 isa<linalg::RelayoutOpInterface>(*op))
3975 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
3978 replacements.reserve(newOp->getNumResults());
3979 for (
auto [oldResult, newResult] :
3980 llvm::zip(op->getResults(), newOp->getResults())) {
3981 if (newResult.
getType() != oldResult.getType()) {
3982 replacements.push_back(rewriter.
create<tensor::CastOp>(
3983 op->getLoc(), oldResult.getType(), newResult));
3985 replacements.push_back(newResult);
3998 void TensorDialect::getCanonicalizationPatterns(
4007 #define GET_OP_CLASSES
4008 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static int64_t getNumElements(ShapedType type)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
static LogicalResult verifyInBoundsSlice(Operation *op, RankedTensorType tensorType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides)
Verify that the offsets/sizes/strides-style access into the given tensor is in-bounds.
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)