40 #include "llvm/ADT/DenseMap.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SetOperations.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/ADT/SmallVector.h"
45 #include "llvm/ADT/StringSet.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include "llvm/Support/MathExtras.h"
50 #include "llvm/Support/raw_ostream.h"
60 auto type = cast<ShapedType>(v.
getType());
61 if (!type.isDynamicDim(dim))
66 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
67 return builder.create<tensor::DimOp>(loc, v, dim);
69 .Case<MemRefType>([&](MemRefType t) ->
Value {
70 return builder.create<memref::DimOp>(loc, v, dim);
81 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
82 return b.
create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
85 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
86 return b.
create<memref::SubViewOp>(loc, source, offsets, sizes,
98 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
100 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
102 llvm_unreachable(
"Expected MemRefType or TensorType");
107 auto shapedType = llvm::cast<ShapedType>(source.
getType());
108 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
131 for (
auto containers : {inputTypes, outputTypes}) {
132 for (
auto t : containers) {
144 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
148 regionBuilder(b, *body, attrs);
160 std::optional<TypeRange> resultTensorTypes,
167 if (!resultTensorTypes)
168 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
169 llvm::IsaPred<RankedTensorType>);
171 state.addOperands(inputs);
172 state.addOperands(outputs);
173 state.addTypes(derivedResultTypes);
175 state.addAttributes(attributes);
177 "operandSegmentSizes",
179 static_cast<int32_t>(outputs.size())}));
182 Region ®ion = *state.addRegion();
184 state.attributes.getAttrs(), regionBuilder);
188 std::optional<TypeRange> resultTensorTypes,
195 indexingMapsAttrVal = llvm::map_to_vector(
196 MatmulOp::getDefaultIndexingMaps(b.
getContext()),
198 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
200 attributes, regionBuilder);
204 std::optional<TypeRange> resultTensorTypes,
211 indexingMapsAttrVal =
215 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
217 attributes, regionBuilder);
226 bool addOperandSegmentSizes =
true) {
227 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
256 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
258 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
262 if (addOperandSegmentSizes) {
271 attrs.
append(
"operandSegmentSizes",
273 {static_cast<int32_t>(inputsOperands.size()),
274 static_cast<int32_t>(outputsOperands.size())}));
279 {static_cast<int32_t>(inputsOperands.size()),
280 static_cast<int32_t>(outputsOperands.size())}));
284 std::optional<RegisteredOperationName> info =
287 if (failed(info->verifyInherentAttrs(result.
attributes, [&]() {
288 return parser.emitError(attrsLoc)
289 <<
"'" << result.name.getStringRef() <<
"' op ";
300 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
301 if (!outputs.empty())
302 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
313 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
316 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
317 "region expects {0} args, got {1}",
318 numRegionArgs, inputTypes.size() + outputTypes.size()));
337 unsigned numRegionArgs,
353 result.
addTypes(outputTensorsTypes);
355 std::unique_ptr<Region> region = std::make_unique<Region>();
367 if (resultTypes.empty())
412 class RegionBuilderHelper {
415 : builder(builder), block(block) {}
418 Value buildUnaryFn(UnaryFn unaryFn,
Value arg) {
419 if (!isFloatingPoint(arg))
420 llvm_unreachable(
"unsupported non numeric type");
422 builder.setInsertionPointToEnd(&block);
425 return builder.create<math::ExpOp>(arg.
getLoc(), arg);
427 return builder.create<math::LogOp>(arg.
getLoc(), arg);
429 return builder.create<math::AbsFOp>(arg.
getLoc(), arg);
431 return builder.create<math::CeilOp>(arg.
getLoc(), arg);
433 return builder.create<math::FloorOp>(arg.
getLoc(), arg);
435 return builder.create<arith::NegFOp>(arg.
getLoc(), arg);
436 case UnaryFn::reciprocal: {
438 auto one = builder.create<arith::ConstantOp>(arg.
getLoc(),
439 ::cast<TypedAttr>(oneAttr));
440 return builder.create<arith::DivFOp>(arg.
getLoc(), one, arg);
443 return builder.create<math::RoundOp>(arg.
getLoc(), arg);
445 return builder.create<math::SqrtOp>(arg.
getLoc(), arg);
447 return builder.create<math::RsqrtOp>(arg.
getLoc(), arg);
448 case UnaryFn::square:
449 return builder.create<arith::MulFOp>(arg.
getLoc(), arg, arg);
451 return builder.create<math::TanhOp>(arg.
getLoc(), arg);
453 return builder.create<math::ErfOp>(arg.
getLoc(), arg);
455 llvm_unreachable(
"unsupported unary function");
460 bool allComplex = isComplex(arg0) && isComplex(arg1);
461 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
462 bool allInteger = isInteger(arg0) && isInteger(arg1);
465 if (!allComplex && !allFloatingPoint && !allInteger)
466 llvm_unreachable(
"unsupported non numeric type");
468 builder.setInsertionPointToEnd(&block);
472 return builder.create<complex::AddOp>(arg0.
getLoc(), arg0, arg1);
473 if (allFloatingPoint)
474 return builder.create<arith::AddFOp>(arg0.
getLoc(), arg0, arg1);
476 return builder.create<arith::OrIOp>(arg0.
getLoc(), arg0, arg1);
477 return builder.create<arith::AddIOp>(arg0.
getLoc(), arg0, arg1);
480 return builder.create<complex::SubOp>(arg0.
getLoc(), arg0, arg1);
481 if (allFloatingPoint)
482 return builder.create<arith::SubFOp>(arg0.
getLoc(), arg0, arg1);
484 llvm_unreachable(
"unsupported operation: sub with bools");
485 return builder.create<arith::SubIOp>(arg0.
getLoc(), arg0, arg1);
488 return builder.create<complex::MulOp>(arg0.
getLoc(), arg0, arg1);
489 if (allFloatingPoint)
490 return builder.create<arith::MulFOp>(arg0.
getLoc(), arg0, arg1);
492 return builder.create<arith::AndIOp>(arg0.
getLoc(), arg0, arg1);
493 return builder.create<arith::MulIOp>(arg0.
getLoc(), arg0, arg1);
496 return builder.create<complex::DivOp>(arg0.
getLoc(), arg0, arg1);
497 if (allFloatingPoint)
498 return builder.create<arith::DivFOp>(arg0.
getLoc(), arg0, arg1);
500 llvm_unreachable(
"unsupported operation: div with bools");
501 return builder.create<arith::DivSIOp>(arg0.
getLoc(), arg0, arg1);
502 case BinaryFn::div_unsigned:
503 if (!allInteger || allBool)
504 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
505 return builder.create<arith::DivUIOp>(arg0.
getLoc(), arg0, arg1);
506 case BinaryFn::max_signed:
508 if (allFloatingPoint)
509 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
510 return builder.create<arith::MaxSIOp>(arg0.
getLoc(), arg0, arg1);
511 case BinaryFn::min_signed:
513 if (allFloatingPoint)
514 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
515 return builder.create<arith::MinSIOp>(arg0.
getLoc(), arg0, arg1);
516 case BinaryFn::max_unsigned:
518 if (allFloatingPoint)
519 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
520 return builder.create<arith::MaxUIOp>(arg0.
getLoc(), arg0, arg1);
521 case BinaryFn::min_unsigned:
523 if (allFloatingPoint)
524 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
525 return builder.create<arith::MinUIOp>(arg0.
getLoc(), arg0, arg1);
527 assert(allFloatingPoint);
528 return builder.create<math::PowFOp>(arg0.
getLoc(), arg0, arg1);
530 llvm_unreachable(
"unsupported binary function");
538 bool tailFloatingPoint =
539 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
540 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
542 builder.setInsertionPointToEnd(&block);
544 case TernaryFn::select:
545 if (!headBool && !(tailFloatingPoint || tailInteger))
546 llvm_unreachable(
"unsupported non numeric type");
547 return builder.create<arith::SelectOp>(arg0.
getLoc(), arg0, arg1, arg2);
549 llvm_unreachable(
"unsupported ternary function");
555 case TypeFn::cast_signed:
556 return cast(toType, operand,
false);
557 case TypeFn::cast_unsigned:
558 return cast(toType, operand,
true);
560 llvm_unreachable(
"unsupported type conversion function");
565 builder.setInsertionPointToEnd(&block);
566 Location loc = builder.getUnknownLoc();
567 builder.create<YieldOp>(loc, values);
570 Value constant(
const std::string &value) {
572 builder.setInsertionPointToEnd(&block);
573 Location loc = builder.getUnknownLoc();
575 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
578 Value index(int64_t dim) {
580 builder.setInsertionPointToEnd(&block);
581 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
584 Type getIntegerType(
unsigned width) {
598 builder.setInsertionPointToEnd(&block);
599 auto loc = operand.
getLoc();
603 bool isComplex(
Value value) {
604 return llvm::isa<ComplexType>(value.
getType());
606 bool isFloatingPoint(
Value value) {
607 return llvm::isa<FloatType>(value.
getType());
609 bool isInteger(
Value value) {
610 return llvm::isa<IntegerType>(value.
getType());
627 LogicalResult matchAndRewrite(CopyOp copyOp,
629 if (copyOp.getInputs() != copyOp.getOutputs())
631 if (copyOp.hasPureBufferSemantics())
634 rewriter.
replaceOp(copyOp, copyOp.getInputs());
644 results.
add<EraseSelfCopy>(context);
657 template <
typename TensorReshapeOp>
660 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
662 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
667 TensorReshapeOp newInit;
668 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
670 newInit = rewriter.
create<TensorReshapeOp>(
671 loc, reshapeOp.getResultType(), oldFill.output(),
672 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
673 reshapeOp.getStaticOutputShape());
675 newInit = rewriter.
create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
677 reshapeOp.getReassociation());
690 LogicalResult matchAndRewrite(tensor::PadOp padOp,
692 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
698 Value padValue = padOp.getConstantPaddingValue();
699 if (!padValue || fillOp.value() != padValue)
705 padOp,
"failed to reify tensor.pad op result shape");
707 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
708 padOp.getLoc(), reifiedShape.front(),
709 padOp.getResultType().getElementType());
715 if (replacement.getType() != padOp.getResultType()) {
716 replacement = rewriter.
create<tensor::CastOp>(
717 fillOp.getLoc(), padOp.getResultType(), replacement);
727 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
730 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
732 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
736 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
741 Value firstDest = insertOp.getDest();
742 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
743 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
748 bool disjoint =
false;
749 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
752 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
753 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
754 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
758 int64_t prevStart = prevOp.getStaticOffset(i);
759 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
760 prevOp.getStaticStride(i);
761 int64_t nextStart = insertOp.getStaticOffset(i);
762 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
763 insertOp.getStaticStride(i);
764 if (prevEnd < nextStart || nextEnd < prevStart) {
772 firstDest = prevOp.getDest();
783 Value padValue = srcPadOp.getConstantPaddingValue();
784 if (!padValue || dstFillOp.value() != padValue)
800 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
802 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
805 RankedTensorType srcPadType = srcPadOp.getSourceType();
807 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
808 if (srcPadType.isDynamicDim(i)) {
810 rewriter.
create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
813 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
818 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
819 newSizes, insertOp.getMixedStrides());
825 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
829 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
833 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
838 Value extractedScalar = fillOp.getInputs()[0];
841 rewriter.
replaceOp(extractOp, extractedScalar);
849 static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
850 tensor::PackOp packOp) {
851 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
855 if (
auto paddingValue = packOp.getPaddingValue())
859 Value packOpDest = packOp.getDest();
863 return rewriter.
create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
873 LogicalResult matchAndRewrite(tensor::PackOp packOp,
875 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
878 rewriter.
replaceOp(packOp, fillOp.value().result());
887 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
889 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
892 copyOp.getOutputs());
895 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
897 fillOp.getOutputs());
908 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
910 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
912 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
913 transposeOp.getDpsInitOperand(0)->get());
925 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
927 auto concatOperands = concatOp.getInputs();
928 if (concatOperands.empty()) {
932 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
941 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
943 auto isDefinedByCompatibleFillOp = [&](
Value v) ->
bool {
944 auto fillOp = v.getDefiningOp<linalg::FillOp>();
951 if (fillVal != firstFillVal)
954 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
957 if (!llvm::all_of(concatOperands.drop_front(),
958 isDefinedByCompatibleFillOp)) {
960 concatOp,
"not all operands are defined by a compatible fill op");
963 Value outsConcat = rewriter.
create<tensor::ConcatOp>(
964 concatOp.getLoc(), concatOp.getDim(), allOuts);
966 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
975 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
976 FoldFillWithPack, FoldFillWithPad,
977 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
978 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
979 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
992 for (
ValueRange container : {inputs, outputs}) {
993 for (
Value v : container) {
994 Type t = v.getType();
995 blockArgTypes.push_back(
997 blockArgLocs.push_back(v.getLoc());
1003 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1007 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
1009 for (
Value v : getRegionInputArgs())
1011 for (
Value v : getRegionOutputArgs())
1012 setNameFn(v,
"out");
1015 void GenericOp::build(
1018 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1021 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1022 iteratorTypes, doc, libraryCall);
1026 inputs, outputs, bodyBuild);
1029 void GenericOp::build(
1033 StringRef libraryCall,
1036 build(builder, result, resultTensorTypes, inputs, outputs,
1041 return IteratorTypeAttr::get(builder.getContext(), iter);
1044 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1045 bodyBuild, attributes);
1048 void GenericOp::build(
1052 StringRef libraryCall,
1055 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
1056 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1059 void GenericOp::build(
1065 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1067 "", bodyBuild, attributes);
1070 void GenericOp::build(
1076 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1079 "", bodyBuild, attributes);
1086 auto genericAttrNames = linalgTraitAttrNames();
1089 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1091 for (
auto attr : (*this)->getAttrs()) {
1092 if (attr.getName() == getIteratorTypesAttrName()) {
1093 auto iteratorTypes =
1094 llvm::cast<ArrayAttr>(attr.getValue())
1095 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1101 llvm::to_vector(llvm::map_range(
1102 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
1106 genericAttrs.emplace_back(
1107 getIteratorTypesAttrName(),
1109 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1110 genericAttrs.push_back(attr);
1113 if (!genericAttrs.empty()) {
1115 p << genericDictAttr;
1121 genericAttrNames.push_back(
"operandSegmentSizes");
1122 genericAttrNamesSet.insert(genericAttrNames.back());
1124 bool hasExtraAttrs =
false;
1126 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1129 if (hasExtraAttrs) {
1136 if (!getRegion().empty()) {
1146 DictionaryAttr dictAttr;
1155 dictAttr.getValue().end());
1161 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1163 if (!iteratorTypes) {
1164 return parser.
emitError(attributeLocation)
1165 <<
"expected " << getIteratorTypesAttrName(result.
name)
1166 <<
" array attribute";
1171 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1172 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1173 if (!maybeIteratorType.has_value())
1175 <<
"unexpected iterator_type (" << s <<
")";
1177 iteratorTypeAttrs.push_back(
1194 std::unique_ptr<Region> region = std::make_unique<Region>();
1206 result.
addTypes(outputTensorsTypes);
1214 LinalgOp linalgOp) {
1215 for (
auto [index, operand] :
llvm::enumerate(linalgOp.getDpsInputs())) {
1216 if (!llvm::isa<MemRefType>(operand.
getType()))
1218 effects.emplace_back(
1223 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1224 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1226 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1237 void GenericOp::getEffects(
1247 if (!linalgOp.hasPureTensorSemantics())
1266 template <
typename OpTy>
1270 LogicalResult matchAndRewrite(OpTy linalgOp,
1273 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1278 Block &body = linalgOp->getRegion(0).
front();
1279 if (!llvm::hasSingleElement(body))
1281 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1286 if (linalgOp.hasPureBufferSemantics()) {
1287 if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1288 linalgOp.getDpsInputOperand(0)->get() ==
1289 linalgOp.getDpsInitOperand(0)->get()) {
1297 if (!linalgOp.hasPureTensorSemantics())
1304 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1305 if (!yieldArg || yieldArg.getOwner() != &body)
1307 unsigned argumentNumber = yieldArg.getArgNumber();
1308 Value returnedArg = linalgOp->getOperand(argumentNumber);
1309 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1313 if (returnType != resultType) {
1318 returnedArg = rewriter.
create<sparse_tensor::ConvertOp>(
1319 linalgOp.getLoc(), resultType, returnedArg);
1321 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1324 returnedArg = rewriter.
create<tensor::CastOp>(
1325 linalgOp.getLoc(), resultType, returnedArg);
1328 returnedArgs.push_back(returnedArg);
1331 if (returnedArgs.size() != linalgOp->getNumResults())
1333 rewriter.
replaceOp(linalgOp, returnedArgs);
1342 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1364 for (
Type outputType : outputTypes) {
1365 if (llvm::isa<RankedTensorType>(outputType))
1370 if (parseAttrsFn && failed(parseAttrsFn(parser, result.
attributes)))
1379 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1381 for (
Value v : getRegionInputArgs())
1386 if (!getResults().empty())
1387 setNameFn(getResults().front(),
"mapped");
1394 build(builder, result,
TypeRange{}, inputs, init);
1399 if (llvm::isa<RankedTensorType>(initType))
1404 inputs, {}, bodyBuild);
1411 bool initFirst =
false) {
1417 for (
auto &operand : operands) {
1419 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1426 payloadOpOperands.push_back(block.
getArguments().back());
1427 for (
const auto &arg : block.
getArguments().drop_back())
1428 payloadOpOperands.push_back(arg);
1437 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1444 std::optional<OperationName> payloadOpName;
1448 if (failed(operationName))
1452 payloadOpName = operationName.value();
1460 if (payloadOpName.has_value()) {
1498 for (
const auto &[operand, bbArg] :
1500 if (bbArg != operand)
1504 for (
const auto &[operand, bbArg] :
1506 if (bbArg != operand)
1515 std::string attrToElide;
1517 for (
const auto &attr : payloadOp->
getAttrs()) {
1519 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1520 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1521 attrToElide = attr.getName().str();
1522 elidedAttrs.push_back(attrToElide);
1531 Block *mapper = getBody();
1546 [&](
auto arg) { p.printRegionArgument(arg); });
1555 auto *bodyBlock = getBody();
1556 auto blockArgs = bodyBlock->getArguments();
1559 if (getInputs().size() != blockArgs.size())
1560 return emitOpError() <<
"expects number of operands to match the arity of "
1562 << getInputs().size() <<
" and " << blockArgs.size();
1565 for (
const auto &[bbArgType, inputArg] :
1566 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1567 auto inputElemType =
1568 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1569 if (bbArgType != inputElemType) {
1570 return emitOpError() <<
"expected element type of input " << inputElemType
1571 <<
" to match bbArg type " << bbArgType;
1576 auto outputShape = getInit().getType().getShape();
1578 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1579 if (inputElemShape != outputShape) {
1580 return emitOpError() <<
"expected shape of input (" << inputElemShape
1581 <<
") to match shape of output (" << outputShape
1590 int64_t rank = getInit().getType().getRank();
1594 ArrayAttr MapOp::getIndexingMaps() {
1596 int64_t rank = getInit().getType().getRank();
1597 int64_t numIndexingMaps = getOperands().size();
1602 void MapOp::getEffects(
1616 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1618 for (
Value v : getRegionInputArgs())
1620 for (
Value v : getRegionOutputArgs())
1621 setNameFn(v,
"init");
1624 void ReduceOp::getAsmResultNames(
1626 if (!getResults().empty())
1627 setNameFn(getResults().front(),
"reduced");
1630 void ReduceOp::build(
1635 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1639 for (
Value init : inits) {
1641 if (llvm::isa<RankedTensorType>(initType))
1647 inputs, inits, bodyBuild);
1652 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1654 utils::IteratorType::parallel);
1655 for (int64_t reductionDim : getDimensions())
1656 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1657 return iteratorTypes;
1660 ArrayAttr ReduceOp::getIndexingMaps() {
1662 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1669 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1670 affineMaps.push_back(resultMap);
1674 void ReduceOp::getEffects(
1686 StringRef attributeName) {
1695 std::optional<OperationName> payloadOpName;
1699 if (failed(operationName))
1703 payloadOpName = operationName.value();
1714 if (payloadOpName.has_value()) {
1734 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1738 Block *mapper = getBody();
1753 [&](
auto arg) { p.printRegionArgument(arg); });
1764 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1765 if (llvm::cast<ShapedType>(getInputs()[i].
getType()).getShape() !=
1766 llvm::cast<ShapedType>(getInputs()[0].
getType()).getShape()) {
1767 return emitOpError() <<
"expects all inputs to have the same shapes. "
1768 "Shape at input-index "
1770 <<
" is not equal to the shape at input-index 0.";
1773 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1774 if (llvm::cast<ShapedType>(getInits()[i].
getType()).getShape() !=
1775 llvm::cast<ShapedType>(getInits()[0].
getType()).getShape()) {
1776 return emitOpError() <<
"expects all outputs to have the same shapes. "
1777 "Shape at output-index "
1779 <<
" is not equal to the shape at output-index 0.";
1782 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1783 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1786 for (int64_t dimension : dimensionsRef) {
1787 if (dimension < 0 || dimension >= inputType.getRank()) {
1788 return emitOpError()
1789 <<
"dimensions for reduction should be in the range [0, "
1790 << inputType.getRank() - 1 <<
"].";
1792 dimensionsToReduce.insert(dimension);
1795 auto inputDims = inputType.getShape();
1796 auto initDims = initType.getShape();
1801 if (!dimensionsToReduce.count(en.index()))
1802 reducedInputDims.push_back(en.value());
1805 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1806 return emitOpError() <<
"number of dimensions after reduction "
1807 << reducedInputDims.size()
1808 <<
" doesn't match the init rank "
1809 << initType.getRank();
1812 if (reducedInputDims != initDims)
1813 return emitOpError() <<
"init dimensions [" << initDims
1814 <<
"] doesn't match input dimensions after reduction ["
1815 << reducedInputDims <<
"]";
1817 Block *block = getBody();
1819 return emitOpError()
1820 <<
"mismatching number of operands and block arguments";
1823 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1824 Type inputElementType =
1825 llvm::cast<ShapedType>(input.getType()).getElementType();
1826 if (inputElementType != bbArg.getType())
1827 return emitOpError()
1828 <<
"input element type " << inputElementType
1829 <<
" does not match corresponding block argument type "
1834 for (
auto [output, bbArg] : llvm::zip(
1835 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1836 auto outputElementType =
1837 llvm::cast<ShapedType>(output.getType()).getElementType();
1838 if (outputElementType != bbArg.getType())
1839 return emitOpError()
1840 <<
"output element type " << outputElementType
1841 <<
" does not match corresponding block argument type "
1857 b.
create<linalg::YieldOp>(loc, args[0]);
1872 if (llvm::isa<RankedTensorType>(initType))
1901 void TransposeOp::getAsmResultNames(
1903 if (!getResults().empty())
1904 setNameFn(getResults().front(),
"transposed");
1917 return emitOpError(
"permutation is not valid");
1919 auto inputType = getInput().getType();
1920 auto initType = getInit().getType();
1922 int64_t rank = inputType.getRank();
1924 if (rank != initType.getRank())
1925 return emitOpError() <<
"input rank " << rank
1926 <<
" does not match init rank " << initType.getRank();
1928 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
1929 return emitOpError() <<
"size of permutation " << permutationRef.size()
1930 <<
" does not match the argument rank " << rank;
1932 auto inputDims = inputType.getShape();
1933 auto initDims = initType.getShape();
1935 for (int64_t i = 0; i < rank; ++i) {
1936 int64_t inputDim = inputDims[permutationRef[i]];
1937 int64_t initDim = initDims[i];
1939 if (inputDim != initDim) {
1940 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
1941 <<
" doesn't match dim(input, permutation[" << i
1942 <<
"]) = " << inputDim;
1950 int64_t rank = getInit().getType().getRank();
1954 ArrayAttr TransposeOp::getIndexingMaps() {
1956 int64_t rank = getInit().getType().getRank();
1959 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
1963 void TransposeOp::getEffects(
1973 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1976 if (!isa<TensorType>(getInput().
getType()))
1980 if (getPermutation().size() == 0) {
1981 result.push_back(getInput());
1986 result.push_back(getInput());
1999 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2000 if (!defTransposeOp)
2005 foldedPerms.reserve(perms.size());
2006 for (int64_t perm : perms)
2007 foldedPerms.push_back(defPerms[perm]);
2010 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2024 Value input = transposeOp.getInput();
2025 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2036 unsigned dimensionSize = dimensions.size();
2037 for (
unsigned i = 0; i < dimensionSize; ++i)
2038 resultDimensions.push_back(invertPerm[dimensions[i]]);
2041 Value broadcastInput = broadcastOp.getInput();
2042 Location loc = transposeOp.getLoc();
2045 auto broadcastInputTy =
2046 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2047 unsigned inputRank = broadcastInputTy.getRank();
2048 for (
unsigned i = 0; i < inputRank; ++i) {
2049 if (broadcastInputTy.isDynamicDim(i)) {
2050 dims.push_back(rewriter.
create<tensor::DimOp>(loc, broadcastInput, i)
2054 broadcastInputTy.getDimSize(i)));
2059 Value transposeInit = rewriter.
create<tensor::EmptyOp>(
2060 transposeOp.getLoc(), transposeResultShapes,
2061 broadcastInputTy.getElementType());
2064 Value transposeResult =
2066 .
create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2070 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2095 if (llvm::isa<RankedTensorType>(initType))
2124 void BroadcastOp::getAsmResultNames(
2126 if (!getResults().empty())
2127 setNameFn(getResults().front(),
"broadcasted");
2139 auto inputType = getInput().getType();
2140 auto initType = getInit().getType();
2142 int64_t inputRank = inputType.getRank();
2143 int64_t initRank = initType.getRank();
2145 auto inputShape = inputType.getShape();
2146 auto initShape = initType.getShape();
2148 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2149 return emitOpError() <<
"input rank plus added dimensions does not "
2150 "match init rank. input rank: "
2152 <<
", dimensions size: " << dimensionsRef.size()
2153 <<
", init rank: " << initRank;
2156 if (dim < 0 || dim >= initRank)
2157 return emitOpError() <<
"dimension " << idx
2158 <<
" is out of range. expected range: [0, "
2159 << initRank - 1 <<
"], got: " << dim;
2164 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2165 if (!llvm::is_contained(dimensionsRef, dim))
2166 dimMap.push_back(dim);
2169 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
2172 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2173 return emitOpError() <<
"input dim " << inputDimIdx
2174 <<
" should match init dim " << initDimIdx
2175 <<
". input: " << inputShape[inputDimIdx]
2176 <<
", init: " << initShape[initDimIdx];
2183 int64_t rank = getInit().getType().getRank();
2187 ArrayAttr BroadcastOp::getIndexingMaps() {
2189 int64_t rank = getInit().getType().getRank();
2195 void BroadcastOp::getEffects(
2207 results.
add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2215 if (getNumOperands() > 0)
2216 p <<
' ' << getOperands();
2218 if (getNumOperands() > 0)
2219 p <<
" : " << getOperandTypes();
2234 static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2235 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2236 return op.emitOpError(
"expected number of yield values (")
2237 << op.getNumOperands()
2238 <<
") to match the number of inits / outs operands of the enclosing "
2239 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2241 for (
OpOperand &opOperand : op->getOpOperands()) {
2243 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2245 if (isa<MemRefType, RankedTensorType>(elementType))
2247 if (opOperand.get().getType() != elementType)
2248 return op.emitOpError(
"type of yield operand ")
2249 << (opOperand.getOperandNumber() + 1) <<
" ("
2250 << opOperand.get().getType() <<
") doesn't match "
2251 <<
"the element type of the enclosing linalg.generic op ("
2252 << elementType <<
")";
2258 auto *parentOp = (*this)->getParentOp();
2259 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2260 return emitOpError(
"expected single non-empty parent region");
2262 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2265 return emitOpError(
"expected parent op with LinalgOp interface");
2273 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2275 return emitOpError(
"expected parent op with LinalgOp interface");
2276 if (linalgOp.getNumLoops() <= getDim())
2277 return emitOpError(
"expected dim (")
2278 << getDim() <<
") to be lower than the number of loops ("
2279 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2285 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2287 #define GET_OP_CLASSES
2288 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2290 #define GET_OP_CLASSES
2291 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2308 for (
unsigned i = 0; i < num; ++i)
2315 auto rangeA = llvm::make_range(a.begin(), a.end());
2316 auto rangeB = llvm::make_range(b.begin(), b.end());
2317 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2318 return llvm::to_vector<4>(concatRanges);
2322 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2324 for (
auto size : memref.getShape())
2331 if (
auto as = memref.getMemorySpace()) {
2332 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2333 ss <<
"as" << attr.getInt();
2339 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2342 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2355 assert(isa<LinalgOp>(op));
2357 std::string fun =
"";
2359 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2360 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2361 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2362 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2366 std::replace(name.begin(), name.end(),
'.',
'_');
2367 llvm::raw_string_ostream ss(name);
2371 return std::string();
2386 LogicalResult matchAndRewrite(LinalgOp op,
2388 for (
OpOperand &opOperand : op->getOpOperands()) {
2392 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2395 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2406 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2409 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2414 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2421 if (castOp->getBlock() != linalgOp->getBlock())
2428 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2431 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2437 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2439 rewriter.
create<tensor::CastOp>(loc, resultType, outOperand->
get());
2442 linalgOp.getDpsInits().end());
2443 outputOperands[resultNumber] = newOperand;
2444 newOperands.append(outputOperands.begin(), outputOperands.end());
2447 linalgOp->result_type_end());
2448 resultTypes[resultNumber] = resultType;
2449 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2456 results[resultNumber] = castBack;
2468 if (linalgOp.isScalar(&opOperand))
2470 Value src = opOperand.get();
2471 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2472 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2480 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2481 Value castSource = castOp.getSource();
2482 auto castSourceType =
2483 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2484 if (castSourceType && castSourceType.hasStaticShape())
2485 sourceShape = castSourceType.getShape();
2491 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2492 if (sourceType.isDynamicDim(i))
2494 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2495 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2505 static void createNewOperandWithStaticSizes(
2509 bool &changeNeeded) {
2511 newOperands.push_back(src);
2512 if (linalgOp.isScalar(opOperand))
2514 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2515 Type resultType = sourceType;
2516 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2517 resultTypes.push_back(resultType);
2521 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2525 bool newOperandNeeded =
false;
2526 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2527 int64_t dimShape = sourceShape[i];
2529 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2530 newShape.push_back(dimShape);
2536 newShape.push_back(affineExprToSize[dimExpr]);
2537 newOperandNeeded =
true;
2540 if (newOperandNeeded) {
2541 changeNeeded =
true;
2544 Value newOperand = rewriter.
create<tensor::CastOp>(loc, resultType, src);
2546 newOperands[index] = newOperand;
2548 if (linalgOp.isDpsInit(opOperand))
2549 resultTypes.push_back(resultType);
2558 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2560 if (!linalgOp.hasPureTensorSemantics())
2564 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2565 return !map.isProjectedPermutation();
2575 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2582 bool changeNeeded =
false;
2583 newOperands.reserve(linalgOp->getNumOperands());
2584 resultTypes.reserve(linalgOp.getNumDpsInits());
2587 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2588 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2589 affineExprToSize, linalgOp, newOperands,
2590 resultTypes, changeNeeded);
2599 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2602 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2603 Value newResult = std::get<1>(it);
2604 Value oldResult = std::get<0>(it);
2607 replacements.push_back(
2608 (newType != oldType)
2609 ? rewriter.
create<tensor::CastOp>(loc, oldType, newResult)
2612 rewriter.
replaceOp(linalgOp, replacements);
2627 ShapedType inputType = getInputOperandType();
2628 ShapedType outputType = getOutputOperandType();
2633 return emitOpError(
"incompatible output shape");
2635 int64_t inputRank = getInputOperandRank();
2636 int64_t dimension = getDimension();
2637 if ((dimension < 0) || (dimension >= inputRank))
2638 return emitOpError(
"incorrect dimension specified");
2644 int64_t operandRank = getInputOperandRank();
2647 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
2648 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2649 Value source = getInput();
2650 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2651 loopBounds[dim].offset = zero;
2652 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2653 loopBounds[dim].stride = one;
2660 utils::IteratorType::parallel);
2661 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2662 return iteratorTypes;
2665 FailureOr<TilingResult>
2669 int64_t rank = getInputOperandRank();
2674 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2676 return emitOpError(
"failed to compute input slice");
2678 tiledOperands.emplace_back(inputSlice->
getResult(0));
2680 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2682 return emitOpError(
"failed to compute output slice");
2684 tiledOperands.emplace_back(outputSlice->
getResult(0));
2687 if (hasPureTensorSemantics())
2688 resultTypes.push_back(tiledOperands[1].
getType());
2690 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2702 if (resultNumber == 0) {
2703 resultOffsets.assign(offsets.begin(), offsets.end());
2704 resultSizes.assign(sizes.begin(), sizes.end());
2719 Location loc = getOperation()->getLoc();
2721 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2722 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2723 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2724 if (!outputShapedType.isDynamicDim(dim)) {
2726 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2733 reifiedReturnShapes.emplace_back(std::move(shapes));
2737 void SoftmaxOp::getEffects(
2741 if (!llvm::isa<MemRefType>(operand.
getType()))
2744 &getOperation()->getOpOperand(index), 0,
2749 for (
OpOperand &operand : getDpsInitsMutable()) {
2750 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2783 int64_t dim,
bool allParallel =
false) {
2785 utils::IteratorType::parallel);
2787 iteratorTypes[dim] = utils::IteratorType::reduction;
2791 for (
int i = 0; i < inputRank; i++) {
2798 return std::make_tuple(iteratorTypes, indexingMaps);
2803 template <
typename T>
2806 auto inputType = cast<ShapedType>(input.
getType());
2808 int64_t inputRank = inputShape.size();
2809 auto [iteratorTypes, indexingMaps] =
2811 assert(indexingMaps.size() == 2 &&
2812 "We should have two maps: 1 for the input, 1 for the output");
2813 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2815 auto genericOp = builder.
create<linalg::GenericOp>(
2816 loc, output.
getType(), input, output, indexingMaps, iteratorTypes,
2818 Value result = b.create<T>(loc, args[0], args[1]);
2819 b.create<linalg::YieldOp>(loc, result);
2829 auto inputType = cast<ShapedType>(input.
getType());
2831 int64_t inputRank = inputShape.size();
2833 builder, inputRank, dim,
true);
2834 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2835 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2837 indexingMaps.push_back(indexingMaps[0]);
2838 auto genericOp = builder.
create<linalg::GenericOp>(
2841 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2842 Value result = b.create<math::ExpOp>(loc, diff);
2843 b.create<linalg::YieldOp>(loc, result);
2854 Value denominator,
Value output, int64_t dim) {
2855 auto inputType = cast<ShapedType>(numerator.
getType());
2857 int64_t inputRank = inputShape.size();
2859 builder, inputRank, dim,
true);
2860 assert(indexingMaps.size() == 2 &&
2861 "We should have one map for each input (2)");
2862 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
2864 indexingMaps.push_back(indexingMaps[0]);
2865 auto genericOp = builder.
create<linalg::GenericOp>(
2867 indexingMaps, iteratorTypes,
2869 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2870 b.create<linalg::YieldOp>(loc, result);
2894 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(
OpBuilder &b) {
2898 Value input = getInput();
2899 ShapedType inputType = getInputOperandType();
2900 Type elementType = inputType.getElementType();
2901 int64_t reductionDim = getDimension();
2903 Value output = getOutput();
2904 dims.erase(dims.begin() + reductionDim);
2906 Value outputReduce = b.
create<tensor::EmptyOp>(loc, dims, elementType);
2908 elementType, b, loc,
2910 Value neutralForMaxFInit =
2911 b.
create<linalg::FillOp>(loc,
Value{neutralForMaxF}, outputReduce)
2914 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2923 b.
create<linalg::FillOp>(loc,
Value{zero}, outputReduce).result();
2925 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2929 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2938 auto filterType = cast<ShapedType>(getFilter().
getType());
2940 int64_t filterH = filterShape[getFilterHDim()];
2941 int64_t filterW = filterShape[getFilterWDim()];
2945 if (filterH != r && filterH != 1)
2946 return emitOpError(
"expect filter height either equals to r or 1");
2947 if (filterW != r && filterW != 1)
2948 return emitOpError(
"expect filter width either equals to r or 1");
2949 if (filterH == 1 && filterW == 1)
2950 return emitOpError(
"expect either filter height or width equals to r");
2953 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2954 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2955 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2956 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2958 auto outputType = cast<ShapedType>(getOutput().
getType());
2961 return emitOpError(
"the output shape is not expected");
2967 WinogradFilterTransformOp::getIterationDomain(
OpBuilder &builder) {
2971 Value filter = getFilter();
2972 int64_t filterRank = getFilterOperandRank();
2974 for (
unsigned dim = 0; dim < filterRank; ++dim) {
2975 loopBounds[dim].offset = zeroAttr;
2976 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
2977 loopBounds[dim].stride = oneAttr;
2983 WinogradFilterTransformOp::getLoopIteratorTypes() {
2984 int64_t filterRank = getFilterOperandRank();
2986 utils::IteratorType::parallel);
2987 return iteratorTypes;
2995 ShapedType filterType = getFilterOperandType();
2997 int64_t filterH = filterShape[getFilterHDim()];
2998 int64_t filterW = filterShape[getFilterWDim()];
3001 int64_t alpha = m + r - 1;
3002 int64_t alphaH = filterH != 1 ? alpha : 1;
3003 int64_t alphaW = filterW != 1 ? alpha : 1;
3007 resultOffsets.append(
3008 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3010 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3026 ShapedType filterType = getFilterOperandType();
3028 int64_t filterH = filterShape[getFilterHDim()];
3029 int64_t filterW = filterShape[getFilterWDim()];
3035 sliceOffsets.append(
3036 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3037 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3038 sizes[getFilterCDim()]});
3039 int64_t filterRank = getFilterOperandRank();
3042 auto filterSlice = builder.
create<tensor::ExtractSliceOp>(
3043 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3044 tiledOperands.emplace_back(filterSlice);
3051 int64_t outputRank = getOutputOperandRank();
3053 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3054 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3055 tiledOperands.emplace_back(outputSlice);
3058 resultTypes.push_back(tiledOperands[1].
getType());
3060 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3073 auto inputType = cast<ShapedType>(getInput().
getType());
3075 int64_t inputH = inputShape[getInputHDim()];
3076 int64_t inputW = inputShape[getInputWDim()];
3079 int64_t tileSize = m + r - 1;
3081 auto outputType = cast<ShapedType>(getOutput().
getType());
3083 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3084 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3087 if (ShapedType::isDynamic(inputH)) {
3088 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3089 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3091 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3092 expectedOutputShape[getOutputTileHDim()] =
3093 leftTransform ? (inputH - (r - 1)) / m : inputH;
3095 if (ShapedType::isDynamic(inputW)) {
3096 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3097 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3099 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3100 expectedOutputShape[getOutputTileWDim()] =
3101 rightTransform ? (inputW - (r - 1)) / m : inputW;
3103 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3104 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3107 return emitOpError(
"the output shape is not expected");
3113 WinogradInputTransformOp::getIterationDomain(
OpBuilder &builder) {
3117 Value output = getOutput();
3118 int64_t outputRank = getOutputOperandRank();
3120 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3121 loopBounds[dim].offset = zeroAttr;
3123 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3124 loopBounds[dim].stride = oneAttr;
3130 WinogradInputTransformOp::getLoopIteratorTypes() {
3131 int64_t outputRank = getOutputOperandRank();
3133 utils::IteratorType::parallel);
3134 return iteratorTypes;
3142 ShapedType outputType = getOutputOperandType();
3144 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3145 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3149 int64_t alpha = m + r - 1;
3150 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3151 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3156 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3157 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3158 offsets[getOutputCDim()]});
3159 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3160 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3161 sizes[getOutputCDim()]});
3172 FailureOr<TilingResult>
3180 ShapedType outputType = getOutputOperandType();
3182 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3183 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3187 auto identityAffineMap =
3189 auto offsetAffineMap =
3192 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3193 offsets[getOutputTileHDim()]);
3195 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3196 offsets[getOutputTileWDim()]);
3200 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3202 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3209 sliceOffsets.append(
3210 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3216 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3217 int64_t inputRank = getInputOperandRank();
3219 auto inputSlice = builder.
create<tensor::ExtractSliceOp>(
3220 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3221 tiledOperands.emplace_back(inputSlice);
3228 int64_t outputRank = getOutputOperandRank();
3230 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3231 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3232 tiledOperands.emplace_back(outputSlice);
3235 resultTypes.push_back(tiledOperands[1].
getType());
3237 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3250 auto valueType = cast<ShapedType>(getValue().
getType());
3252 int64_t valueH = valueShape[getValueAlphaHDim()];
3253 int64_t valueW = valueShape[getValueAlphaWDim()];
3254 int64_t valueTileH = valueShape[getValueTileHDim()];
3255 int64_t valueTileW = valueShape[getValueTileWDim()];
3258 bool leftTransform = valueH != 1;
3259 bool rightTransform = valueW != 1;
3261 int64_t outputRank = getOutputOperandRank();
3263 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3264 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3266 if (valueH != (leftTransform ? m + r - 1 : 1))
3267 return emitOpError(
"expect input height equals to input tile size");
3268 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3270 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3271 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3273 if (valueW != (rightTransform ? m + r - 1 : 1))
3274 return emitOpError(
"expect input width equals to input tile size");
3275 expectedOutputShape[getOutputWDim()] =
3276 (rightTransform ? m : 1) * valueTileW;
3278 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3279 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3281 auto outputType = cast<ShapedType>(getOutput().
getType());
3284 return emitOpError(
"the output shape is not expected");
3290 WinogradOutputTransformOp::getIterationDomain(
OpBuilder &builder) {
3294 Value value = getValue();
3295 int64_t valueRank = getValueOperandRank();
3297 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3298 loopBounds[dim].offset = zeroAttr;
3300 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3301 loopBounds[dim].stride = oneAttr;
3307 WinogradOutputTransformOp::getLoopIteratorTypes() {
3308 int64_t valueRank = getValueOperandRank();
3310 utils::IteratorType::parallel);
3311 return iteratorTypes;
3322 auto identityAffineMap =
3327 ShapedType valueType = getValueOperandType();
3329 int64_t valueH = valueShape[0];
3330 int64_t valueW = valueShape[1];
3332 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3333 offsets[getValueTileHDim()]);
3335 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3336 offsets[getValueTileWDim()]);
3338 builder, loc, affineMap, sizes[getValueTileHDim()]);
3340 builder, loc, affineMap, sizes[getValueTileWDim()]);
3350 resultOffsets.append(
3351 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3353 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3372 ShapedType valueType = getValueOperandType();
3374 int64_t alphaH = valueShape[getValueAlphaHDim()];
3375 int64_t alphaW = valueShape[getValueAlphaWDim()];
3379 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3380 offsets[getValueTileWDim()], offsets[getValueNDim()],
3381 offsets[getValueFDim()]});
3382 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3383 sizes[getValueTileWDim()], sizes[getValueNDim()],
3384 sizes[getValueFDim()]});
3385 int64_t valueRank = getValueOperandRank();
3387 auto valueSlice = builder.
create<tensor::ExtractSliceOp>(
3388 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3389 tiledOperands.emplace_back(valueSlice);
3396 int64_t outputRank = getOutputOperandRank();
3398 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3399 loc, getOutput(), resultOffsets, resultSizes, strides);
3400 tiledOperands.emplace_back(outputSlice);
3403 resultTypes.push_back(tiledOperands[1].
getType());
3405 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3417 void LinalgDialect::getCanonicalizationPatterns(
3419 results.
add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
3426 return arith::ConstantOp::materialize(builder, value, type, loc);
3435 llvm::set_union(explicitSet, defaultSet);
3436 return explicitSet == defaultSet;
3456 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3458 auto opIndexingMap = opIndexingMaps[opIndex];
3459 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3462 return matmulOp->emitOpError()
3463 <<
"Unexpected dim expression in map result.";
3466 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3467 return matmulOp->emitOpError()
3468 <<
"Invalid broadcast requested, should be (d2).";
3478 AffineMap defaultIndexingMap,
bool isLHS) {
3481 return batchMatmulOp->emitOpError()
3482 <<
"Unexpected result dim expression (outside the set of default "
3487 return batchMatmulOp->emitOpError()
3488 <<
"no. of result dim expressions exceeds 3.";
3490 auto hasValidBatchDim = [](
AffineMap map) {
3497 if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3498 return batchMatmulOp->emitOpError() <<
"Invalid broadcast requested.";
3499 }
else if (!hasValidBatchDim(opIndexingMap)) {
3500 return batchMatmulOp->emitOpError()
3501 <<
"Invalid batch dimension expression.";
3512 return batchMatmulOp->emitOpError()
3513 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3516 auto areValidOutputResultDim = [](
AffineMap outputMap) {
3517 return outputMap.getResult(0).isFunctionOfDim(0) &&
3518 outputMap.getResult(1).isFunctionOfDim(1) &&
3519 outputMap.getResult(2).isFunctionOfDim(2);
3522 if (!areValidOutputResultDim(opIndexingMap))
3523 return batchMatmulOp->emitOpError()
3524 <<
"Invalid output map result dimension.";
3531 static LogicalResult
3535 batchMatmulOp.getIndexingMapsArray();
3537 batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
3539 if (opIndexingMaps.size() != 3)
3540 return batchMatmulOp->emitOpError()
3541 <<
"Indexing_map attribute must have 3 affine maps.";
3543 auto opIndexingMap = opIndexingMaps[opIndex];
3544 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3546 if (opIndex == 2 && failed(
verifyOutputMap(batchMatmulOp, opIndexingMap)))
3549 if (failed(
verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
3571 return indexingMaps;
3576 utils::IteratorType::parallel,
3577 utils::IteratorType::reduction};
3580 unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3582 std::string MatmulOp::getLibraryCallName() {
3586 bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3590 bool MatmulOp::hasUserDefinedMaps() {
3594 return defaultMaps != explicitMaps;
3602 "MatmulOp regionBuilder expects 3 (>=0) args");
3603 RegionBuilderHelper helper(b, block);
3606 TypeFn castVal = TypeFn::cast_signed;
3607 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3608 return attr.
getName() ==
"cast";
3610 if (castIter != attrs.end()) {
3611 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3619 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3621 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), value3);
3622 yields.push_back(value4);
3623 helper.yieldOutputs(yields);
3627 bool MatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap) {
3628 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3639 ArrayAttr arrayAttr;
3643 if (llvm::any_of(arrayAttr,
3644 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3646 <<
"element of indexing_maps array is not an affine_map";
3653 if (failed(indexingMapsAttr))
3656 if (*indexingMapsAttr ==
nullptr) {
3657 auto indexingMapAttrs = llvm::map_to_vector(
3658 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3663 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3665 MatmulOp::getRegionBuilder());
3670 MatmulOp::getDefaultIndexingMaps(
getContext()),
3672 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
3673 p <<
" indexing_maps = [";
3674 llvm::interleaveComma(getIndexingMaps(), p,
3680 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3688 if (!hasUserDefinedMaps())
3691 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3702 void MatmulOp::getEffects(
3705 if (hasPureTensorSemantics())
3719 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3729 for (
auto result : outAffineMap.
getResults()) {
3730 auto dimExpr = dyn_cast<AffineDimExpr>(result);
3731 assert(dimExpr &&
"affine_map is a projected permutation");
3732 dimsInOutput[dimExpr.getPosition()] =
true;
3736 for (
auto dimOccursInOutput : dimsInOutput)
3737 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3738 : utils::IteratorType::reduction);
3740 return iteratorTypes;
3743 unsigned ContractOp::getNumRegionArgs() {
return 3; }
3749 "ContractOp regionBuilder expects 3 args");
3750 RegionBuilderHelper helper(b, block);
3752 TypeFn castSignedness = TypeFn::cast_signed;
3753 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3754 return attr.
getName() ==
"cast";
3756 if (castIter != attrs.end()) {
3757 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3763 Value lhsAtOutType =
3764 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
3765 Value rhsAtOutType =
3766 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
3767 Value productAtOutType =
3768 helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3771 helper.yieldOutputs({result});
3776 if (failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
3778 "expected 'indexing_maps' attribute");
3779 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3786 p <<
" indexing_maps = [";
3787 llvm::interleaveComma(getIndexingMaps(), p,
3791 p, getOperation(), getInputs(), getOutputs(),
3792 {
"indexing_maps",
"operandSegmentSizes"});
3796 int iterationSpaceDims = -1;
3805 auto checkAffineMapAndType = [&](
AffineMap affineMap,
Type operandType,
3806 bool isInput) -> LogicalResult {
3809 return emitError(
"provided affine_map is not a projected permutation");
3812 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
3814 return emitError(
"ranks of shaped operand and results of corresponding "
3815 "affine_map differ");
3817 return emitError(
"affine_map specifies shaped access while operand has "
3822 if (iterationSpaceDims == -1) {
3826 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
3827 return emitError(
"iteration spaces of provided affine_maps differ");
3832 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3834 llvm_unreachable(
"affine_map is a projected permutation");
3837 inOccurrences[affineDimExpr.getPosition()] += 1;
3839 outOccurrences[affineDimExpr.getPosition()] += 1;
3845 for (
auto &&[affineMap, operandType, isInput] :
3846 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3848 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3852 bool hasContractingDim =
false;
3853 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3854 size_t inOccCount = inOccurrences[dimIndex];
3855 size_t outOccCount = outOccurrences[dimIndex];
3858 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3860 if (inOccCount == 0 && outOccCount == 0)
3861 return emitError() <<
"iteration space dim at index " << dimIndex
3862 <<
" not used to access any operand";
3873 if (inOccCount == 1 && outOccCount != 1)
3875 <<
"iteration space dim at index " << dimIndex
3876 <<
" is neither a contracting dim nor of parallel iteration type";
3879 if (!hasContractingDim)
3880 return emitError(
"'indexing_maps' do not specify a contracting dimension");
3889 void ContractOp::getEffects(
3892 if (hasPureTensorSemantics())
3905 BatchMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
3909 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
3910 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
3911 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
3912 return indexingMaps;
3917 utils::IteratorType::parallel, utils::IteratorType::parallel,
3918 utils::IteratorType::parallel, utils::IteratorType::reduction};
3921 unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
3923 std::string BatchMatmulOp::getLibraryCallName() {
3929 bool BatchMatmulOp::hasUserDefinedMaps() {
3933 return defaultMaps != explicitMaps;
3937 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
bool isLHS) {
3939 "Expected less than 3 result dim expr.");
3940 bool isValid =
false;
3941 enum Indices { batchPos, mPos, nPos, kPos };
3958 "BatchMatmulOp regionBuilder expects 3 (>=0) args");
3959 RegionBuilderHelper helper(b, block);
3964 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
3966 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
3967 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
3969 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
3970 yields.push_back(addVal);
3971 helper.yieldOutputs(yields);
3987 if (!isa<AffineMapAttr>(mapAttr)) {
3989 "expected affine map attribute");
3991 indexingMapsAttr.push_back(mapAttr);
4001 if (indexingMapsAttr.empty()) {
4002 indexingMapsAttr = llvm::map_to_vector(
4003 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4010 BatchMatmulOp::getNumRegionArgs(),
4011 BatchMatmulOp::getRegionBuilder());
4016 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4021 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4023 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
4024 p <<
" indexing_maps = [";
4025 llvm::interleaveComma(getIndexingMaps(), p,
4035 if (!hasUserDefinedMaps())
4038 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4045 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4050 void BatchMatmulOp::getEffects(
4053 if (hasPureTensorSemantics())
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp has exactly 3 result di...
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.