27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/StringRef.h"
31 #include "llvm/Support/MathExtras.h"
38 using llvm::divideCeilSigned;
39 using llvm::divideFloorSigned;
47 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
49 if (complex::ConstantOp::isBuildableWith(value, type))
50 return builder.
create<complex::ConstantOp>(loc, type,
51 llvm::cast<ArrayAttr>(value));
57 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
59 if (tensorType.isDynamicDim(dim))
60 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
67 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
69 for (int64_t i = 0; i < tensorType.getRank(); ++i)
76 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
77 assert(tensorType &&
"expected tensor type");
81 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
83 return destOp.getTiedOpOperand(opResult)->get();
91 if (!tensorType.hasStaticShape()) {
99 for (int64_t sz : tensorType.getShape())
105 b.
create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
113 if (llvm::isa<TensorType>(opResult.getType())) {
115 if (failed(destination))
117 result.push_back(*destination);
124 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
125 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
126 return rtp1.getShape() == rtp2.getShape() &&
127 rtp1.getElementType() == rtp2.getElementType();
137 llvm::SmallBitVector droppedDims(mixedSizes.size());
138 int64_t shapePos = reducedShape.size() - 1;
140 for (
const auto &size :
enumerate(llvm::reverse(mixedSizes))) {
141 size_t idx = mixedSizes.size() - size.index() - 1;
143 bool isStaticUnitSize =
145 llvm::cast<IntegerAttr>(size.value().get<
Attribute>()).getInt() == 1;
150 assert(isStaticUnitSize &&
"expected unit dim");
151 droppedDims.set(idx);
156 if (!isStaticUnitSize) {
162 if (reducedShape[shapePos] == 1) {
168 droppedDims.set(idx);
171 assert(shapePos < 0 &&
"dimension mismatch");
178 static RankedTensorType
182 type.getShape().end());
183 assert(type.getNumDynamicDims() ==
184 static_cast<int64_t
>(dynamicSizes.size()) &&
185 "incorrect number of dynamic sizes");
189 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
190 if (type.isDynamicDim(i)) {
191 Value dynamicSize = dynamicSizes[ctr++];
193 if (cst.has_value()) {
195 if (cst.value() < 0) {
196 foldedDynamicSizes.push_back(dynamicSize);
199 staticShape[i] = *cst;
201 foldedDynamicSizes.push_back(dynamicSize);
215 if (inputs.size() != 1 || outputs.size() != 1)
217 Type a = inputs.front(), b = outputs.front();
218 auto aT = dyn_cast<TensorType>(a);
219 auto bT = dyn_cast<TensorType>(b);
223 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
236 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
238 auto tensorBitcastOperand =
239 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
240 if (!tensorBitcastOperand)
243 auto resultType = cast<TensorType>(tensorBitcast.getType());
244 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
245 tensorBitcastOperand.getOperand());
254 results.
add<ChainedTensorBitcast>(context);
262 setNameFn(getResult(),
"cast");
268 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
269 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
272 if (!sourceType || !targetType)
276 if (sourceType.getElementType() != targetType.getElementType())
280 if (sourceType.getRank() != targetType.getRank())
284 if (sourceType.getEncoding() != targetType.getEncoding())
288 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
289 if (!ShapedType::isDynamic(std::get<0>(t)) &&
290 ShapedType::isDynamic(std::get<1>(t)))
326 castOp.getSource().getType());
361 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
363 operand.set(castOp.getOperand());
367 return success(folded);
371 if (inputs.size() != 1 || outputs.size() != 1)
373 Type a = inputs.front(), b = outputs.front();
374 auto aT = llvm::dyn_cast<TensorType>(a);
375 auto bT = llvm::dyn_cast<TensorType>(b);
379 if (aT.getElementType() != bT.getElementType())
395 int64_t rank = one.getRank();
396 if (rank != two.getRank())
401 for (int64_t i = 0; i < rank; ++i) {
402 if (one.isDynamicDim(i)) {
403 join.push_back(two.getDimSize(i));
406 if (two.isDynamicDim(i)) {
407 join.push_back(one.getDimSize(i));
410 if (one.getDimSize(i) != two.getDimSize(i))
412 join.push_back(one.getDimSize(i));
424 LogicalResult matchAndRewrite(CastOp tensorCast,
426 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
428 if (!tensorCastOperand)
432 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
433 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
434 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
448 auto newJoin =
joinShapes(sourceType, resultType);
449 if (firstJoin != newJoin)
452 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
453 tensorCastOperand.getOperand());
473 LogicalResult matchAndRewrite(CastOp tensorCast,
475 auto extractOperand =
476 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
479 auto rankedResultType =
480 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
481 if (!rankedResultType)
485 rankedResultType.getShape() ==
486 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
492 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
494 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
495 if (dimMask && dimMask->count(i))
497 int64_t dim = rankedResultType.getShape()[dimIndex++];
498 if (ShapedType::isDynamic(dim))
500 sizes[i] = rewriter.getIndexAttr(dim);
503 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
504 tensorCast, rankedResultType, extractOperand.getSource(),
505 extractOperand.getMixedOffsets(), sizes,
506 extractOperand.getMixedStrides());
515 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
522 RankedTensorType ConcatOp::inferResultType(int64_t dim,
TypeRange inputTypes) {
523 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
525 llvm::to_vector<4>(llvm::map_range(inputTypes, [](
Type type) {
526 return llvm::cast<RankedTensorType>(type);
528 int64_t concatRank = tensorTypes[0].getRank();
531 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
534 for (int64_t i = 0, e = concatRank; i < e; ++i) {
538 for (
auto tensorType : tensorTypes)
543 for (
auto tensorType : tensorTypes)
546 sizes[dim] = concatSize.asInteger();
552 FailureOr<RankedTensorType> resultType =
553 inferResultType(dim, inputs.
getTypes());
554 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
555 build(builder, result, *resultType, dim, inputs);
559 if (getInputs().size() < 1)
560 return emitOpError(
"requires at least one input");
563 for (
auto input : getInputs())
564 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
566 RankedTensorType resultType = getResultType();
567 int64_t resultRank = getRank();
568 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
569 return type.getRank() != resultRank;
571 return emitOpError(
"rank of concatenated inputs must match result rank");
573 Type resultElementType = resultType.getElementType();
574 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
575 return type.getElementType() != resultElementType;
577 return emitOpError(
"inputs and result element type must match");
579 int64_t dim = getDim();
580 if (dim >= resultRank)
581 return emitOpError(
"concatenation dim must be less than the tensor rank");
584 for (int64_t i = 0, e = resultRank; i < e; ++i) {
588 for (
auto tensorType : inputTypes) {
589 FailureOr<SaturatedInteger> maybeSize =
591 if (failed(maybeSize))
592 return emitOpError(
"static concatenation size mismatch along ")
593 <<
"non-concatenated dimension " << i;
599 for (
auto tensorType : inputTypes)
602 sizes[dim] = concatSize.asInteger();
603 auto inferredResultType =
606 for (
auto [inferredSize, actualSize] :
607 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
608 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
609 ShapedType::isDynamic(actualSize);
610 if (!hasDynamic && inferredSize != actualSize)
611 return emitOpError(
"result type ")
612 << resultType <<
"does not match inferred shape "
613 << inferredResultType <<
" static sizes";
623 int64_t dim = getDim();
624 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
626 Value init = inputs[0];
627 int64_t rank =
getType().getRank();
634 for (int64_t i = 0; i < rank; ++i) {
637 if (!
getType().isDynamicDim(i)) {
639 }
else if (!inferredResultType.isDynamicDim(i)) {
642 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
644 reifiedReturnShapes[0][i] =
645 builder.
create<tensor::DimOp>(init.
getLoc(), init, i).getResult();
649 if (
getType().isDynamicDim(dim)) {
657 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
665 reifiedReturnShapes[0][dim] =
671 void ConcatOp::getAsmResultNames(
673 setNameFn(getResult(),
"concat");
678 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
688 LogicalResult matchAndRewrite(ConcatOp concatOp,
690 if (concatOp.getInputs().size() != 1)
693 concatOp.getInputs()[0]);
701 results.
add<SingleInputConcatOp>(context);
709 setNameFn(getResult(),
"dim");
715 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
716 build(builder, result, source, indexValue);
719 std::optional<int64_t> DimOp::getConstantIndex() {
728 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
729 if (!rankedSourceType)
740 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
745 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
751 int64_t indexVal = index.getInt();
752 if (indexVal < 0 || indexVal >= tensorType.getRank())
756 if (!tensorType.isDynamicDim(index.getInt())) {
758 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
761 Operation *definingOp = getSource().getDefiningOp();
764 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
766 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
769 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
772 auto dynExtents = fromElements.getDynamicExtents().begin();
773 for (
auto dim : resultType.getShape().take_front(index.getInt()))
774 if (ShapedType::isDynamic(dim))
777 return Value{*dynExtents};
781 unsigned unsignedIndex = index.getValue().getZExtValue();
783 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
786 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
787 sliceOp.isDynamicSize(unsignedIndex)) {
788 return {sliceOp.getDynamicSize(unsignedIndex)};
804 LogicalResult matchAndRewrite(DimOp dimOp,
806 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
809 Value newSource = castOp.getOperand();
820 LogicalResult matchAndRewrite(DimOp dimOp,
822 auto source = dimOp.getSource();
823 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
827 auto resultIndex = cast<OpResult>(source).getResultNumber();
828 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
831 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
841 LogicalResult matchAndRewrite(DimOp dim,
843 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
853 rewriter.
create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
854 if (extract.
getType() != dim.getType())
856 rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), extract);
865 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
875 assert(all_of(staticShape,
876 [](int64_t sz) {
return !ShapedType::isDynamic(sz); }) &&
877 "expected only static sizes");
878 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
885 build(builder, result, tensorType, dynamicSizes);
894 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
898 if (
getType().getNumDynamicDims() !=
900 return emitOpError(
"incorrect number of dynamic sizes, has ")
902 <<
getType().getNumDynamicDims();
911 for (int64_t i = 0; i <
getType().getRank(); ++i) {
912 if (
getType().isDynamicDim(i)) {
921 Value EmptyOp::getDynamicSize(
unsigned idx) {
922 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
924 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
934 for (int64_t i = 0; i <
getType().getRank(); ++i) {
935 if (
getType().isDynamicDim(i)) {
959 LogicalResult matchAndRewrite(EmptyOp op,
963 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
966 if (foldedTensorType == op.getType())
969 auto newOp = rewriter.
create<EmptyOp>(op.
getLoc(), foldedTensorType,
979 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
981 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
982 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
983 if (!emptyTensorOp || !maybeConstantIndex)
985 if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
988 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1011 LogicalResult matchAndRewrite(CastOp castOp,
1015 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1020 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1024 newMixedSizes.reserve(currMixedSizes.size());
1025 assert(resultShape.size() == currMixedSizes.size() &&
1026 "mismatch in result shape and sizes of empty op");
1027 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1028 int64_t newDim = std::get<0>(it);
1032 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1033 if (ShapedType::isDynamic(newDim) ||
1034 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1039 producer,
"mismatch in static value of shape of empty tensor "
1040 "result and cast result");
1042 newMixedSizes.push_back(attr);
1048 if (!ShapedType::isDynamic(newDim)) {
1049 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1055 newMixedSizes.push_back(currDim);
1060 resultType.getElementType());
1069 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1070 ReplaceEmptyTensorStaticShapeDims>(context);
1079 std::optional<Attribute> cst = std::nullopt) {
1080 if (source && source.
isSplat() && result.hasStaticShape() &&
1101 struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1104 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1106 auto tensorCast = extract.getTensor().
getDefiningOp<tensor::CastOp>();
1109 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1112 extract, tensorCast.getSource(), extract.getIndices());
1119 void ExtractOp::getAsmResultNames(
1121 setNameFn(getResult(),
"extracted");
1126 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1127 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1128 return emitOpError(
"incorrect number of indices for extract_element");
1135 if (
Attribute tensor = adaptor.getTensor())
1136 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1137 return splatTensor.getSplatValue<
Attribute>();
1141 for (
Attribute indice : adaptor.getIndices()) {
1142 if (!indice || !llvm::isa<IntegerAttr>(indice))
1144 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1148 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1149 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1150 auto rank = tensorType.getRank();
1151 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
1155 for (
int i = rank - 1; i >= 0; --i) {
1156 flatIndex += indices[i] * stride;
1157 stride *= tensorType.getDimSize(i);
1161 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1164 return fromElementsOp.getElements()[flatIndex];
1168 if (
Attribute tensor = adaptor.getTensor()) {
1169 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1170 if (elementsAttr && elementsAttr.isValidIndex(indices))
1171 return elementsAttr.getValues<
Attribute>()[indices];
1179 results.
add<ExtractFromTensorCast>(context);
1186 void FromElementsOp::getAsmResultNames(
1188 setNameFn(getResult(),
"from_elements");
1193 assert(!elements.empty() &&
"expected at least one element");
1195 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
1196 build(builder, result, resultType, elements);
1199 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1200 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1223 struct ExtractElementFromIndexCast
1227 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1230 auto indexCast = extract.getTensor().
getDefiningOp<arith::IndexCastOp>();
1236 auto newExtract = rewriter.
create<tensor::ExtractOp>(
1237 loc, elementTy, indexCast.getIn(), extract.getIndices());
1250 results.
add<ExtractElementFromIndexCast>(context);
1257 void GatherOp::getAsmResultNames(
1259 setNameFn(getResult(),
"gather");
1274 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1275 RankedTensorType indicesType,
1279 resultShape.reserve(resultShape.size() + sourceType.getRank());
1280 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1281 if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1283 resultShape.push_back(1);
1286 resultShape.push_back(sourceType.getDimSize(idx));
1291 static LogicalResult
1293 StringRef gatherOrScatter, StringRef sourceOrDest) {
1295 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1297 int64_t numGatherDims = dims.size();
1298 if (numGatherDims > rank)
1300 <<
"_dims overflow " << sourceOrDest <<
" rank";
1301 for (int64_t val : dims) {
1304 <<
"_dims value must be non-negative";
1307 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1309 for (int64_t i = 1; i < numGatherDims; ++i) {
1310 if (dims[i - 1] >= dims[i])
1312 <<
"_dims values must be strictly increasing";
1318 int64_t sourceRank = getSourceType().getRank();
1321 "gather",
"source")))
1324 RankedTensorType expectedResultType = GatherOp::inferResultType(
1325 getSourceType(), getIndicesType(), gatherDims,
false);
1326 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1327 getSourceType(), getIndicesType(), gatherDims,
true);
1328 if (getResultType() != expectedResultType &&
1329 getResultType() != expectedRankReducedResultType) {
1330 return emitOpError(
"result type "
1333 << expectedResultType <<
" or its rank-reduced variant "
1334 << expectedRankReducedResultType <<
" (got: " << getResultType()
1343 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1345 return reshapedSource;
1353 void InsertOp::getAsmResultNames(
1355 setNameFn(getResult(),
"inserted");
1360 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1361 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1362 return emitOpError(
"incorrect number of indices");
1370 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1371 if (scalar == splatDest.getSplatValue<
Attribute>())
1380 void GenerateOp::getAsmResultNames(
1382 setNameFn(getResult(),
"generated");
1389 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1390 if (
getType().isDynamicDim(dim)) {
1391 reifiedReturnShapes[0][dim] = getOperand(idx++);
1393 reifiedReturnShapes[0][dim] =
1403 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1404 if (getNumOperands() != resultType.getNumDynamicDims())
1405 return emitError(
"must have as many index operands as dynamic extents "
1406 "in the result type");
1410 LogicalResult GenerateOp::verifyRegions() {
1411 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1413 if (!llvm::all_of(getBody().getArgumentTypes(),
1415 return emitError(
"all body arguments must be index");
1416 if (getBody().getNumArguments() != resultTy.getRank())
1417 return emitError(
"must have one body argument per input dimension");
1420 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1422 if (yieldOp.getValue().getType() != resultTy.getElementType())
1424 "body must be terminated with a `yield` operation of the tensor "
1430 void GenerateOp::build(
1434 build(b, result, resultTy, dynamicExtents);
1439 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1443 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1456 LogicalResult matchAndRewrite(GenerateOp generateOp,
1460 generateOp.getType(), generateOp.getDynamicExtents(),
1461 foldedDynamicSizes);
1464 if (foldedTensorType == generateOp.getType())
1467 auto loc = generateOp.getLoc();
1469 rewriter.
create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1471 newOp.getBody().begin());
1473 generateOp.getType(), newOp);
1489 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1492 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1494 auto tensorFromElements = extract.getTensor().
getDefiningOp<GenerateOp>();
1499 Block *body = &tensorFromElements.getBody().
front();
1502 rewriter.
clone(op, mapping);
1516 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1523 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1524 setNameFn(getResult(),
"rank");
1529 auto type = getOperand().getType();
1530 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1531 if (shapedType && shapedType.hasRank())
1533 return IntegerAttr();
1540 void ReshapeOp::getAsmResultNames(
1542 setNameFn(getResult(),
"reshape");
1546 int64_t numElements = 1;
1547 for (
auto dim : type.getShape())
1557 return emitOpError(
"element types of source and destination tensor "
1558 "types should be the same");
1562 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1563 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1565 if (resultRankedType) {
1566 if (operandRankedType && resultRankedType.hasStaticShape() &&
1567 operandRankedType.hasStaticShape()) {
1569 return emitOpError(
"source and destination tensor should have the "
1570 "same number of elements");
1572 if (ShapedType::isDynamic(shapeSize))
1573 return emitOpError(
"cannot use shape operand with dynamic length to "
1574 "reshape to statically-ranked tensor type");
1575 if (shapeSize != resultRankedType.getRank())
1577 "length of shape operand differs from the result's tensor rank");
1584 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1586 return reshapedSource;
1591 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1592 getSourceMutable().assign(reshapeOpProducer.getSource());
1596 auto source = getSource();
1597 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1598 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1599 if (!sourceTy || !resultTy || sourceTy != resultTy)
1604 if (sourceTy.getRank() == 1)
1607 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1608 auto elements = fromElements.getElements();
1610 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1611 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++id) {
1612 auto element = elements[id];
1615 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1619 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1620 dynamicNoop &= dimOp.getSource() == source;
1625 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1629 dynamicNoop =
false;
1644 void CollapseShapeOp::getAsmResultNames(
1646 setNameFn(getResult(),
"collapsed");
1649 void ExpandShapeOp::getAsmResultNames(
1651 setNameFn(getResult(),
"expanded");
1654 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1655 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1656 "invalid resultDim");
1658 if (llvm::is_contained(it.value(), resultDim))
1660 llvm_unreachable(
"could not find reassociation group");
1663 FailureOr<SmallVector<OpFoldResult>>
1665 RankedTensorType expandedType,
1668 std::optional<SmallVector<OpFoldResult>> outputShape =
1673 return *outputShape;
1680 auto [staticOutputShape, dynamicOutputShape] =
1682 build(builder, result, cast<RankedTensorType>(resultType), src,
1684 dynamicOutputShape, staticOutputShape);
1692 auto tensorResultTy = cast<RankedTensorType>(resultType);
1693 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1694 builder, result.
location, tensorResultTy, reassociation, inputShape);
1696 if (succeeded(outputShape)) {
1697 outputShapeOrEmpty = *outputShape;
1699 build(builder, result, tensorResultTy, src, reassociation,
1700 outputShapeOrEmpty);
1708 getReassociationIndices());
1716 getReassociationIndices());
1719 RankedTensorType CollapseShapeOp::inferCollapsedType(
1721 return inferCollapsedType(
1723 type.getContext(), reassociation)));
1729 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1731 auto shape = type.getShape();
1733 newShape.reserve(reassociation.size());
1738 unsigned currentDim = 0;
1740 unsigned dim = m.getNumResults();
1741 auto band = shape.slice(currentDim, dim);
1743 if (llvm::is_contained(band, ShapedType::kDynamic))
1744 size = ShapedType::kDynamic;
1746 for (
unsigned d = 0; d < dim; ++d)
1747 size *= shape[currentDim + d];
1748 newShape.push_back(size);
1758 auto resultType = inferCollapsedType(
1759 llvm::cast<RankedTensorType>(src.
getType()),
1764 build(b, result, resultType, src, attrs);
1767 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
1768 TensorReshapeOp, ExpandShapeOp>::value>
1770 RankedTensorType expandedType,
1771 RankedTensorType collapsedType) {
1776 auto maps = op.getReassociationMaps();
1777 RankedTensorType expectedType =
1778 CollapseShapeOp::inferCollapsedType(expandedType, maps);
1780 return op.
emitOpError(
"expected collapsed type to be ")
1781 << expectedType <<
", but got " << collapsedType;
1786 auto srcType = getSrcType();
1787 auto resultType = getResultType();
1789 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1790 return emitOpError(
"expected number of static shape dims to be equal to "
1791 "the output rank (")
1792 << resultType.getRank() <<
") but found "
1793 << getStaticOutputShape().size() <<
" inputs instead";
1795 if ((int64_t)getOutputShape().size() !=
1796 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1797 return emitOpError(
"mismatch in dynamic dims in output_shape and "
1798 "static_output_shape: static_output_shape has ")
1799 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1800 <<
" dynamic dims while output_shape has " << getOutputShape().size()
1813 template <
typename TensorReshapeOp>
1816 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1824 reshapeOp.getResultType(), attr.
getRawData());
1831 template <
typename TensorReshapeOp>
1836 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1838 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1839 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1843 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1850 template <
typename TensorReshapeOp>
1853 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1856 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1860 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1862 if (!shapedTy.hasStaticShape())
1866 fromElements.getElements());
1875 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1877 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1881 RankedTensorType srcType =
1882 llvm::cast<RankedTensorType>(castOp.getSource().getType());
1883 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1884 srcType, collapseShapeOp.getReassociationMaps());
1886 if (newResultType == collapseShapeOp.getResultType()) {
1888 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1891 auto newOp = rewriter.
create<CollapseShapeOp>(
1892 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1893 collapseShapeOp.getReassociation());
1895 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1904 LogicalResult matchAndRewrite(DimOp dimOp,
1906 auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1911 std::optional<int64_t> dim = dimOp.getConstantIndex();
1912 if (!dim.has_value())
1916 RankedTensorType resultType = expandShapeOp.getResultType();
1917 if (!resultType.isDynamicDim(*dim))
1921 int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1927 for (int64_t d : grp) {
1929 assert(!resultType.isDynamicDim(d) &&
"expected static dim");
1930 product *= resultType.getDimSize(d);
1936 rewriter.
create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1940 dimOp, expr.floorDiv(
product), srcDimSz);
1948 LogicalResult matchAndRewrite(DimOp dimOp,
1950 auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1951 if (!collapseShapeOp)
1955 std::optional<int64_t> dim = dimOp.getConstantIndex();
1956 if (!dim.has_value())
1960 RankedTensorType resultType = collapseShapeOp.getResultType();
1961 if (!resultType.isDynamicDim(*dim))
1966 collapseShapeOp.getReassociationIndices()[*dim];
1973 srcDimSizes.push_back(rewriter.
create<DimOp>(
1974 dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1990 FoldReshapeWithConstant<ExpandShapeOp>,
1991 FoldReshapeWithSplat<ExpandShapeOp>,
1992 FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1993 FoldDimOfCollapseShape>(context);
2001 tensor::DimOp, RankedTensorType>,
2002 FoldReshapeWithConstant<CollapseShapeOp>,
2003 FoldReshapeWithSplat<CollapseShapeOp>,
2004 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2008 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2009 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2010 adaptor.getOperands());
2013 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2014 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2015 adaptor.getOperands());
2022 void ExtractSliceOp::getAsmResultNames(
2024 setNameFn(getResult(),
"extracted_slice");
2030 RankedTensorType ExtractSliceOp::inferResultType(
2036 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2037 sourceTensorType.getRank() &&
2038 "unexpected staticSizes not equal to rank of source");
2040 sourceTensorType.getEncoding());
2043 RankedTensorType ExtractSliceOp::inferResultType(
2051 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2052 staticSizes, staticStrides);
2063 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2064 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2068 auto inferredType = llvm::cast<RankedTensorType>(
2069 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2070 int rankDiff = inferredType.getRank() - desiredResultRank;
2072 auto shape = inferredType.getShape();
2073 llvm::SmallBitVector dimsToProject =
2077 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2078 if (!dimsToProject.test(pos))
2079 projectedShape.push_back(shape[pos]);
2083 return inferredType;
2086 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2087 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2095 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2096 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2103 RankedTensorType resultType,
Value source,
2113 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2116 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2117 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2120 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2133 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2142 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2148 RankedTensorType resultType,
Value source,
2157 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2164 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2169 RankedTensorType expectedType) {
2174 return op->
emitError(
"expected rank to be smaller or equal to ")
2175 <<
"the other rank. ";
2177 return op->
emitError(
"expected type to be ")
2178 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2180 return op->
emitError(
"expected element type to be ")
2181 << expectedType.getElementType();
2183 llvm_unreachable(
"unexpected extract_slice op verification result");
2190 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2191 getSourceType(), getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2203 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2204 assert(sourceTensorType &&
"not a ranked tensor type");
2205 auto sourceShape = sourceTensorType.getShape();
2206 if (sourceShape.equals(desiredShape))
2208 auto maybeRankReductionMask =
2210 if (!maybeRankReductionMask)
2219 reifiedReturnShapes.resize(1);
2220 reifiedReturnShapes[0].reserve(
getType().getRank());
2223 for (
const auto &size :
enumerate(mixedSizes)) {
2224 if (droppedDims.test(size.index()))
2226 reifiedReturnShapes[0].push_back(size.value());
2247 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2251 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2254 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
2255 return matchPattern(operand, matchConstantIndex());
2259 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2268 Value newResult = rewriter.
create<ExtractSliceOp>(
2269 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2270 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2271 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2272 if (newResult.
getType() != sliceOp.getType())
2273 newResult = rewriter.
create<CastOp>(loc, sliceOp.getType(), newResult);
2282 template <
typename IterTy,
typename ElemTy>
2287 assert(offsets.size() == sizes.size());
2288 assert(offsets.size() == strides.size());
2289 if (offsets.empty())
2292 int64_t offset = offsets.front();
2293 int64_t size = sizes.front();
2294 int64_t stride = strides.front();
2295 if (offsets.size() == 1) {
2296 for (int64_t i = 0; i < size; ++i, offset += stride)
2297 outValues->push_back(*(values + offset));
2302 for (int64_t i = 0; i < size; ++i, offset += stride) {
2303 auto begin = values + offset * counts.front();
2304 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2305 offsets.drop_front(), sizes.drop_front(),
2306 strides.drop_front(), outValues);
2313 class ConstantOpExtractSliceFolder final
2318 ConstantOpExtractSliceFolder(
MLIRContext *context,
2321 controlFn(std::move(controlFn)) {}
2323 LogicalResult matchAndRewrite(ExtractSliceOp op,
2334 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2336 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2343 int64_t count = sourceType.getNumElements();
2348 auto offsets = op.getStaticOffsets();
2349 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2351 auto sizes = op.getStaticSizes();
2352 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2354 auto strides = op.getStaticStrides();
2355 if (llvm::is_contained(strides, ShapedType::kDynamic))
2361 counts.reserve(shape.size());
2362 for (int64_t v : shape) {
2364 counts.push_back(count);
2370 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2372 outValues.reserve(sourceType.getNumElements());
2373 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2374 elems.begin(), counts, offsets, sizes, strides, &outValues);
2376 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2378 outValues.reserve(sourceType.getNumElements());
2379 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2380 elems.begin(), counts, offsets, sizes, strides, &outValues);
2403 patterns.
add<ConstantOpExtractSliceFolder>(patterns.
getContext(), controlFn);
2412 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2413 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2421 ExtractSliceOp newOp) {
2422 Value replacement = newOp.getResult();
2423 if (replacement.
getType() != op.getType())
2424 replacement = rewriter.
create<tensor::CastOp>(op.
getLoc(), op.getType(),
2435 ExtractSliceOpCastFolder>(context);
2439 static LogicalResult
2441 ShapedType shapedType) {
2448 auto shape = shapedType.getShape();
2449 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2463 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2466 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2467 insertOp.isSameAs(extractOp, isSame))
2468 return insertOp.getSource();
2473 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2475 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2477 return reshapedSource;
2478 if (getSourceType() ==
getType() &&
2480 return this->getSource();
2489 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.
getType());
2490 unsigned rank = rankedTensorType.getRank();
2494 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2495 offsets, sizes, strides);
2502 void InsertSliceOp::getAsmResultNames(
2504 setNameFn(getResult(),
"inserted_slice");
2519 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2531 build(b, result, source, dest, offsets, sizes, strides, attrs);
2544 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2550 RankedTensorType srcType, RankedTensorType dstType,
2555 RankedTensorType expected = ExtractSliceOp::inferResultType(
2556 dstType, staticOffsets, staticSizes, staticStrides);
2558 *expectedType = expected;
2564 RankedTensorType expectedType;
2567 getStaticSizes(), getStaticStrides(), &expectedType);
2589 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2592 if (!prevInsertOp ||
2593 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2594 !prevInsertOp.isSameAs(insertOp, isSame))
2597 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2609 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2612 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2613 !extractOp.isSameAs(insertOp, isSame))
2616 return extractOp.getSource();
2620 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2621 getSourceType() ==
getType() &&
2623 return this->getSource();
2645 template <
typename InsertOpTy>
2646 class InsertSliceOpConstantArgumentFolder final
2651 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2664 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2665 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2666 mixedOffsets, mixedSizes, mixedStrides);
2667 Value toInsert = insertSliceOp.getSource();
2668 if (sourceType != insertSliceOp.getSourceType()) {
2673 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2675 toInsert = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2676 sourceType, toInsert);
2679 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2680 mixedSizes, mixedStrides);
2705 template <
typename InsertOpTy>
2706 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
2709 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2711 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
2712 return matchPattern(operand, matchConstantIndex());
2716 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
2717 auto castOp = v.getDefiningOp<tensor::CastOp>();
2719 return std::nullopt;
2720 return castOp.getSource();
2722 std::optional<Value> sourceCastSource =
2723 getSourceOfCastOp(insertSliceOp.getSource());
2724 std::optional<Value> destCastSource =
2725 getSourceOfCastOp(insertSliceOp.getDest());
2726 if (!sourceCastSource && !destCastSource)
2730 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2731 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2732 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
2733 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2734 if (!srcType || !dstType)
2742 staticSizes, srcType.getShape(),
true);
2743 if (!rankReductionMask.has_value())
2751 int64_t rankReducedIdx = 0;
2752 for (
auto [idx, size] :
enumerate(staticSizes)) {
2753 if (!rankReductionMask.value().contains(idx) &&
2754 !srcType.isDynamicDim(rankReducedIdx)) {
2756 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
2757 size = srcType.getDimSize(rankReducedIdx++);
2761 staticSizes, insertSliceOp.getStaticStrides()) !=
2766 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2767 mixedSizes, insertSliceOp.getMixedStrides());
2770 bool isParallelInsert =
2771 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2772 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2773 replacement = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2774 insertSliceOp.getDestType(),
2803 template <
typename InsertOpTy>
2804 struct InsertSliceOpSourceCastInserter final
2808 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2810 RankedTensorType srcType = insertSliceOp.getSourceType();
2811 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2814 srcType.getShape().end());
2815 for (int64_t i = 0; i < srcType.getRank(); ++i) {
2816 if (std::optional<int64_t> constInt =
2821 newSrcShape[i] = *constInt;
2828 newSrcShape, srcType.getElementType(), srcType.getEncoding());
2829 if (srcType == newSrcType ||
2831 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2843 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2846 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2848 insertSliceOp, cast, insertSliceOp.getDest(),
2849 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2850 insertSliceOp.getMixedStrides());
2862 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2863 InsertSliceOpCastFolder<InsertSliceOp>,
2864 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2871 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
2872 unsigned rank = rankedTensorType.getRank();
2876 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2885 setNameFn(getResult(),
"padded");
2891 Type typeToInfer,
Type typeToInferFrom) {}
2895 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2896 Type &typeToInfer,
Type typeToInferFrom) {
2898 typeToInfer = typeToInferFrom;
2903 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
2904 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
2906 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2907 if (!expectedType) {
2908 return emitError(
"failed to infer expectedType from sourceType ")
2909 << sourceType <<
", specified resultType is " << resultType;
2911 if (resultType.getRank() != expectedType.getRank()) {
2913 << resultType <<
" does not match the inferred type "
2916 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
2917 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2919 if (expectedType.isDynamicDim(i))
2922 << resultType <<
" does not match the inferred type "
2929 LogicalResult PadOp::verifyRegions() {
2930 auto ®ion = getRegion();
2931 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
2934 return emitError(
"expected the block to have ") << rank <<
" arguments";
2938 if (!en.value().isIndex())
2939 return emitOpError(
"expected block argument ")
2940 << (en.index() + 1) <<
" to be an index";
2945 if (yieldOp.getValue().getType() !=
2947 return emitOpError(
"expected yield type to match shape element type");
2952 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2956 unsigned rank = sourceType.getRank();
2957 if (staticLow.size() != rank)
2958 return RankedTensorType();
2959 if (staticHigh.size() != rank)
2960 return RankedTensorType();
2961 if (!resultShape.empty() && resultShape.size() != rank)
2962 return RankedTensorType();
2965 for (
auto i : llvm::seq<unsigned>(0, rank)) {
2966 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2967 staticHigh[i] == ShapedType::kDynamic) {
2968 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2971 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2972 assert((resultShape.empty() || size == resultShape[i] ||
2973 resultShape[i] == ShapedType::kDynamic) &&
2974 "mismatch between inferred shape and result shape");
2975 inferredShape.push_back(size);
2986 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
2988 resultType = inferResultType(sourceType, staticLow, staticHigh);
2990 build(b, result, resultType, source, low, high,
2998 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
2999 unsigned rank = sourceType.getRank();
3001 build(b, result, resultType, source, staticVector, staticVector, low, high,
3009 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3019 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3021 assert(llvm::isa<RankedTensorType>(resultType));
3023 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3032 build(b, result, resultType, source, low, high, nofold, attrs);
3036 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3043 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3047 llvm::SmallBitVector PadOp::getPaddedDims() {
3048 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3050 for (
const auto &en :
enumerate(paddingWidths))
3052 paddedDims.set(en.index());
3054 extractPaddedDims(getMixedLowPad());
3055 extractPaddedDims(getMixedHighPad());
3065 LogicalResult matchAndRewrite(PadOp padTensorOp,
3067 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3069 if (padTensorOp.getNofold())
3072 padTensorOp, padTensorOp.getResult().getType(),
3073 padTensorOp.getSource());
3082 LogicalResult matchAndRewrite(PadOp padTensorOp,
3084 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3088 auto newResultType = PadOp::inferResultType(
3089 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3090 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3091 padTensorOp.getResultType().getShape());
3093 if (newResultType == padTensorOp.getResultType()) {
3095 padTensorOp.getSourceMutable().assign(castOp.getSource());
3098 auto newOp = rewriter.
create<PadOp>(
3099 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3100 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3101 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3104 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3107 padTensorOp, padTensorOp.getResultType(), newOp);
3118 LogicalResult matchAndRewrite(PadOp padTensorOp,
3120 if (!padTensorOp.getResult().hasOneUse())
3123 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3127 tensorCastOp.getDest().getType()))
3130 auto replacementOp = rewriter.
create<PadOp>(
3131 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3132 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3133 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3134 padTensorOp.getHigh(), padTensorOp.getNofold(),
3138 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3139 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3182 LogicalResult matchAndRewrite(PadOp padOp,
3184 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3187 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3188 if (!outerPadOp || outerPadOp.getNofold())
3190 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3195 int64_t rank = padOp.getSourceType().getRank();
3196 if (outerSliceOp.getSourceType().getRank() != rank) {
3198 "cannot fold rank-reducing chain");
3202 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3204 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3208 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3210 "cannot fold PadOps with low padding");
3215 Value innerValue = padOp.getConstantPaddingValue();
3216 Value outerValue = outerPadOp.getConstantPaddingValue();
3217 if (!innerValue || !outerValue ||
3220 innerAttr != outerAttr) {
3222 padOp,
"cannot fold PadOps with different padding values");
3226 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3227 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3228 if (innerDims.anyCommon(outerDims)) {
3230 padOp,
"cannot fold PadOps with common padding dimensions");
3240 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3241 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3242 if (!innerDims.test(en.index()) &&
3244 en.value() = outerOffset;
3247 if (!outerDims.test(en.index()) &&
3249 en.value() = innerOffset;
3253 padOp,
"cannot find zero-offset and zero-padding pair");
3263 if (!outerDims.test(en.index()))
3265 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3266 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3267 assert(!ShapedType::isDynamic(sourceSize) &&
3268 "expected padded dimension to have a static size");
3271 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3272 "match the size of the outer padding");
3274 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3280 if (innerDims.test(en.index()))
3281 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3282 if (outerDims.test(en.index()))
3283 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3288 auto newSliceOp = rewriter.
create<ExtractSliceOp>(
3289 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3290 innerSliceOp.getMixedStrides());
3291 auto newPadOp = rewriter.
create<PadOp>(
3292 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3293 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3296 newPadOp.getRegion().begin());
3297 rewriter.
replaceOp(padOp, newPadOp.getResult());
3305 LogicalResult matchAndRewrite(PadOp padTensorOp,
3307 Value input = padTensorOp.getSource();
3308 if (!llvm::isa<RankedTensorType>(input.
getType()))
3310 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3311 auto inputRank = inputDims.size();
3313 auto oldResultType =
3314 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3318 auto outputDims = oldResultType.getShape();
3323 for (
auto operand : padTensorOp.getLow()) {
3326 constOperandsLow.push_back(ShapedType::kDynamic);
3327 newLows.push_back(operand);
3330 constOperandsLow.push_back(intOp.getExtValue());
3334 for (
auto operand : padTensorOp.getHigh()) {
3337 constOperandsHigh.push_back(ShapedType::kDynamic);
3338 newHighs.push_back(operand);
3341 constOperandsHigh.push_back(intOp.getExtValue());
3348 if (inputDims.size() != outputDims.size() ||
3349 inputDims.size() != constLow.size() ||
3350 inputDims.size() != constHigh.size())
3355 for (
size_t i = 0; i < inputRank; i++) {
3356 if (constLow[i] == ShapedType::kDynamic)
3357 constLow[i] = constOperandsLow[lowCount++];
3358 if (constHigh[i] == ShapedType::kDynamic)
3359 constHigh[i] = constOperandsHigh[highCount++];
3367 for (
size_t i = 0; i < inputRank; i++) {
3368 if (outputDims[i] == ShapedType::kDynamic) {
3369 newOutDims.push_back(
3370 (staticLow[i] == ShapedType::kDynamic ||
3371 staticHigh[i] == ShapedType::kDynamic ||
3372 inputDims[i] == ShapedType::kDynamic
3373 ? ShapedType::kDynamic
3374 : inputDims[i] + staticLow[i] + staticHigh[i]));
3376 newOutDims.push_back(outputDims[i]);
3381 llvm::all_of(newOutDims,
3382 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3387 newOutDims, padTensorOp.getType().getElementType());
3388 auto newOp = rewriter.
create<PadOp>(
3389 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3390 newLows, newHighs, padTensorOp.getNofold(),
3394 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3406 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3407 FoldOrthogonalPaddings, FoldStaticPadding>(context);
3419 Value PadOp::getConstantPaddingValue() {
3420 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3423 Value padValue = yieldOp.getValue();
3435 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3445 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3446 ParallelCombiningOpInterface parallelCombiningParent =
3447 getParallelCombiningParent();
3448 for (
const auto &it :
3451 if (&nextOp == getOperation())
3452 return parallelCombiningParent.getParentResult(it.index());
3454 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3470 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3483 build(b, result, source, dest, offsets, sizes, strides, attrs);
3497 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3501 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3502 return this->
emitError(
"expected ParallelCombiningOpInterface parent, got:")
3503 << *(getOperation()->getParentOp());
3505 RankedTensorType expectedType;
3508 getStaticSizes(), getStaticStrides(), &expectedType);
3512 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3514 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3515 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3516 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3527 void ScatterOp::getAsmResultNames(
3529 setNameFn(getResult(),
"scatter");
3533 int64_t destRank = getDestType().getRank();
3536 "scatter",
"dest")))
3540 return emitOpError(
"requires 'unique' attribute to be set");
3547 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3548 getDestType(), getIndicesType(), scatterDims,
false);
3549 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3550 getDestType(), getIndicesType(), scatterDims,
true);
3551 if (getSourceType() != expectedSourceType &&
3552 getSourceType() != expectedRankReducedSourceType) {
3553 return emitOpError(
"source type "
3556 << expectedSourceType <<
" or its rank-reduced variant "
3557 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3570 build(builder, result, aggregateType, element, dynamicSizes);
3576 build(builder, result, aggregateType, element, dynamicSizes);
3584 build(builder, result, element, staticShape, dynamicSizes);
3587 void SplatOp::getAsmResultNames(
3589 setNameFn(getResult(),
"splat");
3593 if (
getType().getNumDynamicDims() !=
3595 return emitOpError(
"incorrect number of dynamic sizes, has ")
3597 <<
getType().getNumDynamicDims();
3606 for (int64_t i = 0; i <
getType().getRank(); ++i) {
3607 if (
getType().isDynamicDim(i)) {
3617 auto constOperand = adaptor.getInput();
3618 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3622 if (!
getType().hasStaticShape())
3634 template <
typename OpTy>
3635 static LogicalResult
3638 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3639 "applies to only pack or unpack operations");
3640 int64_t destRank = op.getDestRank();
3642 reifiedReturnShapes[0] =
3647 template <
typename OpTy>
3649 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3650 "applies to only pack or unpack operations");
3654 assert(tiles.size() == dimsToTile.size() &&
3655 "tiles must match indices of dimension to block");
3657 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3658 dimAndTileMapping[dimsToTile[i]] = tiles[i];
3659 return dimAndTileMapping;
3662 template <
typename OpTy>
3664 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3665 "applies to only pack or unpack operations");
3668 unsigned dynamicValIndex = 0;
3669 for (int64_t staticTile : op.getStaticInnerTiles()) {
3670 if (!ShapedType::isDynamic(staticTile))
3673 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3675 return mixedInnerTiles;
3678 template <
typename OpTy>
3680 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3681 "applies to only pack or unpack operations");
3694 size_t dimsPosSize = dimsPos.size();
3695 if (dimsPosSize > rank)
3698 for (int64_t dim : dimsPos)
3699 uniqued.insert(dim);
3700 if (dimsPosSize != uniqued.size())
3702 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3703 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
3712 sourceShape.size() == limitShape.size() &&
3713 "expected source shape rank, and limit of the shape to have same rank");
3714 return llvm::all_of(
3715 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3716 int64_t sourceExtent = std::get<0>(it);
3717 int64_t limit = std::get<1>(it);
3718 return ShapedType::isDynamic(sourceExtent) ||
3719 ShapedType::isDynamic(limit) || sourceExtent <= limit;
3723 template <
typename OpTy>
3725 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3726 "applies to only pack or unpack operations");
3727 Operation *op = packOrUnPack.getOperation();
3731 return llvm::any_of(
3737 if (hasZeros(mixedTiles))
3738 return op->
emitError(
"invalid zero tile factor");
3741 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3742 ? packOrUnPack.getSourceType()
3743 : packOrUnPack.getDestType();
3744 size_t unpackedRank = unpackedType.getRank();
3748 return op->
emitError(
"invalid inner_dims_pos vector");
3750 return op->
emitError(
"invalid outer_dims_perm vector");
3751 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3752 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
3756 if (mixedTiles.size() > unpackedRank) {
3757 return op->
emitError(
"tiling factors must be less than or equal to the "
3758 "input rank for pack or output rank for unpack");
3760 if (mixedTiles.size() != innerDimsPos.size()) {
3762 "tiling factors must equal the number of dimensions to tile");
3765 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3766 ? packOrUnPack.getDestType()
3767 : packOrUnPack.getSourceType();
3768 size_t packedRank = packedType.getRank();
3770 if (unpackedRank + mixedTiles.size() != packedRank) {
3772 "packed rank must equal unpacked rank + tiling factors");
3778 RankedTensorType expectedPackedType = PackOp::inferPackedType(
3779 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3780 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3781 return op->
emitError(
"the shape of output is not large enough to hold the "
3782 "packed data. Expected at least ")
3783 << expectedPackedType <<
", got " << packedType;
3786 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3788 [](std::tuple<int64_t, OpFoldResult> it) {
3789 std::optional<int64_t> constTileSize =
3790 getConstantIntValue(std::get<1>(it));
3791 int64_t shape = std::get<0>(it);
3792 if (!constTileSize) {
3795 return ShapedType::isDynamic(shape);
3797 if (ShapedType::isDynamic(shape)) {
3804 return shape == constTileSize.value();
3806 return op->
emitError(
"mismatch in inner tile sizes specified and shaped of "
3807 "tiled dimension in the packed type");
3819 struct PackOrUnPackTransposeResult {
3826 template <
typename OpTy>
3827 static PackOrUnPackTransposeResult
3831 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3832 "applies to only pack or unpack operations");
3833 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3834 "some permutation must be non-empty");
3835 PackOrUnPackTransposeResult metadata;
3836 metadata.innerDimsPos =
3838 metadata.innerTiles =
3840 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3841 ? packOrUnPackOp.getSourceRank()
3842 : packOrUnPackOp.getDestRank();
3843 metadata.outerDimsPerm =
3844 packOrUnPackOp.getOuterDimsPerm().empty()
3845 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3847 if (!innerPermutation.empty()) {
3848 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3850 "invalid inner permutation");
3854 if (!outerPermutation.empty()) {
3855 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3857 "invalid outer permutation");
3867 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3868 setNameFn(getResult(),
"pack");
3874 std::optional<Value> paddingValue,
3876 assert(innerDimsPos.size() == innerTiles.size() &&
3877 "number of tile sizes specified must match the specified number of "
3878 "original dimensions to be tiled");
3882 build(builder, state, dest.
getType(), source, dest,
3883 paddingValue ? *paddingValue :
nullptr,
3884 outerDimsPerm.empty() ?
nullptr
3914 outputShape.take_front(inputShape.size()));
3915 if (!outerDimsPerm.empty()) {
3916 assert(outerDimsPerm.size() == outputTileSizes.size() &&
3917 "expected output and outer_dims_perm to have same size");
3921 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
3922 if (ShapedType::isDynamic(inputShape[pos]))
3926 if (!constantTile) {
3927 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
3928 (inputShape[pos] % outputTileSizes[pos] != 0))
3930 }
else if (inputShape[pos] % (*constantTile) != 0) {
3944 auto paddingValue = getPaddingValue();
3947 return emitOpError(
"expected padding_value has ")
3948 << getSourceType().getElementType()
3949 <<
" but got: " << paddingValue.getType();
3952 if (!paddingValue &&
3953 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
3954 getDestType().
getShape(), getOuterDimsPerm(),
3957 "invalid tile factor or output size provided. Only full tiles are "
3958 "supported when padding_value is not set");
3968 for (
auto o : ofrs) {
3970 if (llvm::dyn_cast_if_present<Value>(o))
3971 result.push_back(ShapedType::kDynamic);
3985 for (
auto tiledDim :
llvm::enumerate(llvm::to_vector(innerDimsPos))) {
3986 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
3988 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
3989 resultShape[tiledDim.value()] = ShapedType::kDynamic;
3992 resultShape[tiledDim.value()] = divideCeilSigned(
3993 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
3997 if (!outerDimsPerm.empty())
4001 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4014 for (
auto tiledDim :
llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4016 builder, loc, ceilDivExpr,
4017 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4019 if (!outerDimsPerm.empty())
4021 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4026 innerDimsPos, outerDimsPerm);
4032 for (
unsigned i = 0; i < resultDims.size(); ++i) {
4033 if (!ShapedType::isDynamic(resultTypeShape[i]))
4044 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4049 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4066 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
4067 if (ShapedType::isDynamic(value))
4068 mixedSizes.push_back(b.
create<DimOp>(loc, source, index).
getResult());
4072 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4073 int64_t dimPos = std::get<0>(it);
4075 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4077 if (!outerDimsPerm.empty())
4078 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4080 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4081 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4082 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4089 *
this, innerPermutation, outerPermutation);
4090 Value transposedDest =
4091 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4092 metadata.innerDimsPos, metadata.outerDimsPerm);
4093 return b.
create<PackOp>(loc, getSource(), transposedDest,
4094 metadata.innerDimsPos, metadata.innerTiles,
4095 getPaddingValue(), metadata.outerDimsPerm);
4099 template <
typename OpTy>
4101 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4102 "applies to only pack or unpack operations");
4103 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4105 : op.getSourceType();
4107 for (
auto [dimDest,
tile] : llvm::zip(
4108 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4110 if (!constTileSize || ShapedType::isDynamic(dimDest))
4117 if (getPaddingValue())
4132 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4134 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4146 auto packTiles = packOp.getMixedTiles();
4147 auto unPackTiles = unPackOp.getMixedTiles();
4148 if (packTiles.size() != unPackTiles.size())
4150 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
4159 auto srcType = op.getSourceType();
4160 if (llvm::any_of(op.getInnerDimsPos(),
4161 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4163 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4165 return !PackOp::requirePaddingValue(
4166 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4167 op.getOuterDimsPerm(), op.getMixedTiles());
4174 bool changeNeeded =
false;
4175 srcShape.assign(packOp.getSourceType().getShape().begin(),
4176 packOp.getSourceType().getShape().end());
4177 destShape.assign(packOp.getDestType().getShape().begin(),
4178 packOp.getDestType().getShape().end());
4179 llvm::SmallSetVector<int64_t, 4> innerDims;
4180 innerDims.insert(packOp.getInnerDimsPos().begin(),
4181 packOp.getInnerDimsPos().end());
4183 if (!packOp.getOuterDimsPerm().empty())
4185 int srcRank = packOp.getSourceRank();
4186 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
4187 if (innerDims.contains(i))
4190 int64_t destPos = i;
4191 if (!inverseOuterDimsPerm.empty())
4192 destPos = inverseOuterDimsPerm[srcPos];
4193 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4194 ShapedType::isDynamic(destShape[destPos])) {
4197 int64_t size = srcShape[srcPos];
4198 if (ShapedType::isDynamic(size))
4199 size = destShape[destPos];
4200 srcShape[srcPos] = size;
4201 destShape[destPos] = size;
4202 changeNeeded =
true;
4204 return changeNeeded;
4207 LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
4209 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4210 if (unPackOp.getSourceType() != packOp.getDestType())
4212 if (packOp.getPaddingValue() ||
4216 rewriter.
replaceOp(packOp, unPackOp.getSource());
4223 packOp.getPaddingValueMutable().clear();
4232 Value source = packOp.getSource();
4233 if (srcShape != packOp.getSourceType().getShape()) {
4234 auto newSrcType = packOp.getSourceType().clone(srcShape);
4236 rewriter.
create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4238 Value dest = packOp.getDest();
4239 if (destShape != packOp.getDestType().getShape()) {
4240 auto newDestType = packOp.getDestType().clone(destShape);
4242 rewriter.
create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4245 loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
4246 packOp.getPaddingValue(), packOp.getOuterDimsPerm());
4248 packOp, packOp.getResult().getType(), newOp);
4255 template <
typename PackOrUnpackOp>
4257 RankedTensorType packedTensorType) {
4258 static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4259 std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4260 "Function meant for pack/unpack");
4265 int64_t numPackedDims = innerDimsPos.size();
4266 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4267 if (orderedDims != innerDimsPos) {
4273 int64_t packedRank = packedTensorType.getRank();
4283 return llvm::all_of(
4284 llvm::seq<int64_t>(0, packedRank - numPackedDims),
4285 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
4288 bool PackOp::isLikePad() {
4289 auto packedTensorType =
4290 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4295 std::optional<Attribute> paddingValue;
4296 if (
auto pad = adaptor.getPaddingValue())
4299 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4300 getDestType(), paddingValue))
4301 return reshapedSource;
4309 void UnPackOp::getAsmResultNames(
4311 setNameFn(getResult(),
"unpack");
4348 assert(innerDimsPos.size() == innerTiles.size() &&
4349 "number of tile sizes specified must match the specified number of "
4350 "original dimensions to be tiled");
4354 build(builder, state, dest.
getType(), source, dest,
4355 outerDimsPerm.empty() ?
nullptr
4373 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
4375 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4376 if (srcType.isDynamicDim(i))
4379 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
4381 if (!outerDimsPerm.empty()) {
4382 applyPermutationToVector<OpFoldResult>(
4386 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4387 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4389 auto elemType = srcType.getElementType();
4390 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4394 Value transposedSource,
4398 *
this, innerPermutation, outerPermutation);
4399 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
4400 metadata.innerDimsPos, metadata.innerTiles,
4401 metadata.outerDimsPerm);
4408 bool changeNeeded =
false;
4409 srcShape.assign(op.getSourceType().getShape().begin(),
4410 op.getSourceType().getShape().end());
4411 destShape.assign(op.getDestType().getShape().begin(),
4412 op.getDestType().getShape().end());
4413 llvm::SmallSetVector<int64_t, 4> innerDims;
4414 innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4416 if (!op.getOuterDimsPerm().empty())
4418 int destRank = op.getDestRank();
4419 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
4420 if (innerDims.contains(i))
4423 int64_t destPos = i;
4424 if (!inverseOuterDimsPerm.empty())
4425 srcPos = inverseOuterDimsPerm[destPos];
4426 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4427 ShapedType::isDynamic(destShape[destPos])) {
4430 int64_t size = srcShape[srcPos];
4431 if (ShapedType::isDynamic(size))
4432 size = destShape[destPos];
4433 srcShape[srcPos] = size;
4434 destShape[destPos] = size;
4435 changeNeeded =
true;
4437 return changeNeeded;
4440 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4443 if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4444 if (packOp.getDestType() != unPackOp.getSourceType())
4446 if (packOp.getPaddingValue() ||
4450 rewriter.
replaceOp(unPackOp, packOp.getSource());
4454 if (
auto dstStyleOp =
4455 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4456 auto destValue = cast<OpResult>(unPackOp.getDest());
4457 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4459 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4467 Value source = unPackOp.getSource();
4468 if (srcShape != unPackOp.getSourceType().getShape()) {
4469 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4470 source = rewriter.
create<tensor::CastOp>(loc, newSrcType,
4471 unPackOp.getSource());
4473 Value dest = unPackOp.getDest();
4474 if (destShape != unPackOp.getDestType().getShape()) {
4475 auto newDestType = unPackOp.getDestType().clone(destShape);
4477 rewriter.
create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4480 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4481 unPackOp.getOuterDimsPerm());
4483 unPackOp, unPackOp.getResult().getType(), newOp);
4490 bool UnPackOp::isLikeUnPad() {
4491 RankedTensorType packedTensorType = getSourceType();
4497 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4499 return reshapedSource;
4531 if (isa<InsertSliceOp>(op.getOperation()))
4536 if (isa<LoopLikeOpInterface>(op.getOperation()))
4540 bool hasTensorCastOperand =
4542 if (llvm::isa<BlockArgument>(opOperand.get()))
4544 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4545 return castOp && canFoldIntoConsumerOp(castOp);
4547 if (!hasTensorCastOperand)
4554 int64_t dpsInitIdx = 0;
4558 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.
get());
4559 if (op.isDpsInit(&opOperand) &&
4560 !llvm::isa<MemRefType>(newOperands.back().getType()))
4561 newResultTypes[dpsInitIdx++] = newOperands.back().getType();
4565 Operation *newOp =
clone(rewriter, op, newResultTypes, newOperands);
4568 for (
auto [oldResult, newResult] :
4570 if (newResult.
getType() != oldResult.getType()) {
4571 replacements.push_back(rewriter.
create<tensor::CastOp>(
4572 op->
getLoc(), oldResult.getType(), newResult));
4574 replacements.push_back(newResult);
4587 void TensorDialect::getCanonicalizationPatterns(
4596 #define GET_OP_CLASSES
4597 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static int64_t getNumElements(ShapedType type)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result, std::optional< Attribute > cst=std::nullopt)
Try to remove a tensor operation if it would only reshape a constant.
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)