20 #include "llvm/ADT/STLExtras.h" 21 #include "llvm/ADT/SmallBitVector.h" 31 if (arith::ConstantOp::isBuildableWith(value, type))
32 return builder.
create<arith::ConstantOp>(loc,
value, type);
33 if (complex::ConstantOp::isBuildableWith(value, type))
34 return builder.
create<complex::ConstantOp>(loc, type,
35 value.
cast<ArrayAttr>());
46 auto sourceType = source.
dyn_cast<RankedTensorType>();
47 auto targetType = target.
dyn_cast<RankedTensorType>();
50 if (!sourceType || !targetType)
54 if (sourceType.getElementType() != targetType.getElementType())
58 if (sourceType.getRank() != targetType.getRank())
62 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
63 if (!ShapedType::isDynamic(std::get<0>(t)) &&
64 ShapedType::isDynamic(std::get<1>(t)))
100 castOp.getSource().getType());
135 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
137 operand.set(castOp.getOperand());
145 if (inputs.size() != 1 || outputs.size() != 1)
147 Type a = inputs.front(), b = outputs.front();
169 int64_t rank = one.getRank();
170 if (rank != two.getRank())
175 for (int64_t i = 0; i < rank; ++i) {
176 if (one.isDynamicDim(i)) {
177 join.push_back(two.getDimSize(i));
180 if (two.isDynamicDim(i)) {
181 join.push_back(one.getDimSize(i));
184 if (one.getDimSize(i) != two.getDimSize(i))
186 join.push_back(one.getDimSize(i));
200 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
202 if (!tensorCastOperand)
206 tensorCastOperand.getOperand().getType().cast<
TensorType>();
207 auto intermediateType = tensorCastOperand.getType().
cast<
TensorType>();
222 auto newJoin =
joinShapes(sourceType, resultType);
223 if (firstJoin != newJoin)
227 tensorCastOperand.getOperand());
249 auto extractOperand =
250 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
253 tensorCast.getType().getShape() == tensorCast.getSource()
255 .cast<RankedTensorType>()
262 extractOperand.getType().getShape());
264 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
265 if (dimMask && dimMask->count(i))
267 int64_t dim = tensorCast.getType().getShape()[dimIndex++];
268 if (ShapedType::isDynamic(dim))
274 tensorCast, tensorCast.getType().cast<RankedTensorType>(),
275 extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
276 extractOperand.getMixedStrides());
285 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
296 build(builder, result, source, indexValue);
299 Optional<int64_t> DimOp::getConstantIndex() {
300 if (
auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
301 return constantOp.getValue().cast<IntegerAttr>().getInt();
307 Optional<int64_t> index = getConstantIndex();
312 auto type = getSource().getType();
313 if (
auto tensorType = type.
dyn_cast<RankedTensorType>()) {
314 if (*index >= tensorType.getRank())
315 return emitOpError(
"index is out of range");
316 }
else if (type.
isa<UnrankedTensorType>()) {
319 llvm_unreachable(
"expected operand with tensor type");
326 auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
331 auto tensorType = getSource().getType().dyn_cast<RankedTensorType>();
336 if (!tensorType.isDynamicDim(index.getInt())) {
338 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
341 Operation *definingOp = getSource().getDefiningOp();
344 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
349 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
352 auto dynExtents = fromElements.getDynamicExtents().begin();
353 for (
auto dim : resultType.getShape().take_front(index.getInt()))
354 if (ShapedType::isDynamic(dim))
357 return Value{*dynExtents};
361 unsigned unsignedIndex = index.getValue().getZExtValue();
363 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
366 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
367 sliceOp.isDynamicSize(unsignedIndex)) {
368 return {sliceOp.getDynamicSize(unsignedIndex)};
386 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
389 Value newSource = castOp.getOperand();
398 results.
add<DimOfCastOp>(context);
407 if (
auto tensorType = getTensor().getType().dyn_cast<RankedTensorType>())
408 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
409 return emitOpError(
"incorrect number of indices for extract_element");
419 return splatTensor.getSplatValue<
Attribute>();
423 for (
Attribute indice : llvm::drop_begin(operands, 1)) {
424 if (!indice || !indice.isa<IntegerAttr>())
426 indices.push_back(indice.cast<IntegerAttr>().getInt());
430 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
431 auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
432 auto rank = tensorType.getRank();
433 assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
437 for (
int i = rank - 1; i >= 0; --i) {
439 stride *= tensorType.getDimSize(i);
440 flatIndex += indices[i] * stride;
444 if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
447 return fromElementsOp.getElements()[flatIndex];
451 if (
Attribute tensor = operands.front()) {
452 auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
453 if (elementsAttr && elementsAttr.isValidIndex(indices))
454 return elementsAttr.getValues<
Attribute>()[indices];
472 assert(!elements.empty() &&
"expected at least one element");
473 Type resultType = RankedTensorType::get(
474 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
475 build(builder, result, resultType, elements);
479 if (!llvm::is_contained(operands,
nullptr))
501 struct ExtractElementFromIndexCast
508 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
514 auto newExtract = rewriter.
create<tensor::ExtractOp>(
515 loc, elementTy, indexCast.getIn(), extract.getIndices());
528 results.
add<ExtractElementFromIndexCast>(context);
537 if (
auto destType = getDest().getType().dyn_cast<RankedTensorType>())
538 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
539 return emitOpError(
"incorrect number of indices");
548 if (scalar == splatDest.getSplatValue<
Attribute>())
561 for (
auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
562 if (getType().isDynamicDim(dim)) {
563 reifiedReturnShapes[0][dim] = getOperand(idx++);
566 getLoc(), getType().getDimSize(dim));
575 RankedTensorType resultTy = getType().cast<RankedTensorType>();
576 if (getNumOperands() != resultTy.getNumDynamicDims())
577 return emitError(
"must have as many index operands as dynamic extents " 578 "in the result type");
584 RankedTensorType resultTy = getType().cast<RankedTensorType>();
586 if (!llvm::all_of(getBody().getArgumentTypes(),
588 return emitError(
"all body arguments must be index");
589 if (getBody().getNumArguments() != resultTy.getRank())
590 return emitError(
"must have one body argument per input dimension");
593 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
595 if (yieldOp.getValue().getType() != resultTy.getElementType())
597 "body must be terminated with a `yield` operation of the tensor " 603 void GenerateOp::build(
607 build(b, result, resultTy, dynamicExtents);
612 auto rank = resultTy.
cast<RankedTensorType>().getRank();
616 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
632 tensorFromElements.getResult().getType().
cast<RankedTensorType>();
634 if (resultType.hasStaticShape())
639 auto operandsIt = tensorFromElements.getDynamicExtents().begin();
641 for (int64_t dim : resultType.getShape()) {
642 if (!ShapedType::isDynamic(dim)) {
643 newShape.push_back(dim);
648 newShape.push_back(ShapedType::kDynamicSize);
649 newOperands.push_back(*operandsIt++);
652 newShape.push_back(index.getSExtValue());
656 if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
659 auto loc = tensorFromElements.getLoc();
660 auto newOp = rewriter.
create<GenerateOp>(
661 loc, RankedTensorType::get(newShape, resultType.getElementType()),
664 newOp.getBody().begin());
682 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
687 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
692 Block *body = &tensorFromElements.getBody().
front();
695 rewriter.
clone(op, mapping);
717 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
722 extract, tensorCast.getSource(), extract.getIndices());
732 results.
add<ExtractFromTensorGenerate, ExtractFromTensorCast,
733 StaticTensorGenerate>(context);
742 auto type = getOperand().getType();
743 auto shapedType = type.
dyn_cast<ShapedType>();
744 if (shapedType && shapedType.hasRank())
745 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
746 return IntegerAttr();
754 int64_t numElements = 1;
755 for (
auto dim : type.getShape())
765 return emitOpError(
"element types of source and destination tensor " 766 "types should be the same");
769 getShape().getType().cast<RankedTensorType>().getDimSize(0);
770 auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
771 auto operandRankedType = operandType.
dyn_cast<RankedTensorType>();
773 if (resultRankedType) {
774 if (operandRankedType && resultRankedType.hasStaticShape() &&
775 operandRankedType.hasStaticShape()) {
777 return emitOpError(
"source and destination tensor should have the " 778 "same number of elements");
780 if (ShapedType::isDynamic(shapeSize))
781 return emitOpError(
"cannot use shape operand with dynamic length to " 782 "reshape to statically-ranked tensor type");
783 if (shapeSize != resultRankedType.getRank())
785 "length of shape operand differs from the result's tensor rank");
799 getReassociationIndices());
807 getReassociationIndices());
811 static RankedTensorType
814 auto shape = type.getShape();
816 newShape.reserve(reassociation.size());
821 unsigned currentDim = 0;
823 unsigned dim = m.getNumResults();
824 auto band = shape.slice(currentDim, dim);
826 if (llvm::is_contained(band, ShapedType::kDynamicSize))
827 size = ShapedType::kDynamicSize;
829 for (
unsigned d = 0; d < dim; ++d)
830 size *= shape[currentDim + d];
831 newShape.push_back(size);
835 return RankedTensorType::get(newShape, type.getElementType());
845 build(b, result, resultType, src, attrs);
852 if (
auto rtp1 = tp1.
dyn_cast<RankedTensorType>()) {
853 if (
auto rtp2 = tp2.
dyn_cast<RankedTensorType>())
854 return rtp1.getShape() == rtp2.getShape() &&
855 rtp1.getElementType() == rtp2.getElementType();
862 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
863 TensorReshapeOp, ExpandShapeOp>
::value>
865 RankedTensorType expandedType,
866 RankedTensorType collapsedType) {
871 auto maps = op.getReassociationMaps();
872 RankedTensorType expectedType =
875 return op.emitOpError(
"expected collapsed type to be ")
876 << expectedType <<
", but got " << collapsedType;
891 template <
typename TensorReshapeOp>
902 reshapeOp.getResultType(), attr.
getRawData());
910 template <
typename TensorReshapeOp>
916 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
920 auto shapedTy = reshapeOp.getType().template cast<ShapedType>();
922 if (!shapedTy.hasStaticShape())
926 fromElements.getElements());
937 FoldReshapeWithConstant<ExpandShapeOp>,
938 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
945 FoldReshapeWithConstant<CollapseShapeOp>,
946 FoldReshapeWithFromElements<CollapseShapeOp>>(context);
950 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this, operands);
953 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this, operands);
963 RankedTensorType ExtractSliceOp::inferResultType(
969 assert(static_cast<int64_t>(staticSizes.size()) ==
970 sourceShapedTensorType.getRank() &&
971 "unexpected staticSizes not equal to rank of source");
972 return RankedTensorType::get(staticSizes,
973 sourceShapedTensorType.getElementType());
976 RankedTensorType ExtractSliceOp::inferResultType(
982 ShapedType::kDynamicStrideOrOffset);
984 ShapedType::kDynamicSize);
986 ShapedType::kDynamicStrideOrOffset);
987 return ExtractSliceOp::inferResultType(sourceShapedTensorType, staticOffsets,
988 staticSizes, staticStrides);
999 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1000 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1005 inferResultType(sourceRankedTensorType, offsets, sizes, strides)
1006 .cast<RankedTensorType>();
1007 int rankDiff = inferredType.getRank() - desiredResultRank;
1009 auto shape = inferredType.getShape();
1010 llvm::SmallBitVector dimsToProject =
1014 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1015 if (!dimsToProject.test(pos))
1016 projectedShape.push_back(shape[pos]);
1018 RankedTensorType::get(projectedShape, inferredType.getElementType());
1020 return inferredType;
1023 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1024 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1030 ShapedType::kDynamicStrideOrOffset);
1032 ShapedType::kDynamicSize);
1034 ShapedType::kDynamicStrideOrOffset);
1035 return ExtractSliceOp::inferCanonicalRankReducedResultType(
1036 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1043 RankedTensorType resultType,
Value source,
1051 ShapedType::kDynamicStrideOrOffset);
1053 ShapedType::kDynamicSize);
1055 ShapedType::kDynamicStrideOrOffset);
1056 auto sourceRankedTensorType = source.
getType().
cast<RankedTensorType>();
1060 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticOffsets,
1061 staticSizes, staticStrides)
1062 .cast<RankedTensorType>();
1064 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1077 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1083 RankedTensorType resultType,
Value source,
1092 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1099 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1102 template <
typename OpTy>
1104 OpTy op,
Type expectedType) {
1105 auto memrefType = expectedType.
cast<ShapedType>();
1110 return op.emitError(
"expected rank to be smaller or equal to ")
1111 <<
"the other rank. ";
1113 return op.emitError(
"expected type to be ")
1114 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
1116 return op.emitError(
"expected element type to be ")
1117 << memrefType.getElementType();
1119 llvm_unreachable(
"unexpected extract_slice op verification result");
1126 auto expectedType = ExtractSliceOp::inferResultType(
1132 llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
1135 llvm::SmallBitVector droppedDims(mixedSizes.size());
1136 unsigned shapePos = 0;
1137 for (
const auto &size :
enumerate(mixedSizes)) {
1142 if (!sizeVal || *sizeVal != 1 ||
1143 (shapePos < resultShape.size() && resultShape[shapePos] == 1)) {
1147 droppedDims.set(size.index());
1154 reifiedReturnShapes.resize(1);
1155 reifiedReturnShapes[0].reserve(getType().getRank());
1157 llvm::SmallBitVector droppedDims = getDroppedDims();
1159 for (
const auto &size :
enumerate(mixedSizes)) {
1160 if (droppedDims.test(size.index()))
1162 if (
auto attr = size.value().dyn_cast<
Attribute>()) {
1164 loc, attr.
cast<IntegerAttr>().getInt()));
1167 reifiedReturnShapes[0].push_back(size.value().get<
Value>());
1188 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
1195 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
1200 auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
1208 RankedTensorType resultType =
1209 ExtractSliceOp::inferCanonicalRankReducedResultType(
1210 sliceOp.getType().getRank(), sliceOp.getSourceType(),
1211 sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
1212 sliceOp.getMixedStrides());
1214 sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
1215 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
1216 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1226 template <
typename IterTy,
typename ElemTy>
1231 assert(offsets.size() == sizes.size());
1232 assert(offsets.size() == strides.size());
1233 if (offsets.empty())
1236 int64_t offset = offsets.front();
1237 int64_t size = sizes.front();
1238 int64_t stride = strides.front();
1239 if (offsets.size() == 1) {
1240 for (int64_t i = 0; i < size; ++i, offset += stride)
1241 outValues->push_back(*(values + offset));
1246 for (int64_t i = 0; i < size; ++i, offset += stride) {
1247 auto begin = values + offset * counts.front();
1248 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
1249 offsets.drop_front(), sizes.drop_front(),
1250 strides.drop_front(), outValues);
1257 class ConstantOpExtractSliceFolder final
1262 ConstantOpExtractSliceFolder(
MLIRContext *context,
1265 controlFn(std::move(controlFn)) {}
1278 auto sourceType = op.getSource().getType().cast<ShapedType>();
1279 auto resultType = op.getResult().getType().cast<ShapedType>();
1280 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
1287 int64_t count = sourceType.getNumElements();
1293 if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset))
1296 if (llvm::is_contained(sizes, ShapedType::kDynamicSize))
1299 if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset))
1305 counts.reserve(shape.size());
1306 for (int64_t v : shape) {
1308 counts.push_back(count);
1316 outValues.reserve(sourceType.getNumElements());
1317 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
1318 elems.begin(), counts, offsets, sizes, strides, &outValues);
1322 outValues.reserve(sourceType.getNumElements());
1323 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
1324 elems.begin(), counts, offsets, sizes, strides, &outValues);
1347 patterns.
add<ConstantOpExtractSliceFolder>(patterns.
getContext(), controlFn);
1356 return ExtractSliceOp::inferCanonicalRankReducedResultType(
1357 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
1365 ExtractSliceOp newOp) {
1366 Value replacement = newOp.getResult();
1367 if (replacement.
getType() != op.getType())
1368 replacement = rewriter.
create<tensor::CastOp>(op.getLoc(), op.getType(),
1379 ExtractSliceOpCastFolder>(context);
1385 ShapedType shapedType) {
1392 auto shape = shapedType.getShape();
1393 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
1407 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
1410 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
1411 insertOp.isSameAs(extractOp, isSame))
1412 return insertOp.getSource();
1418 if (
auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
1419 auto resultType = getResult().getType().cast<ShapedType>();
1420 if (resultType.hasStaticShape())
1421 return splat.resizeSplat(resultType);
1423 if (getSourceType() == getType() &&
1425 return this->getSource();
1434 auto rankedTensorType = tensor.
getType().
cast<RankedTensorType>();
1435 unsigned rank = rankedTensorType.getRank();
1436 auto shape = rankedTensorType.getShape();
1439 for (
unsigned i = 0, e = rank; i < e; ++i) {
1441 if (rankedTensorType.isDynamicDim(i))
1446 sizes.push_back(dim);
1449 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
1450 offsets, sizes, strides);
1466 ShapedType::kDynamicStrideOrOffset);
1468 ShapedType::kDynamicSize);
1470 ShapedType::kDynamicStrideOrOffset);
1471 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
1487 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
1492 ArrayAttr staticOffsets, ArrayAttr staticSizes,
1493 ArrayAttr staticStrides,
1494 ShapedType *expectedType =
nullptr) {
1496 auto expected = ExtractSliceOp::inferResultType(
1500 .cast<ShapedType>();
1502 *expectedType = expected;
1508 ShapedType expectedType;
1511 getStaticSizes(), getStaticStrides(), &expectedType);
1531 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
1534 if (!prevInsertOp ||
1535 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
1536 !prevInsertOp.isSameAs(insertOp, isSame))
1539 insertOp.getDestMutable().assign(prevInsertOp.getDest());
1544 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
1545 getSourceType() == getType() &&
1547 return this->getSource();
1556 for (
auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1557 reifiedReturnShapes[0][dim] =
1558 builder.
createOrFold<tensor::DimOp>(getLoc(), getDest(), dim);
1565 class InsertSliceOpConstantArgumentFolder final
1573 if (llvm::none_of(insertSliceOp.getOperands(), [](
Value operand) {
1589 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
1590 insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
1591 mixedOffsets, mixedSizes, mixedStrides);
1592 Value toInsert = insertSliceOp.getSource();
1593 if (sourceType != insertSliceOp.getSourceType())
1594 toInsert = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
1595 sourceType, toInsert);
1597 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
1598 mixedSizes, mixedStrides);
1621 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertSliceOp> {
1626 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
1631 auto getSourceOfCastOp = [](
Value v) -> Optional<Value> {
1632 auto castOp = v.getDefiningOp<tensor::CastOp>();
1635 return castOp.getSource();
1637 Optional<Value> sourceCastSource =
1638 getSourceOfCastOp(insertSliceOp.getSource());
1639 Optional<Value> destCastSource = getSourceOfCastOp(insertSliceOp.getDest());
1640 if (!sourceCastSource && !destCastSource)
1644 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
1645 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
1648 auto dstType = dst.getType().cast<ShapedType>();
1650 insertSliceOp.getStaticSizes(),
1651 insertSliceOp.getStaticStrides()) !=
1655 Value replacement = rewriter.
create<InsertSliceOp>(
1656 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
1657 insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
1659 if (replacement.
getType() != insertSliceOp.getType()) {
1660 replacement = rewriter.
create<tensor::CastOp>(
1661 insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
1663 rewriter.
replaceOp(insertSliceOp, replacement);
1687 struct InsertSliceOpSourceCastInserter final
1693 RankedTensorType srcType = insertSliceOp.getSourceType();
1694 if (srcType.getRank() != insertSliceOp.getType().getRank())
1697 srcType.getShape().end());
1698 for (int64_t i = 0; i < srcType.getRank(); ++i) {
1699 if (Optional<int64_t> constInt =
1701 newSrcShape[i] = *constInt;
1704 RankedTensorType newSrcType =
1705 RankedTensorType::get(newSrcShape, srcType.getElementType());
1706 if (srcType == newSrcType ||
1708 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
1717 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
1719 insertSliceOp, cast, insertSliceOp.getDest(),
1720 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1721 insertSliceOp.getMixedStrides());
1729 results.
add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
1730 InsertSliceOpSourceCastInserter>(context);
1737 auto rankedTensorType = dest.
getType().
cast<RankedTensorType>();
1738 unsigned rank = rankedTensorType.getRank();
1739 auto shape = rankedTensorType.getShape();
1742 for (
unsigned i = 0, e = rank; i < e; ++i) {
1744 if (rankedTensorType.isDynamicDim(i))
1749 sizes.push_back(dim);
1752 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
1763 Type typeToInfer,
Type typeToInferFrom) {}
1766 Optional<OpAsmParser::UnresolvedOperand> optOperand,
1767 Type &typeToInfer,
Type typeToInferFrom) {
1769 typeToInfer = typeToInferFrom;
1774 auto sourceType = getSource().getType().cast<RankedTensorType>();
1775 auto resultType = getResult().getType().cast<RankedTensorType>();
1776 auto expectedType = PadOp::inferResultType(
1779 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
1780 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
1782 if (expectedType.isDynamicDim(i))
1785 << resultType <<
" does not match the inferred type " 1793 auto ®ion = getRegion();
1794 unsigned rank = getResult().getType().cast<RankedTensorType>().getRank();
1797 return emitError(
"expected the block to have ") << rank <<
" arguments";
1801 if (!en.value().isIndex())
1802 return emitOpError(
"expected block argument ")
1803 << (en.index() + 1) <<
" to be an index";
1808 if (yieldOp.getValue().getType() !=
1810 return emitOpError(
"expected yield type to match shape element type");
1815 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
1819 unsigned rank = sourceType.getRank();
1820 assert(staticLow.size() == rank &&
"unexpected staticLow size mismatch");
1821 assert(staticHigh.size() == rank &&
"unexpected staticHigh size mismatch");
1822 assert((resultShape.empty() || resultShape.size() == rank) &&
1823 "unexpected resultShape size mismatch");
1826 for (
auto i : llvm::seq<unsigned>(0, rank)) {
1827 if (sourceType.isDynamicDim(i) ||
1828 staticLow[i] == ShapedType::kDynamicSize ||
1829 staticHigh[i] == ShapedType::kDynamicSize) {
1830 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamicSize
1833 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
1834 assert((resultShape.empty() || size == resultShape[i] ||
1835 resultShape[i] == ShapedType::kDynamicSize) &&
1836 "mismatch between inferred shape and result shape");
1837 inferredShape.push_back(size);
1841 return RankedTensorType::get(inferredShape, sourceType.getElementType());
1848 auto sourceType = source.
getType().
cast<RankedTensorType>();
1849 auto resultType = inferResultType(sourceType, staticLow, staticHigh);
1850 build(b, result, resultType, source, low, high, b.
getI64ArrayAttr(staticLow),
1858 auto sourceType = source.
getType().
cast<RankedTensorType>();
1859 unsigned rank = sourceType.getRank();
1861 build(b, result, source, staticVector, staticVector, low, high, nofold,
1869 assert(resultType.
isa<RankedTensorType>());
1870 auto sourceType = source.
getType().
cast<RankedTensorType>();
1878 ShapedType::kDynamicSize);
1880 ShapedType::kDynamicSize);
1882 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
1884 build(b, result, resultType, source, dynamicLow, dynamicHigh,
1890 llvm::SmallBitVector PadOp::getPaddedDims() {
1891 llvm::SmallBitVector paddedDims(getSourceType().getRank());
1893 for (
const auto &en :
enumerate(paddingWidths))
1895 paddedDims.set(en.index());
1897 extractPaddedDims(getMixedLowPad());
1898 extractPaddedDims(getMixedHighPad());
1910 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
1912 if (padTensorOp.getNofold())
1915 padTensorOp, padTensorOp.getResult().getType(),
1916 padTensorOp.getSource());
1927 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
1931 auto newResultType = PadOp::inferResultType(
1932 castOp.getSource().getType().cast<RankedTensorType>(),
1935 padTensorOp.getResultType().getShape());
1937 if (newResultType == padTensorOp.getResultType()) {
1939 padTensorOp.getSourceMutable().assign(castOp.getSource());
1942 auto newOp = rewriter.
create<PadOp>(
1943 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
1944 padTensorOp.getLow(), padTensorOp.getHigh(),
1945 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
1946 padTensorOp.getNofold());
1948 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
1951 padTensorOp, padTensorOp.getResultType(), newOp);
1964 if (!padTensorOp.getResult().hasOneUse())
1967 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
1971 tensorCastOp.getDest().getType()))
1974 auto replacementOp = rewriter.
create<PadOp>(
1975 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
1976 padTensorOp.getSource(), padTensorOp.getLow(), padTensorOp.getHigh(),
1977 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
1978 padTensorOp.getNofold());
1981 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
1982 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
2026 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
2029 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
2030 if (!outerPadOp || outerPadOp.getNofold())
2032 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
2037 int64_t rank = padOp.getSourceType().getRank();
2038 if (outerSliceOp.getSourceType().getRank() != rank) {
2040 "cannot fold rank-reducing chain");
2044 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2046 padOp,
"cannot fold non-unit stride ExtractSliceOps");
2050 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2052 "cannot fold PadOps with low padding");
2057 Value innerValue = padOp.getConstantPaddingValue();
2058 Value outerValue = outerPadOp.getConstantPaddingValue();
2059 if (!innerValue || !outerValue ||
2062 innerAttr != outerAttr) {
2064 padOp,
"cannot fold PadOps with different padding values");
2068 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2069 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2070 if (innerDims.anyCommon(outerDims)) {
2072 padOp,
"cannot fold PadOps with common padding dimensions");
2081 for (
auto &en :
enumerate(newOffsets)) {
2082 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2083 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2084 if (!innerDims.test(en.index()) &&
2086 en.value() = outerOffset;
2089 if (!outerDims.test(en.index()) &&
2091 en.value() = innerOffset;
2095 padOp,
"cannot find zero-offset and zero-padding pair");
2105 if (!outerDims.test(en.index()))
2107 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2108 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2109 assert(!ShapedType::isDynamic(sourceSize) &&
2110 "expected padded dimension to have a static size");
2113 padOp,
"cannot fold since the inner ExtractSliceOp size does not " 2114 "match the size of the outer padding");
2116 en.value() = outerSliceOp.getMixedSizes()[en.index()];
2121 for (
auto &en :
enumerate(newHighPad)) {
2122 if (innerDims.test(en.index()))
2123 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2124 if (outerDims.test(en.index()))
2125 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2130 auto newSliceOp = rewriter.
create<ExtractSliceOp>(
2131 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
2132 innerSliceOp.getMixedStrides());
2133 auto newPadOp = rewriter.
create<PadOp>(
2134 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
2135 padOp.getMixedLowPad(), newHighPad, padOp.getNofold());
2137 newPadOp.getRegion().begin());
2138 rewriter.
replaceOp(padOp, newPadOp.getResult());
2147 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
2148 FoldOrthogonalPaddings>(context);
2160 Value PadOp::getConstantPaddingValue() {
2161 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
2164 Value padValue = yieldOp.getValue();
2176 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
2187 auto constOperand = operands.front();
2188 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
2199 #define GET_OP_CLASSES 2200 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc" bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
Pattern to rewrite a subview op with constant arguments.
SmallVector< OpFoldResult, 4 > getMixedOffsets(OffsetSizeAndStrideOpInterface op, ArrayAttr staticOffsets, ValueRange offsets)
Return a vector of all the static or dynamic offsets of the op from provided external static and dyna...
This class contains a list of basic blocks and a link to the parent operation it is attached to...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
An attribute that represents a reference to a dense float vector or tensor object.
MLIRContext * getContext() const
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_int_op_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Operation is a basic unit of execution within MLIR.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
llvm::Optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Block represents an ordered list of Operations.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec, int64_t sentinel)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
This class represents a single result from folding an operation.
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
static bool isSameTypesWithoutEncoding(Type tp1, Type tp2)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
static SliceVerificationResult verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, ArrayAttr staticOffsets, ArrayAttr staticSizes, ArrayAttr staticStrides, ShapedType *expectedType=nullptr)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void map(Block *from, Block *to)
Inserts a new mapping for 'from' to 'to'.
MutableArrayRef< OpOperand > getOpOperands()
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SmallVector< OpFoldResult, 4 > getMixedSizes(OffsetSizeAndStrideOpInterface op, ArrayAttr staticSizes, ValueRange sizes)
Return a vector of all the static or dynamic sizes of the op from provided external static and dynami...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
An attribute that represents a reference to a dense vector or tensor object.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void addOperands(ValueRange newOperands)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
unsigned getNumArguments()
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
Attributes are known-constant values of operations.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
void canonicalizeSubViewPart(SmallVectorImpl< OpFoldResult > &values, function_ref< bool(int64_t)> isDynamic)
Detects the values produced by a ConstantIndexOp and places the new constant in place of the correspo...
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Block * getParentBlock()
Return the Block in which this Value is defined.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Operation::operand_range getIndices(Operation *op)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class provides an abstraction over the various different ranges of value types.
void addTypes(ArrayRef< Type > newTypes)
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs...
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
BlockArgListType getArguments()
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A canonicalizer wrapper to replace ExtractSliceOps.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
An attribute that represents a reference to a splat vector or tensor constant, meaning all of the ele...
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, OpTy op, Type expectedType)
Operation * getTerminator()
Get the terminator operation of this block.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
RAII guard to reset the insertion point of the builder when destroyed.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Return the canonical type of the result of an extract_slice op.
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
Type getType() const
Return the type of this value.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
static int64_t getNumElements(ShapedType type)
Specialization of arith.constant op that returns an integer of index type.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static RankedTensorType computeTensorReshapeCollapsedType(RankedTensorType type, ArrayRef< AffineMap > reassociation)
Compute the RankedTensorType obtained by applying reassociation to type.
MLIRContext is the top-level object for a collection of MLIR operations.
Block * lookupOrDefault(Block *from) const
Lookup a mapped value within the map.
Type getElementType() const
Returns the element type of this tensor type.
This class represents an operand of an operation.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
ParseResult parseInferType(OpAsmParser &parser, Optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it...
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
This class represents success/failure for parsing-like operations that find it important to chain tog...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute...
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
SmallVector< OpFoldResult, 4 > getMixedStrides(OffsetSizeAndStrideOpInterface op, ArrayAttr staticStrides, ValueRange strides)
Return a vector of all the static or dynamic strides of the op from provided external static and dyna...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MLIRContext * getContext() const
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
An attribute that represents a reference to a dense integer vector or tensor object.
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.