37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/ADT/StringSet.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/FormatVariadic.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
52 auto type = cast<ShapedType>(v.
getType());
53 if (!type.isDynamicDim(dim))
58 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
59 return builder.create<tensor::DimOp>(loc, v, dim);
61 .Case<MemRefType>([&](MemRefType t) ->
Value {
62 return builder.create<memref::DimOp>(loc, v, dim);
73 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
74 return b.
create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
77 .Case<MemRefType>([&](MemRefType type) ->
Value {
78 return b.
create<memref::SubViewOp>(loc, source, offsets, sizes,
81 .Default([&](
Type t) {
return nullptr; });
90 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
92 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
94 llvm_unreachable(
"Expected MemRefType or TensorType");
99 auto shapedType = llvm::cast<ShapedType>(source.
getType());
100 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
122 assert(llvm::all_of(outputTypes,
123 [](
Type t) {
return llvm::isa<ShapedType>(t); }));
127 for (
auto containers : {inputTypes, outputTypes}) {
128 for (
auto t : containers) {
140 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
144 regionBuilder(b, *body, attrs);
156 std::optional<TypeRange> resultTensorTypes,
163 if (!resultTensorTypes)
164 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
165 [](
Type type) { return llvm::isa<RankedTensorType>(type); });
167 state.addOperands(inputs);
168 state.addOperands(outputs);
169 state.addTypes(derivedResultTypes);
170 state.addAttributes(attributes);
172 "operandSegmentSizes",
174 static_cast<int32_t>(outputs.size())}));
177 Region ®ion = *state.addRegion();
179 state.attributes.getAttrs(), regionBuilder);
188 bool addOperandSegmentSizes =
true) {
189 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
218 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
220 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
224 if (addOperandSegmentSizes) {
233 attrs.
append(
"operandSegmentSizes",
235 {static_cast<int32_t>(inputsOperands.size()),
236 static_cast<int32_t>(outputsOperands.size())}));
241 {static_cast<int32_t>(inputsOperands.size()),
242 static_cast<int32_t>(outputsOperands.size())}));
246 std::optional<RegisteredOperationName> info =
250 return parser.emitError(attrsLoc)
251 <<
"'" << result.name.getStringRef() <<
"' op ";
262 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
263 if (!outputs.empty())
264 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
275 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
278 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
279 "region expects {0} args, got {1}",
280 numRegionArgs, inputTypes.size() + outputTypes.size()));
299 unsigned numRegionArgs,
311 result.
addTypes(outputTensorsTypes);
313 std::unique_ptr<Region> region = std::make_unique<Region>();
325 if (resultTypes.empty())
334 {
"operandSegmentSizes",
337 "linalg.memoized_indexing_maps"});
374 class RegionBuilderHelper {
377 : context(context), block(block) {}
380 Value buildUnaryFn(UnaryFn unaryFn,
Value arg) {
381 if (!isFloatingPoint(arg))
382 llvm_unreachable(
"unsupported non numeric type");
394 return builder.
create<math::FloorOp>(arg.
getLoc(), arg);
396 return builder.
create<arith::NegFOp>(arg.
getLoc(), arg);
398 llvm_unreachable(
"unsupported unary function");
403 bool allComplex = isComplex(arg0) && isComplex(arg1);
404 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
405 bool allInteger = isInteger(arg0) && isInteger(arg1);
408 if (!allComplex && !allFloatingPoint && !allInteger)
409 llvm_unreachable(
"unsupported non numeric type");
414 return builder.
create<complex::AddOp>(arg0.
getLoc(), arg0, arg1);
415 if (allFloatingPoint)
416 return builder.
create<arith::AddFOp>(arg0.
getLoc(), arg0, arg1);
418 return builder.
create<arith::OrIOp>(arg0.
getLoc(), arg0, arg1);
419 return builder.
create<arith::AddIOp>(arg0.
getLoc(), arg0, arg1);
422 return builder.
create<complex::SubOp>(arg0.
getLoc(), arg0, arg1);
423 if (allFloatingPoint)
424 return builder.
create<arith::SubFOp>(arg0.
getLoc(), arg0, arg1);
426 llvm_unreachable(
"unsupported operation: sub with bools");
427 return builder.
create<arith::SubIOp>(arg0.
getLoc(), arg0, arg1);
430 return builder.
create<complex::MulOp>(arg0.
getLoc(), arg0, arg1);
431 if (allFloatingPoint)
432 return builder.
create<arith::MulFOp>(arg0.
getLoc(), arg0, arg1);
434 return builder.
create<arith::AndIOp>(arg0.
getLoc(), arg0, arg1);
435 return builder.
create<arith::MulIOp>(arg0.
getLoc(), arg0, arg1);
438 return builder.
create<complex::DivOp>(arg0.
getLoc(), arg0, arg1);
439 if (allFloatingPoint)
440 return builder.
create<arith::DivFOp>(arg0.
getLoc(), arg0, arg1);
442 llvm_unreachable(
"unsupported operation: div with bools");
443 return builder.
create<arith::DivSIOp>(arg0.
getLoc(), arg0, arg1);
444 case BinaryFn::div_unsigned:
445 if (!allInteger || allBool)
446 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
447 return builder.
create<arith::DivUIOp>(arg0.
getLoc(), arg0, arg1);
448 case BinaryFn::max_signed:
450 if (allFloatingPoint)
451 return builder.
create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
452 return builder.
create<arith::MaxSIOp>(arg0.
getLoc(), arg0, arg1);
453 case BinaryFn::min_signed:
455 if (allFloatingPoint)
456 return builder.
create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
457 return builder.
create<arith::MinSIOp>(arg0.
getLoc(), arg0, arg1);
458 case BinaryFn::max_unsigned:
460 if (allFloatingPoint)
461 return builder.
create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
462 return builder.
create<arith::MaxUIOp>(arg0.
getLoc(), arg0, arg1);
463 case BinaryFn::min_unsigned:
465 if (allFloatingPoint)
466 return builder.
create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
467 return builder.
create<arith::MinUIOp>(arg0.
getLoc(), arg0, arg1);
469 llvm_unreachable(
"unsupported binary function");
475 case TypeFn::cast_signed:
476 return cast(toType, operand,
false);
477 case TypeFn::cast_unsigned:
478 return cast(toType, operand,
true);
480 llvm_unreachable(
"unsupported type conversion function");
486 builder.
create<YieldOp>(loc, values);
489 Value constant(
const std::string &value) {
493 return builder.
create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
496 Value index(int64_t dim) {
501 Type getIntegerType(
unsigned width) {
515 auto loc = operand.
getLoc();
519 bool isComplex(
Value value) {
520 return llvm::isa<ComplexType>(value.
getType());
522 bool isFloatingPoint(
Value value) {
523 return llvm::isa<FloatType>(value.
getType());
525 bool isInteger(
Value value) {
526 return llvm::isa<IntegerType>(value.
getType());
551 if (copyOp.getInputs() != copyOp.getOutputs())
553 if (copyOp.hasBufferSemantics())
556 rewriter.
replaceOp(copyOp, copyOp.getInputs());
566 results.
add<EraseSelfCopy>(context);
579 template <
typename TensorReshapeOp>
584 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
589 auto newInit = rewriter.
create<TensorReshapeOp>(
590 loc, reshapeOp.getResultType(), oldFill.output(),
591 reshapeOp.getReassociation());
606 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
612 Value padValue = padOp.getConstantPaddingValue();
613 if (!padValue || fillOp.value() != padValue)
619 padOp,
"failed to reify tensor.pad op result shape");
621 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
622 padOp.getLoc(), reifiedShape.front(),
623 padOp.getResultType().getElementType());
629 if (replacement.getType() != padOp.getResultType()) {
630 replacement = rewriter.
create<tensor::CastOp>(
631 fillOp.getLoc(), padOp.getResultType(), replacement);
641 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
644 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
646 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
650 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
655 Value firstDest = insertOp.getDest();
656 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
657 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
662 bool disjoint =
false;
663 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
666 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
667 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
668 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
672 int64_t prevStart = prevOp.getStaticOffset(i);
673 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
674 prevOp.getStaticStride(i);
675 int64_t nextStart = insertOp.getStaticOffset(i);
676 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
677 insertOp.getStaticStride(i);
678 if (prevEnd < nextStart || nextEnd < prevStart) {
686 firstDest = prevOp.getDest();
697 Value padValue = srcPadOp.getConstantPaddingValue();
698 if (!padValue || dstFillOp.value() != padValue)
714 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
716 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
720 for (
int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) {
722 rewriter.
create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
727 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
728 newSizes, insertOp.getMixedStrides());
734 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
742 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
747 Value extractedScalar = fillOp.getInputs()[0];
750 rewriter.
replaceOp(extractOp, extractedScalar);
759 tensor::PackOp packOp) {
760 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
764 if (
auto paddingValue = packOp.getPaddingValue())
768 Value packOpDest = packOp.getDest();
772 return rewriter.
create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
784 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
787 rewriter.
replaceOp(packOp, fillOp.value().result());
798 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
801 copyOp.getOutputs());
804 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
806 fillOp.getOutputs());
818 .
add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
819 FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
820 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
821 FoldInsertPadIntoFill>(context);
834 for (
ValueRange container : {inputs, outputs}) {
835 for (
Value v : container) {
836 Type t = v.getType();
837 blockArgTypes.push_back(
839 blockArgLocs.push_back(v.getLoc());
845 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
849 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
851 for (
Value v : getRegionInputArgs())
853 for (
Value v : getRegionOutputArgs())
857 void GenericOp::build(
860 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
863 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
864 iteratorTypes, doc, libraryCall);
868 inputs, outputs, bodyBuild);
871 void GenericOp::build(
875 StringRef libraryCall,
878 build(builder, result, resultTensorTypes, inputs, outputs,
883 return IteratorTypeAttr::get(builder.getContext(), iter);
886 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
887 bodyBuild, attributes);
890 void GenericOp::build(
894 StringRef libraryCall,
897 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
898 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
901 void GenericOp::build(
907 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
909 "", bodyBuild, attributes);
912 void GenericOp::build(
918 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
921 "", bodyBuild, attributes);
928 auto genericAttrNames = linalgTraitAttrNames();
931 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
933 for (
auto attr : (*this)->getAttrs()) {
934 if (attr.getName() == getIteratorTypesAttrName()) {
936 llvm::cast<ArrayAttr>(attr.getValue())
937 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
943 llvm::to_vector(llvm::map_range(
944 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
948 genericAttrs.emplace_back(
949 getIteratorTypesAttrName(),
951 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
952 genericAttrs.push_back(attr);
955 if (!genericAttrs.empty()) {
957 p << genericDictAttr;
963 genericAttrNames.push_back(
"operandSegmentSizes");
964 genericAttrNamesSet.insert(genericAttrNames.back());
966 bool hasExtraAttrs =
false;
968 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
978 if (!getRegion().empty()) {
988 DictionaryAttr dictAttr;
997 dictAttr.getValue().end());
1003 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1005 if (!iteratorTypes) {
1006 return parser.
emitError(attributeLocation)
1007 <<
"expected " << getIteratorTypesAttrName(result.
name)
1008 <<
" array attribute";
1013 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1014 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1015 if (!maybeIteratorType.has_value())
1017 <<
"unexpected iterator_type (" << s <<
")";
1019 iteratorTypeAttrs.push_back(
1036 std::unique_ptr<Region> region = std::make_unique<Region>();
1048 result.
addTypes(outputTensorsTypes);
1058 for (
auto operand : inputOperands) {
1059 if (!llvm::isa<MemRefType>(operand.
getType()))
1064 for (
auto operand : outputOperands) {
1065 if (!llvm::isa<MemRefType>(operand.
getType()))
1074 void GenericOp::getEffects(
1096 if (llvm::any_of(genericOp.getIndexingMapsArray(),
1097 [](
AffineMap map) { return !map.isIdentity(); }))
1103 if (!llvm::hasSingleElement(body))
1105 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1110 if (genericOp.hasBufferSemantics()) {
1111 if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
1112 genericOp.getDpsInputOperand(0)->get() ==
1113 genericOp.getDpsInitOperand(0)->get()) {
1121 if (!genericOp.hasTensorSemantics())
1128 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1129 if (!yieldArg || yieldArg.getOwner() != &body)
1131 unsigned argumentNumber = yieldArg.getArgNumber();
1132 Value returnedArg = genericOp->getOperand(argumentNumber);
1133 Type resultType = genericOp->getResult(yieldVal.index()).getType();
1137 if (returnType != resultType) {
1142 returnedArg = rewriter.
create<sparse_tensor::ConvertOp>(
1143 genericOp.getLoc(), resultType, returnedArg);
1145 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1148 returnedArg = rewriter.
create<tensor::CastOp>(
1149 genericOp.getLoc(), resultType, returnedArg);
1152 returnedArgs.push_back(returnedArg);
1155 if (returnedArgs.size() != genericOp->getNumResults())
1157 rewriter.
replaceOp(genericOp, returnedArgs);
1166 results.
add<EraseIdentityGenericOp>(context);
1188 for (
Type outputType : outputTypes) {
1189 if (llvm::isa<RankedTensorType>(outputType))
1203 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1205 for (
Value v : getRegionInputArgs())
1210 if (!getResults().empty())
1211 setNameFn(getResults().front(),
"mapped");
1218 build(builder, result,
TypeRange{}, inputs, init);
1223 if (llvm::isa<RankedTensorType>(initType))
1228 inputs, {}, bodyBuild);
1235 bool initFirst =
false) {
1241 for (
auto &operand : operands) {
1243 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1250 payloadOpOperands.push_back(block.
getArguments().back());
1251 for (
const auto &arg : block.
getArguments().drop_back())
1252 payloadOpOperands.push_back(arg);
1261 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1268 std::optional<OperationName> payloadOpName;
1272 if (
failed(operationName))
1276 payloadOpName = operationName.value();
1284 if (payloadOpName.has_value()) {
1318 for (
const auto &[operand, bbArg] :
1320 if (bbArg != operand)
1324 for (
const auto &[operand, bbArg] :
1326 if (bbArg != operand)
1335 std::string attrToElide;
1337 for (
const auto &attr : payloadOp->
getAttrs()) {
1339 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1340 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1341 attrToElide = attr.getName().str();
1342 elidedAttrs.push_back(attrToElide);
1351 Block *mapper = getBody();
1366 [&](
auto arg) { p.printRegionArgument(arg); });
1375 auto *bodyBlock = getBody();
1376 auto blockArgs = bodyBlock->getArguments();
1379 if (getInputs().size() != blockArgs.size())
1380 return emitOpError() <<
"expects number of operands to match the arity of "
1382 << getInputs().size() <<
" and " << blockArgs.size();
1385 for (
const auto &[bbArgType, inputArg] :
1386 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1387 auto inputElemType =
1388 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1389 if (bbArgType != inputElemType) {
1390 return emitOpError() <<
"expected element type of input " << inputElemType
1391 <<
" to match bbArg type " << bbArgType;
1396 auto outputShape = getInit().getType().getShape();
1398 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1399 if (inputElemShape != outputShape) {
1400 return emitOpError() <<
"expected shape of input (" << inputElemShape
1401 <<
") to match shape of output (" << outputShape
1410 int64_t rank = getInit().getType().getRank();
1414 ArrayAttr MapOp::getIndexingMaps() {
1416 int64_t rank = getInit().getType().getRank();
1417 int64_t numIndexingMaps = getOperands().size();
1422 void MapOp::getEffects(
1433 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1435 for (
Value v : getRegionInputArgs())
1437 for (
Value v : getRegionOutputArgs())
1438 setNameFn(v,
"init");
1441 void ReduceOp::getAsmResultNames(
1443 if (!getResults().empty())
1444 setNameFn(getResults().front(),
"reduced");
1447 void ReduceOp::build(
1452 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1456 for (
Value init : inits) {
1458 if (llvm::isa<RankedTensorType>(initType))
1464 inputs, inits, bodyBuild);
1469 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1471 utils::IteratorType::parallel);
1472 for (int64_t reductionDim : getDimensions())
1473 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1474 return iteratorTypes;
1477 ArrayAttr ReduceOp::getIndexingMaps() {
1479 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1486 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1487 affineMaps.push_back(resultMap);
1491 void ReduceOp::getEffects(
1500 StringRef attributeName) {
1509 std::optional<OperationName> payloadOpName;
1513 if (
failed(operationName))
1517 payloadOpName = operationName.value();
1528 if (payloadOpName.has_value()) {
1548 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1552 Block *mapper = getBody();
1567 [&](
auto arg) { p.printRegionArgument(arg); });
1578 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1579 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1580 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1581 return emitOpError() <<
"expects all inputs to have the same shapes. "
1582 "Shape at input-index "
1584 <<
" is not equal to the shape at input-index 0.";
1587 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1588 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1589 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1590 return emitOpError() <<
"expects all outputs to have the same shapes. "
1591 "Shape at output-index "
1593 <<
" is not equal to the shape at output-index 0.";
1596 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1597 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1600 for (int64_t dimension : dimensionsRef) {
1601 if (dimension < 0 || dimension >= inputType.getRank()) {
1602 return emitOpError()
1603 <<
"dimensions for reduction should be in the range [0, "
1604 << inputType.getRank() - 1 <<
"].";
1606 dimensionsToReduce.insert(dimension);
1609 auto inputDims = inputType.getShape();
1610 auto initDims = initType.getShape();
1615 if (!dimensionsToReduce.count(en.index()))
1616 reducedInputDims.push_back(en.value());
1619 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1620 return emitOpError() <<
"number of dimensions after reduction "
1621 << reducedInputDims.size()
1622 <<
" doesn't match the init rank "
1623 << initType.getRank();
1626 if (reducedInputDims != initDims)
1627 return emitOpError() <<
"init dimensions [" << initDims
1628 <<
"] doesn't match input dimensions after reduction ["
1629 << reducedInputDims <<
"]";
1631 Block *block = getBody();
1633 return emitOpError()
1634 <<
"mismatching number of operands and block arguments";
1637 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1638 Type inputElementType =
1639 llvm::cast<ShapedType>(input.getType()).getElementType();
1640 if (inputElementType != bbArg.getType())
1641 return emitOpError()
1642 <<
"input element type " << inputElementType
1643 <<
" does not match corresponding block argument type "
1648 for (
auto [output, bbArg] : llvm::zip(
1649 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1650 auto outputElementType =
1651 llvm::cast<ShapedType>(output.getType()).getElementType();
1652 if (outputElementType != bbArg.getType())
1653 return emitOpError()
1654 <<
"output element type " << outputElementType
1655 <<
" does not match corresponding block argument type "
1670 b.
create<linalg::YieldOp>(loc, args[0]);
1685 if (llvm::isa<RankedTensorType>(initType))
1714 void TransposeOp::getAsmResultNames(
1716 if (!getResults().empty())
1717 setNameFn(getResults().front(),
"transposed");
1730 return emitOpError(
"permutation is not valid");
1732 auto inputType = getInput().getType();
1733 auto initType = getInit().getType();
1735 int64_t rank = inputType.getRank();
1737 if (rank != initType.getRank())
1738 return emitOpError() <<
"input rank " << rank
1739 <<
" does not match init rank " << initType.getRank();
1741 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
1742 return emitOpError() <<
"size of permutation " << permutationRef.size()
1743 <<
" does not match the argument rank " << rank;
1745 auto inputDims = inputType.getShape();
1746 auto initDims = initType.getShape();
1748 for (int64_t i = 0; i < rank; ++i) {
1749 int64_t inputDim = inputDims[permutationRef[i]];
1750 int64_t initDim = initDims[i];
1752 if (inputDim != initDim) {
1753 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
1754 <<
" doesn't match dim(input, permutation[" << i
1755 <<
"]) = " << inputDim;
1763 int64_t rank = getInit().getType().getRank();
1767 ArrayAttr TransposeOp::getIndexingMaps() {
1769 int64_t rank = getInit().getType().getRank();
1773 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())});
1776 void TransposeOp::getEffects(
1798 if (llvm::isa<RankedTensorType>(initType))
1827 void BroadcastOp::getAsmResultNames(
1829 if (!getResults().empty())
1830 setNameFn(getResults().front(),
"broadcasted");
1842 auto inputType = getInput().getType();
1843 auto initType = getInit().getType();
1845 int64_t inputRank = inputType.getRank();
1846 int64_t initRank = initType.getRank();
1848 auto inputShape = inputType.getShape();
1849 auto initShape = initType.getShape();
1851 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
1852 return emitOpError() <<
"input rank plus added dimensions does not "
1853 "match init rank. input rank: "
1855 <<
", dimensions size: " << dimensionsRef.size()
1856 <<
", init rank: " << initRank;
1859 if (dim < 0 || dim >= initRank)
1860 return emitOpError() <<
"dimension " << idx
1861 <<
" is out of range. expected range: [0, "
1862 << initRank - 1 <<
"], got: " << dim;
1867 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
1868 if (!llvm::is_contained(dimensionsRef, dim))
1869 dimMap.push_back(dim);
1872 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
1875 if (inputShape[inputDimIdx] != initShape[initDimIdx])
1876 return emitOpError() <<
"input dim " << inputDimIdx
1877 <<
" should match init dim " << initDimIdx
1878 <<
". input: " << inputShape[inputDimIdx]
1879 <<
", init: " << initShape[initDimIdx];
1886 int64_t rank = getInit().getType().getRank();
1890 ArrayAttr BroadcastOp::getIndexingMaps() {
1892 int64_t rank = getInit().getType().getRank();
1898 void BroadcastOp::getEffects(
1910 if (getNumOperands() > 0)
1911 p <<
' ' << getOperands();
1913 if (getNumOperands() > 0)
1914 p <<
" : " << getOperandTypes();
1931 return op.
emitOpError(
"expected number of yield values (")
1933 <<
") to match the number of inits / outs operands of the enclosing "
1934 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
1938 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
1940 if (isa<MemRefType, RankedTensorType>(elementType))
1942 if (opOperand.get().getType() != elementType)
1944 << (opOperand.getOperandNumber() + 1) <<
" ("
1945 << opOperand.get().getType() <<
") doesn't match "
1946 <<
"the element type of the enclosing linalg.generic op ("
1947 << elementType <<
")";
1953 auto *parentOp = (*this)->getParentOp();
1954 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1955 return emitOpError(
"expected single non-empty parent region");
1957 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1960 return emitOpError(
"expected parent op with LinalgOp interface");
1968 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
1970 return emitOpError(
"expected parent op with LinalgOp interface");
1971 if (linalgOp.getNumLoops() <= getDim())
1972 return emitOpError(
"expected dim (")
1973 << getDim() <<
") to be lower than the number of loops ("
1974 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
1980 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
1982 #define GET_OP_CLASSES
1983 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1985 #define GET_OP_CLASSES
1986 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2003 for (
unsigned i = 0; i < num; ++i)
2010 auto rangeA = llvm::make_range(a.begin(), a.end());
2011 auto rangeB = llvm::make_range(b.begin(), b.end());
2012 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2013 return llvm::to_vector<4>(concatRanges);
2017 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2019 for (
auto size : memref.getShape())
2026 if (
auto as = memref.getMemorySpace()) {
2027 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2028 ss <<
"as" << attr.getInt();
2034 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2037 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2049 assert(isa<LinalgOp>(op));
2051 std::string fun =
"";
2053 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2054 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2055 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2056 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2060 std::replace(name.begin(), name.end(),
'.',
'_');
2061 llvm::raw_string_ostream ss(name);
2065 return std::string();
2068 std::string res = ss.str();
2087 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2090 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2101 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2109 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2116 if (castOp->getBlock() != linalgOp->getBlock())
2123 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2126 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2132 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2134 rewriter.
create<tensor::CastOp>(loc, resultType, outOperand->
get());
2137 linalgOp.getDpsInits().end());
2138 outputOperands[resultNumber] = newOperand;
2139 newOperands.append(outputOperands.begin(), outputOperands.end());
2142 linalgOp->result_type_end());
2143 resultTypes[resultNumber] = resultType;
2144 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2151 results[resultNumber] = castBack;
2163 if (linalgOp.isScalar(&opOperand))
2165 Value src = opOperand.get();
2166 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2167 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2175 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2176 Value castSource = castOp.getSource();
2177 auto castSourceType =
2178 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2179 if (castSourceType && castSourceType.hasStaticShape())
2180 sourceShape = castSourceType.getShape();
2186 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2187 if (sourceType.isDynamicDim(i))
2189 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2190 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2200 static void createNewOperandWithStaticSizes(
2204 bool &changeNeeded) {
2206 newOperands.push_back(src);
2207 if (linalgOp.isScalar(opOperand))
2209 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2210 Type resultType = sourceType;
2211 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2212 resultTypes.push_back(resultType);
2216 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2220 bool newOperandNeeded =
false;
2221 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2222 int64_t dimShape = sourceShape[i];
2224 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2225 newShape.push_back(dimShape);
2231 newShape.push_back(affineExprToSize[dimExpr]);
2232 newOperandNeeded =
true;
2235 if (newOperandNeeded) {
2236 changeNeeded =
true;
2239 Value newOperand = rewriter.
create<tensor::CastOp>(loc, resultType, src);
2241 newOperands[index] = newOperand;
2243 if (linalgOp.isDpsInit(opOperand))
2244 resultTypes.push_back(resultType);
2255 if (!linalgOp.hasTensorSemantics())
2259 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2260 return !map.isProjectedPermutation();
2270 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2277 bool changeNeeded =
false;
2278 newOperands.reserve(linalgOp->getNumOperands());
2279 resultTypes.reserve(linalgOp.getNumDpsInits());
2282 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2283 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2284 affineExprToSize, linalgOp, newOperands,
2285 resultTypes, changeNeeded);
2294 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2297 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2298 Value newResult = std::get<1>(it);
2299 Value oldResult = std::get<0>(it);
2302 replacements.push_back(
2303 (newType != oldType)
2304 ? rewriter.
create<tensor::CastOp>(loc, oldType, newResult)
2307 rewriter.
replaceOp(linalgOp, replacements);
2322 ShapedType inputType = getInputOperandType();
2323 ShapedType outputType = getOutputOperandType();
2328 return emitOpError(
"incompatible output shape");
2330 int64_t inputRank = getInputOperandRank();
2331 int64_t dimension = getDimension();
2332 if ((dimension < 0) || (dimension >= inputRank))
2333 return emitOpError(
"incorrect dimension specified");
2339 int64_t operandRank = getInputOperandRank();
2342 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
2343 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2344 Value source = getInput();
2345 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2346 loopBounds[dim].offset = zero;
2347 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2348 loopBounds[dim].stride = one;
2355 utils::IteratorType::parallel);
2356 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2357 return iteratorTypes;
2361 SoftmaxOp::getTiledImplementation(
OpBuilder &builder,
2364 int64_t rank = getInputOperandRank();
2368 tiledOperands.emplace_back(
2369 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
2370 tiledOperands.emplace_back(
2371 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
2375 resultTypes.push_back(tiledOperands[1].getType());
2377 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2386 if (resultNumber == 0) {
2387 resultOffsets.assign(offsets.begin(), offsets.end());
2388 resultSizes.assign(sizes.begin(), sizes.end());
2403 Location loc = getOperation()->getLoc();
2405 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2406 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2407 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2408 if (!outputShapedType.isDynamicDim(dim)) {
2410 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2417 reifiedReturnShapes.emplace_back(std::move(shapes));
2421 void SoftmaxOp::getEffects(
2450 int64_t dim,
bool allParallel =
false) {
2452 utils::IteratorType::parallel);
2454 iteratorTypes[dim] = utils::IteratorType::reduction;
2458 for (
int i = 0; i < inputRank; i++) {
2465 return std::make_tuple(iteratorTypes, indexingMaps);
2470 template <
typename T>
2473 auto inputType = cast<ShapedType>(input.
getType());
2475 int64_t inputRank = inputShape.size();
2476 auto [iteratorTypes, indexingMaps] =
2478 assert(indexingMaps.size() == 2 &&
2479 "We should have two maps: 1 for the input, 1 for the output");
2480 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2482 auto genericOp = builder.
create<linalg::GenericOp>(
2483 loc, output.
getType(), input, output, indexingMaps, iteratorTypes,
2485 Value result = b.create<T>(loc, args[0], args[1]);
2486 b.create<linalg::YieldOp>(loc, result);
2496 auto inputType = cast<ShapedType>(input.
getType());
2498 int64_t inputRank = inputShape.size();
2500 builder, inputRank, dim,
true);
2501 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2502 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2504 indexingMaps.push_back(indexingMaps[0]);
2505 auto genericOp = builder.
create<linalg::GenericOp>(
2508 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2509 Value result = b.create<math::ExpOp>(loc, diff);
2510 b.create<linalg::YieldOp>(loc, result);
2521 Value denominator,
Value output, int64_t dim) {
2522 auto inputType = cast<ShapedType>(numerator.
getType());
2524 int64_t inputRank = inputShape.size();
2526 builder, inputRank, dim,
true);
2527 assert(indexingMaps.size() == 2 &&
2528 "We should have one map for each input (2)");
2529 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
2531 indexingMaps.push_back(indexingMaps[0]);
2532 auto genericOp = builder.
create<linalg::GenericOp>(
2534 indexingMaps, iteratorTypes,
2536 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2537 b.create<linalg::YieldOp>(loc, result);
2565 Value input = getInput();
2566 ShapedType inputType = getInputOperandType();
2567 Type elementType = inputType.getElementType();
2568 int64_t reductionDim = getDimension();
2570 Value output = getOutput();
2571 dims.erase(dims.begin() + reductionDim);
2573 Value outputReduce = b.
create<tensor::EmptyOp>(loc, dims, elementType);
2575 elementType, b, loc,
2577 Value neutralForMaxFInit =
2578 b.
create<linalg::FillOp>(loc,
Value{neutralForMaxF}, outputReduce)
2580 Value max = reduce<arith::MaximumFOp>(b, loc, input, neutralForMaxFInit,
2590 b.
create<linalg::FillOp>(loc,
Value{zero}, outputReduce).result();
2592 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2596 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2604 void LinalgDialect::getCanonicalizationPatterns(
2606 results.
add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
2613 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, ValueRange results, const ValueRange inputOperands, ValueRange outputOperands)
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
MPInt ceil(const Fraction &f)
MPInt floor(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.