40 #include "llvm/ADT/DenseMap.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SetOperations.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/ADT/SmallVector.h"
45 #include "llvm/ADT/StringSet.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include "llvm/Support/MathExtras.h"
50 #include "llvm/Support/raw_ostream.h"
60 auto type = cast<ShapedType>(v.
getType());
61 if (!type.isDynamicDim(dim))
66 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
67 return builder.create<tensor::DimOp>(loc, v, dim);
69 .Case<MemRefType>([&](MemRefType t) ->
Value {
70 return builder.create<memref::DimOp>(loc, v, dim);
81 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
82 return b.
create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
85 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
86 return b.
create<memref::SubViewOp>(loc, source, offsets, sizes,
98 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
100 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
102 llvm_unreachable(
"Expected MemRefType or TensorType");
107 auto shapedType = llvm::cast<ShapedType>(source.
getType());
108 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
130 assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
134 for (
auto containers : {inputTypes, outputTypes}) {
135 for (
auto t : containers) {
147 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
151 regionBuilder(b, *body, attrs);
163 std::optional<TypeRange> resultTensorTypes,
170 if (!resultTensorTypes)
171 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred<RankedTensorType>);
174 state.addOperands(inputs);
175 state.addOperands(outputs);
176 state.addTypes(derivedResultTypes);
178 state.addAttributes(attributes);
180 "operandSegmentSizes",
182 static_cast<int32_t>(outputs.size())}));
185 Region ®ion = *state.addRegion();
187 state.attributes.getAttrs(), regionBuilder);
191 std::optional<TypeRange> resultTensorTypes,
198 indexingMapsAttrVal = llvm::map_to_vector(
199 MatmulOp::getDefaultIndexingMaps(b.
getContext()),
201 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
203 attributes, regionBuilder);
212 bool addOperandSegmentSizes =
true) {
213 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
242 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
244 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
248 if (addOperandSegmentSizes) {
257 attrs.
append(
"operandSegmentSizes",
259 {static_cast<int32_t>(inputsOperands.size()),
260 static_cast<int32_t>(outputsOperands.size())}));
265 {static_cast<int32_t>(inputsOperands.size()),
266 static_cast<int32_t>(outputsOperands.size())}));
270 std::optional<RegisteredOperationName> info =
273 if (failed(info->verifyInherentAttrs(result.
attributes, [&]() {
274 return parser.emitError(attrsLoc)
275 <<
"'" << result.name.getStringRef() <<
"' op ";
286 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
287 if (!outputs.empty())
288 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
299 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
302 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
303 "region expects {0} args, got {1}",
304 numRegionArgs, inputTypes.size() + outputTypes.size()));
323 unsigned numRegionArgs,
339 result.
addTypes(outputTensorsTypes);
341 std::unique_ptr<Region> region = std::make_unique<Region>();
353 if (resultTypes.empty())
398 class RegionBuilderHelper {
401 : builder(builder), block(block) {}
404 Value buildUnaryFn(UnaryFn unaryFn,
Value arg) {
405 if (!isFloatingPoint(arg))
406 llvm_unreachable(
"unsupported non numeric type");
408 builder.setInsertionPointToEnd(&block);
411 return builder.create<math::ExpOp>(arg.
getLoc(), arg);
413 return builder.create<math::LogOp>(arg.
getLoc(), arg);
415 return builder.create<math::AbsFOp>(arg.
getLoc(), arg);
417 return builder.create<math::CeilOp>(arg.
getLoc(), arg);
419 return builder.create<math::FloorOp>(arg.
getLoc(), arg);
421 return builder.create<arith::NegFOp>(arg.
getLoc(), arg);
422 case UnaryFn::reciprocal: {
424 auto one = builder.create<arith::ConstantOp>(arg.
getLoc(),
425 ::cast<TypedAttr>(oneAttr));
426 return builder.create<arith::DivFOp>(arg.
getLoc(), one, arg);
429 return builder.create<math::RoundOp>(arg.
getLoc(), arg);
431 return builder.create<math::SqrtOp>(arg.
getLoc(), arg);
433 return builder.create<math::RsqrtOp>(arg.
getLoc(), arg);
434 case UnaryFn::square:
435 return builder.create<arith::MulFOp>(arg.
getLoc(), arg, arg);
437 return builder.create<math::TanhOp>(arg.
getLoc(), arg);
439 return builder.create<math::ErfOp>(arg.
getLoc(), arg);
441 llvm_unreachable(
"unsupported unary function");
446 bool allComplex = isComplex(arg0) && isComplex(arg1);
447 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
448 bool allInteger = isInteger(arg0) && isInteger(arg1);
451 if (!allComplex && !allFloatingPoint && !allInteger)
452 llvm_unreachable(
"unsupported non numeric type");
454 builder.setInsertionPointToEnd(&block);
458 return builder.create<complex::AddOp>(arg0.
getLoc(), arg0, arg1);
459 if (allFloatingPoint)
460 return builder.create<arith::AddFOp>(arg0.
getLoc(), arg0, arg1);
462 return builder.create<arith::OrIOp>(arg0.
getLoc(), arg0, arg1);
463 return builder.create<arith::AddIOp>(arg0.
getLoc(), arg0, arg1);
466 return builder.create<complex::SubOp>(arg0.
getLoc(), arg0, arg1);
467 if (allFloatingPoint)
468 return builder.create<arith::SubFOp>(arg0.
getLoc(), arg0, arg1);
470 llvm_unreachable(
"unsupported operation: sub with bools");
471 return builder.create<arith::SubIOp>(arg0.
getLoc(), arg0, arg1);
474 return builder.create<complex::MulOp>(arg0.
getLoc(), arg0, arg1);
475 if (allFloatingPoint)
476 return builder.create<arith::MulFOp>(arg0.
getLoc(), arg0, arg1);
478 return builder.create<arith::AndIOp>(arg0.
getLoc(), arg0, arg1);
479 return builder.create<arith::MulIOp>(arg0.
getLoc(), arg0, arg1);
482 return builder.create<complex::DivOp>(arg0.
getLoc(), arg0, arg1);
483 if (allFloatingPoint)
484 return builder.create<arith::DivFOp>(arg0.
getLoc(), arg0, arg1);
486 llvm_unreachable(
"unsupported operation: div with bools");
487 return builder.create<arith::DivSIOp>(arg0.
getLoc(), arg0, arg1);
488 case BinaryFn::div_unsigned:
489 if (!allInteger || allBool)
490 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
491 return builder.create<arith::DivUIOp>(arg0.
getLoc(), arg0, arg1);
492 case BinaryFn::max_signed:
494 if (allFloatingPoint)
495 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
496 return builder.create<arith::MaxSIOp>(arg0.
getLoc(), arg0, arg1);
497 case BinaryFn::min_signed:
499 if (allFloatingPoint)
500 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
501 return builder.create<arith::MinSIOp>(arg0.
getLoc(), arg0, arg1);
502 case BinaryFn::max_unsigned:
504 if (allFloatingPoint)
505 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
506 return builder.create<arith::MaxUIOp>(arg0.
getLoc(), arg0, arg1);
507 case BinaryFn::min_unsigned:
509 if (allFloatingPoint)
510 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
511 return builder.create<arith::MinUIOp>(arg0.
getLoc(), arg0, arg1);
513 assert(allFloatingPoint);
514 return builder.create<math::PowFOp>(arg0.
getLoc(), arg0, arg1);
516 llvm_unreachable(
"unsupported binary function");
524 bool tailFloatingPoint =
525 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
526 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
528 builder.setInsertionPointToEnd(&block);
530 case TernaryFn::select:
531 if (!headBool && !(tailFloatingPoint || tailInteger))
532 llvm_unreachable(
"unsupported non numeric type");
533 return builder.create<arith::SelectOp>(arg0.
getLoc(), arg0, arg1, arg2);
535 llvm_unreachable(
"unsupported ternary function");
541 case TypeFn::cast_signed:
542 return cast(toType, operand,
false);
543 case TypeFn::cast_unsigned:
544 return cast(toType, operand,
true);
546 llvm_unreachable(
"unsupported type conversion function");
551 builder.setInsertionPointToEnd(&block);
552 Location loc = builder.getUnknownLoc();
553 builder.create<YieldOp>(loc, values);
556 Value constant(
const std::string &value) {
558 builder.setInsertionPointToEnd(&block);
559 Location loc = builder.getUnknownLoc();
561 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
564 Value index(int64_t dim) {
566 builder.setInsertionPointToEnd(&block);
567 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
570 Type getIntegerType(
unsigned width) {
584 builder.setInsertionPointToEnd(&block);
585 auto loc = operand.
getLoc();
589 bool isComplex(
Value value) {
590 return llvm::isa<ComplexType>(value.
getType());
592 bool isFloatingPoint(
Value value) {
593 return llvm::isa<FloatType>(value.
getType());
595 bool isInteger(
Value value) {
596 return llvm::isa<IntegerType>(value.
getType());
613 LogicalResult matchAndRewrite(CopyOp copyOp,
615 if (copyOp.getInputs() != copyOp.getOutputs())
617 if (copyOp.hasPureBufferSemantics())
620 rewriter.
replaceOp(copyOp, copyOp.getInputs());
630 results.
add<EraseSelfCopy>(context);
643 template <
typename TensorReshapeOp>
646 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
648 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
653 TensorReshapeOp newInit;
654 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
656 newInit = rewriter.
create<TensorReshapeOp>(
657 loc, reshapeOp.getResultType(), oldFill.output(),
658 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
659 reshapeOp.getStaticOutputShape());
661 newInit = rewriter.
create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
663 reshapeOp.getReassociation());
676 LogicalResult matchAndRewrite(tensor::PadOp padOp,
678 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
684 Value padValue = padOp.getConstantPaddingValue();
685 if (!padValue || fillOp.value() != padValue)
691 padOp,
"failed to reify tensor.pad op result shape");
693 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
694 padOp.getLoc(), reifiedShape.front(),
695 padOp.getResultType().getElementType());
701 if (replacement.getType() != padOp.getResultType()) {
702 replacement = rewriter.
create<tensor::CastOp>(
703 fillOp.getLoc(), padOp.getResultType(), replacement);
713 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
716 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
718 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
722 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
727 Value firstDest = insertOp.getDest();
728 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
729 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
734 bool disjoint =
false;
735 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
738 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
739 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
740 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
744 int64_t prevStart = prevOp.getStaticOffset(i);
745 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
746 prevOp.getStaticStride(i);
747 int64_t nextStart = insertOp.getStaticOffset(i);
748 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
749 insertOp.getStaticStride(i);
750 if (prevEnd < nextStart || nextEnd < prevStart) {
758 firstDest = prevOp.getDest();
769 Value padValue = srcPadOp.getConstantPaddingValue();
770 if (!padValue || dstFillOp.value() != padValue)
786 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
788 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
791 RankedTensorType srcPadType = srcPadOp.getSourceType();
793 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
794 if (srcPadType.isDynamicDim(i)) {
796 rewriter.
create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
799 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
804 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
805 newSizes, insertOp.getMixedStrides());
811 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
815 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
819 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
824 Value extractedScalar = fillOp.getInputs()[0];
827 rewriter.
replaceOp(extractOp, extractedScalar);
835 static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
836 tensor::PackOp packOp) {
837 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
841 if (
auto paddingValue = packOp.getPaddingValue())
845 Value packOpDest = packOp.getDest();
849 return rewriter.
create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
859 LogicalResult matchAndRewrite(tensor::PackOp packOp,
861 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
864 rewriter.
replaceOp(packOp, fillOp.value().result());
873 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
875 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
878 copyOp.getOutputs());
881 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
883 fillOp.getOutputs());
894 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
896 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
898 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
899 transposeOp.getDpsInitOperand(0)->get());
911 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
913 auto concatOperands = concatOp.getInputs();
914 if (concatOperands.empty()) {
918 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
927 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
929 auto isDefinedByCompatibleFillOp = [&](
Value v) ->
bool {
930 auto fillOp = v.getDefiningOp<linalg::FillOp>();
937 if (fillVal != firstFillVal)
940 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
943 if (!llvm::all_of(concatOperands.drop_front(),
944 isDefinedByCompatibleFillOp)) {
946 concatOp,
"not all operands are defined by a compatible fill op");
949 Value outsConcat = rewriter.
create<tensor::ConcatOp>(
950 concatOp.getLoc(), concatOp.getDim(), allOuts);
952 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
961 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
962 FoldFillWithPack, FoldFillWithPad,
963 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
964 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
965 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
978 for (
ValueRange container : {inputs, outputs}) {
979 for (
Value v : container) {
980 Type t = v.getType();
981 blockArgTypes.push_back(
983 blockArgLocs.push_back(v.getLoc());
989 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
993 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
995 for (
Value v : getRegionInputArgs())
997 for (
Value v : getRegionOutputArgs())
1001 void GenericOp::build(
1004 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1007 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1008 iteratorTypes, doc, libraryCall);
1012 inputs, outputs, bodyBuild);
1015 void GenericOp::build(
1019 StringRef libraryCall,
1022 build(builder, result, resultTensorTypes, inputs, outputs,
1027 return IteratorTypeAttr::get(builder.getContext(), iter);
1030 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1031 bodyBuild, attributes);
1034 void GenericOp::build(
1038 StringRef libraryCall,
1041 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
1042 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1045 void GenericOp::build(
1051 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1053 "", bodyBuild, attributes);
1056 void GenericOp::build(
1062 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1065 "", bodyBuild, attributes);
1072 auto genericAttrNames = linalgTraitAttrNames();
1075 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1077 for (
auto attr : (*this)->getAttrs()) {
1078 if (attr.getName() == getIteratorTypesAttrName()) {
1079 auto iteratorTypes =
1080 llvm::cast<ArrayAttr>(attr.getValue())
1081 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1087 llvm::to_vector(llvm::map_range(
1088 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
1092 genericAttrs.emplace_back(
1093 getIteratorTypesAttrName(),
1095 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1096 genericAttrs.push_back(attr);
1099 if (!genericAttrs.empty()) {
1101 p << genericDictAttr;
1107 genericAttrNames.push_back(
"operandSegmentSizes");
1108 genericAttrNamesSet.insert(genericAttrNames.back());
1110 bool hasExtraAttrs =
false;
1112 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1115 if (hasExtraAttrs) {
1122 if (!getRegion().empty()) {
1132 DictionaryAttr dictAttr;
1141 dictAttr.getValue().end());
1147 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1149 if (!iteratorTypes) {
1150 return parser.
emitError(attributeLocation)
1151 <<
"expected " << getIteratorTypesAttrName(result.
name)
1152 <<
" array attribute";
1157 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1158 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1159 if (!maybeIteratorType.has_value())
1161 <<
"unexpected iterator_type (" << s <<
")";
1163 iteratorTypeAttrs.push_back(
1180 std::unique_ptr<Region> region = std::make_unique<Region>();
1192 result.
addTypes(outputTensorsTypes);
1200 LinalgOp linalgOp) {
1201 for (
auto [index, operand] :
llvm::enumerate(linalgOp.getDpsInputs())) {
1202 if (!llvm::isa<MemRefType>(operand.
getType()))
1204 effects.emplace_back(
1209 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1210 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1212 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1223 void GenericOp::getEffects(
1233 if (!linalgOp.hasPureTensorSemantics())
1252 template <
typename OpTy>
1256 LogicalResult matchAndRewrite(OpTy linalgOp,
1259 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1264 Block &body = linalgOp->getRegion(0).
front();
1265 if (!llvm::hasSingleElement(body))
1267 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1272 if (linalgOp.hasPureBufferSemantics()) {
1273 if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1274 linalgOp.getDpsInputOperand(0)->get() ==
1275 linalgOp.getDpsInitOperand(0)->get()) {
1283 if (!linalgOp.hasPureTensorSemantics())
1290 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1291 if (!yieldArg || yieldArg.getOwner() != &body)
1293 unsigned argumentNumber = yieldArg.getArgNumber();
1294 Value returnedArg = linalgOp->getOperand(argumentNumber);
1295 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1299 if (returnType != resultType) {
1304 returnedArg = rewriter.
create<sparse_tensor::ConvertOp>(
1305 linalgOp.getLoc(), resultType, returnedArg);
1307 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1310 returnedArg = rewriter.
create<tensor::CastOp>(
1311 linalgOp.getLoc(), resultType, returnedArg);
1314 returnedArgs.push_back(returnedArg);
1317 if (returnedArgs.size() != linalgOp->getNumResults())
1319 rewriter.
replaceOp(linalgOp, returnedArgs);
1328 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1350 for (
Type outputType : outputTypes) {
1351 if (llvm::isa<RankedTensorType>(outputType))
1356 if (parseAttrsFn && failed(parseAttrsFn(parser, result.
attributes)))
1365 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1367 for (
Value v : getRegionInputArgs())
1372 if (!getResults().empty())
1373 setNameFn(getResults().front(),
"mapped");
1380 build(builder, result,
TypeRange{}, inputs, init);
1385 if (llvm::isa<RankedTensorType>(initType))
1390 inputs, {}, bodyBuild);
1397 bool initFirst =
false) {
1403 for (
auto &operand : operands) {
1405 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1412 payloadOpOperands.push_back(block.
getArguments().back());
1413 for (
const auto &arg : block.
getArguments().drop_back())
1414 payloadOpOperands.push_back(arg);
1423 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1430 std::optional<OperationName> payloadOpName;
1434 if (failed(operationName))
1438 payloadOpName = operationName.value();
1446 if (payloadOpName.has_value()) {
1484 for (
const auto &[operand, bbArg] :
1486 if (bbArg != operand)
1490 for (
const auto &[operand, bbArg] :
1492 if (bbArg != operand)
1502 for (
const auto &attr : payloadOp->
getAttrs()) {
1503 if (
auto fastAttr = dyn_cast<arith::FastMathFlagsAttr>(attr.getValue())) {
1504 if (fastAttr.getValue() == arith::FastMathFlags::none) {
1505 elidedAttrs.push_back(attr.getName());
1508 if (
auto denormAttr = dyn_cast<arith::DenormalModeAttr>(attr.getValue())) {
1509 if (denormAttr.getValue() == arith::DenormalMode::ieee) {
1510 elidedAttrs.push_back(attr.getName());
1519 Block *mapper = getBody();
1534 [&](
auto arg) { p.printRegionArgument(arg); });
1543 auto *bodyBlock = getBody();
1544 auto blockArgs = bodyBlock->getArguments();
1547 if (getInputs().size() != blockArgs.size())
1548 return emitOpError() <<
"expects number of operands to match the arity of "
1550 << getInputs().size() <<
" and " << blockArgs.size();
1553 for (
const auto &[bbArgType, inputArg] :
1554 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1555 auto inputElemType =
1556 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1557 if (bbArgType != inputElemType) {
1558 return emitOpError() <<
"expected element type of input " << inputElemType
1559 <<
" to match bbArg type " << bbArgType;
1564 auto outputShape = getInit().getType().getShape();
1566 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1567 if (inputElemShape != outputShape) {
1568 return emitOpError() <<
"expected shape of input (" << inputElemShape
1569 <<
") to match shape of output (" << outputShape
1578 int64_t rank = getInit().getType().getRank();
1582 ArrayAttr MapOp::getIndexingMaps() {
1584 int64_t rank = getInit().getType().getRank();
1585 int64_t numIndexingMaps = getOperands().size();
1590 void MapOp::getEffects(
1604 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1606 for (
Value v : getRegionInputArgs())
1608 for (
Value v : getRegionOutputArgs())
1609 setNameFn(v,
"init");
1612 void ReduceOp::getAsmResultNames(
1614 if (!getResults().empty())
1615 setNameFn(getResults().front(),
"reduced");
1618 void ReduceOp::build(
1623 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1627 for (
Value init : inits) {
1629 if (llvm::isa<RankedTensorType>(initType))
1635 inputs, inits, bodyBuild);
1640 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1642 utils::IteratorType::parallel);
1643 for (int64_t reductionDim : getDimensions())
1644 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1645 return iteratorTypes;
1648 ArrayAttr ReduceOp::getIndexingMaps() {
1650 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1657 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1658 affineMaps.push_back(resultMap);
1662 void ReduceOp::getEffects(
1674 StringRef attributeName) {
1683 std::optional<OperationName> payloadOpName;
1687 if (failed(operationName))
1691 payloadOpName = operationName.value();
1702 if (payloadOpName.has_value()) {
1722 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1726 Block *mapper = getBody();
1741 [&](
auto arg) { p.printRegionArgument(arg); });
1752 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1753 if (llvm::cast<ShapedType>(getInputs()[i].
getType()).getShape() !=
1754 llvm::cast<ShapedType>(getInputs()[0].
getType()).getShape()) {
1755 return emitOpError() <<
"expects all inputs to have the same shapes. "
1756 "Shape at input-index "
1758 <<
" is not equal to the shape at input-index 0.";
1761 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1762 if (llvm::cast<ShapedType>(getInits()[i].
getType()).getShape() !=
1763 llvm::cast<ShapedType>(getInits()[0].
getType()).getShape()) {
1764 return emitOpError() <<
"expects all outputs to have the same shapes. "
1765 "Shape at output-index "
1767 <<
" is not equal to the shape at output-index 0.";
1770 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1771 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1774 for (int64_t dimension : dimensionsRef) {
1775 if (dimension < 0 || dimension >= inputType.getRank()) {
1776 return emitOpError()
1777 <<
"dimensions for reduction should be in the range [0, "
1778 << inputType.getRank() - 1 <<
"].";
1780 dimensionsToReduce.insert(dimension);
1783 auto inputDims = inputType.getShape();
1784 auto initDims = initType.getShape();
1789 if (!dimensionsToReduce.count(en.index()))
1790 reducedInputDims.push_back(en.value());
1793 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1794 return emitOpError() <<
"number of dimensions after reduction "
1795 << reducedInputDims.size()
1796 <<
" doesn't match the init rank "
1797 << initType.getRank();
1800 if (reducedInputDims != initDims)
1801 return emitOpError() <<
"init dimensions [" << initDims
1802 <<
"] doesn't match input dimensions after reduction ["
1803 << reducedInputDims <<
"]";
1805 Block *block = getBody();
1807 return emitOpError()
1808 <<
"mismatching number of operands and block arguments";
1811 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1812 Type inputElementType =
1813 llvm::cast<ShapedType>(input.getType()).getElementType();
1814 if (inputElementType != bbArg.getType())
1815 return emitOpError()
1816 <<
"input element type " << inputElementType
1817 <<
" does not match corresponding block argument type "
1822 for (
auto [output, bbArg] : llvm::zip(
1823 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1824 auto outputElementType =
1825 llvm::cast<ShapedType>(output.getType()).getElementType();
1826 if (outputElementType != bbArg.getType())
1827 return emitOpError()
1828 <<
"output element type " << outputElementType
1829 <<
" does not match corresponding block argument type "
1845 b.
create<linalg::YieldOp>(loc, args[0]);
1860 if (llvm::isa<RankedTensorType>(initType))
1889 void TransposeOp::getAsmResultNames(
1891 if (!getResults().empty())
1892 setNameFn(getResults().front(),
"transposed");
1905 return emitOpError(
"permutation is not valid");
1907 auto inputType = getInput().getType();
1908 auto initType = getInit().getType();
1910 int64_t rank = inputType.getRank();
1912 if (rank != initType.getRank())
1913 return emitOpError() <<
"input rank " << rank
1914 <<
" does not match init rank " << initType.getRank();
1916 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
1917 return emitOpError() <<
"size of permutation " << permutationRef.size()
1918 <<
" does not match the argument rank " << rank;
1920 auto inputDims = inputType.getShape();
1921 auto initDims = initType.getShape();
1923 for (int64_t i = 0; i < rank; ++i) {
1924 int64_t inputDim = inputDims[permutationRef[i]];
1925 int64_t initDim = initDims[i];
1927 if (inputDim != initDim) {
1928 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
1929 <<
" doesn't match dim(input, permutation[" << i
1930 <<
"]) = " << inputDim;
1938 int64_t rank = getInit().getType().getRank();
1942 ArrayAttr TransposeOp::getIndexingMaps() {
1944 int64_t rank = getInit().getType().getRank();
1947 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
1951 void TransposeOp::getEffects(
1961 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1964 if (!isa<TensorType>(getInput().
getType()))
1968 if (getPermutation().size() == 0) {
1969 result.push_back(getInput());
1974 result.push_back(getInput());
1987 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1988 if (!defTransposeOp)
1993 foldedPerms.reserve(perms.size());
1994 for (int64_t perm : perms)
1995 foldedPerms.push_back(defPerms[perm]);
1998 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2012 Value input = transposeOp.getInput();
2013 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2024 unsigned dimensionSize = dimensions.size();
2025 for (
unsigned i = 0; i < dimensionSize; ++i)
2026 resultDimensions.push_back(invertPerm[dimensions[i]]);
2029 Value broadcastInput = broadcastOp.getInput();
2030 Location loc = transposeOp.getLoc();
2033 auto broadcastInputTy =
2034 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2035 unsigned inputRank = broadcastInputTy.getRank();
2036 for (
unsigned i = 0; i < inputRank; ++i) {
2037 if (broadcastInputTy.isDynamicDim(i)) {
2038 dims.push_back(rewriter.
create<tensor::DimOp>(loc, broadcastInput, i)
2042 broadcastInputTy.getDimSize(i)));
2047 Value transposeInit = rewriter.
create<tensor::EmptyOp>(
2048 transposeOp.getLoc(), transposeResultShapes,
2049 broadcastInputTy.getElementType());
2052 Value transposeResult =
2054 .
create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2058 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2083 if (llvm::isa<RankedTensorType>(initType))
2112 void BroadcastOp::getAsmResultNames(
2114 if (!getResults().empty())
2115 setNameFn(getResults().front(),
"broadcasted");
2127 auto inputType = getInput().getType();
2128 auto initType = getInit().getType();
2130 int64_t inputRank = inputType.getRank();
2131 int64_t initRank = initType.getRank();
2133 auto inputShape = inputType.getShape();
2134 auto initShape = initType.getShape();
2136 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2137 return emitOpError() <<
"input rank plus added dimensions does not "
2138 "match init rank. input rank: "
2140 <<
", dimensions size: " << dimensionsRef.size()
2141 <<
", init rank: " << initRank;
2144 if (dim < 0 || dim >= initRank)
2145 return emitOpError() <<
"dimension " << idx
2146 <<
" is out of range. expected range: [0, "
2147 << initRank - 1 <<
"], got: " << dim;
2152 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2153 if (!llvm::is_contained(dimensionsRef, dim))
2154 dimMap.push_back(dim);
2157 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
2160 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2161 return emitOpError() <<
"input dim " << inputDimIdx
2162 <<
" should match init dim " << initDimIdx
2163 <<
". input: " << inputShape[inputDimIdx]
2164 <<
", init: " << initShape[initDimIdx];
2171 int64_t rank = getInit().getType().getRank();
2175 ArrayAttr BroadcastOp::getIndexingMaps() {
2177 int64_t rank = getInit().getType().getRank();
2183 void BroadcastOp::getEffects(
2195 results.
add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2203 if (getNumOperands() > 0)
2204 p <<
' ' << getOperands();
2206 if (getNumOperands() > 0)
2207 p <<
" : " << getOperandTypes();
2222 static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2223 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2224 return op.emitOpError(
"expected number of yield values (")
2225 << op.getNumOperands()
2226 <<
") to match the number of inits / outs operands of the enclosing "
2227 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2229 for (
OpOperand &opOperand : op->getOpOperands()) {
2231 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2233 if (isa<MemRefType, RankedTensorType>(elementType))
2235 if (opOperand.get().getType() != elementType)
2236 return op.emitOpError(
"type of yield operand ")
2237 << (opOperand.getOperandNumber() + 1) <<
" ("
2238 << opOperand.get().getType() <<
") doesn't match "
2239 <<
"the element type of the enclosing linalg.generic op ("
2240 << elementType <<
")";
2246 auto *parentOp = (*this)->getParentOp();
2247 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2248 return emitOpError(
"expected single non-empty parent region");
2250 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2253 return emitOpError(
"expected parent op with LinalgOp interface");
2261 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2263 return emitOpError(
"expected parent op with LinalgOp interface");
2264 if (linalgOp.getNumLoops() <= getDim())
2265 return emitOpError(
"expected dim (")
2266 << getDim() <<
") to be lower than the number of loops ("
2267 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2273 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2275 #define GET_OP_CLASSES
2276 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2278 #define GET_OP_CLASSES
2279 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2296 for (
unsigned i = 0; i < num; ++i)
2303 auto rangeA = llvm::make_range(a.begin(), a.end());
2304 auto rangeB = llvm::make_range(b.begin(), b.end());
2305 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2306 return llvm::to_vector<4>(concatRanges);
2310 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2312 for (
auto size : memref.getShape())
2319 if (
auto as = memref.getMemorySpace()) {
2320 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2321 ss <<
"as" << attr.getInt();
2327 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2330 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2343 assert(isa<LinalgOp>(op));
2345 std::string fun =
"";
2347 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2348 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2349 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2350 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2354 std::replace(name.begin(), name.end(),
'.',
'_');
2355 llvm::raw_string_ostream ss(name);
2359 return std::string();
2374 LogicalResult matchAndRewrite(LinalgOp op,
2376 for (
OpOperand &opOperand : op->getOpOperands()) {
2380 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2383 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2394 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2397 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2402 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2409 if (castOp->getBlock() != linalgOp->getBlock())
2416 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2419 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2425 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2427 rewriter.
create<tensor::CastOp>(loc, resultType, outOperand->
get());
2430 linalgOp.getDpsInits().end());
2431 outputOperands[resultNumber] = newOperand;
2432 newOperands.append(outputOperands.begin(), outputOperands.end());
2435 linalgOp->result_type_end());
2436 resultTypes[resultNumber] = resultType;
2437 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2444 results[resultNumber] = castBack;
2456 if (linalgOp.isScalar(&opOperand))
2458 Value src = opOperand.get();
2459 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2460 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2468 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2469 Value castSource = castOp.getSource();
2470 auto castSourceType =
2471 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2472 if (castSourceType && castSourceType.hasStaticShape())
2473 sourceShape = castSourceType.getShape();
2479 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2480 if (sourceType.isDynamicDim(i))
2482 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2483 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2493 static void createNewOperandWithStaticSizes(
2497 bool &changeNeeded) {
2499 newOperands.push_back(src);
2500 if (linalgOp.isScalar(opOperand))
2502 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2503 Type resultType = sourceType;
2504 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2505 resultTypes.push_back(resultType);
2509 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2513 bool newOperandNeeded =
false;
2514 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2515 int64_t dimShape = sourceShape[i];
2517 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2518 newShape.push_back(dimShape);
2524 newShape.push_back(affineExprToSize[dimExpr]);
2525 newOperandNeeded =
true;
2528 if (newOperandNeeded) {
2529 changeNeeded =
true;
2532 Value newOperand = rewriter.
create<tensor::CastOp>(loc, resultType, src);
2534 newOperands[index] = newOperand;
2536 if (linalgOp.isDpsInit(opOperand))
2537 resultTypes.push_back(resultType);
2546 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2548 if (!linalgOp.hasPureTensorSemantics())
2552 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2553 return !map.isProjectedPermutation();
2563 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2570 bool changeNeeded =
false;
2571 newOperands.reserve(linalgOp->getNumOperands());
2572 resultTypes.reserve(linalgOp.getNumDpsInits());
2575 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2576 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2577 affineExprToSize, linalgOp, newOperands,
2578 resultTypes, changeNeeded);
2587 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2590 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2591 Value newResult = std::get<1>(it);
2592 Value oldResult = std::get<0>(it);
2595 replacements.push_back(
2596 (newType != oldType)
2597 ? rewriter.
create<tensor::CastOp>(loc, oldType, newResult)
2600 rewriter.
replaceOp(linalgOp, replacements);
2615 ShapedType inputType = getInputOperandType();
2616 ShapedType outputType = getOutputOperandType();
2621 return emitOpError(
"incompatible output shape");
2623 int64_t inputRank = getInputOperandRank();
2624 int64_t dimension = getDimension();
2625 if ((dimension < 0) || (dimension >= inputRank))
2626 return emitOpError(
"incorrect dimension specified");
2632 int64_t operandRank = getInputOperandRank();
2635 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
2636 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2637 Value source = getInput();
2638 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2639 loopBounds[dim].offset = zero;
2640 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2641 loopBounds[dim].stride = one;
2648 utils::IteratorType::parallel);
2649 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2650 return iteratorTypes;
2653 FailureOr<TilingResult>
2654 SoftmaxOp::getTiledImplementation(
OpBuilder &builder,
2657 int64_t rank = getInputOperandRank();
2662 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2664 return emitOpError(
"failed to compute input slice");
2666 tiledOperands.emplace_back(inputSlice->
getResult(0));
2668 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2670 return emitOpError(
"failed to compute output slice");
2672 tiledOperands.emplace_back(outputSlice->
getResult(0));
2675 if (hasPureTensorSemantics())
2676 resultTypes.push_back(tiledOperands[1].
getType());
2678 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2686 LogicalResult SoftmaxOp::getResultTilePosition(
2690 if (resultNumber == 0) {
2691 resultOffsets.assign(offsets.begin(), offsets.end());
2692 resultSizes.assign(sizes.begin(), sizes.end());
2707 Location loc = getOperation()->getLoc();
2709 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2710 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2711 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2712 if (!outputShapedType.isDynamicDim(dim)) {
2714 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2721 reifiedReturnShapes.emplace_back(std::move(shapes));
2725 void SoftmaxOp::getEffects(
2729 if (!llvm::isa<MemRefType>(operand.
getType()))
2732 &getOperation()->getOpOperand(index), 0,
2737 for (
OpOperand &operand : getDpsInitsMutable()) {
2738 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2771 int64_t dim,
bool allParallel =
false) {
2773 utils::IteratorType::parallel);
2775 iteratorTypes[dim] = utils::IteratorType::reduction;
2779 for (
int i = 0; i < inputRank; i++) {
2786 return std::make_tuple(iteratorTypes, indexingMaps);
2791 template <
typename T>
2794 auto inputType = cast<ShapedType>(input.
getType());
2796 int64_t inputRank = inputShape.size();
2797 auto [iteratorTypes, indexingMaps] =
2799 assert(indexingMaps.size() == 2 &&
2800 "We should have two maps: 1 for the input, 1 for the output");
2801 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2803 auto genericOp = builder.
create<linalg::GenericOp>(
2804 loc, output.
getType(), input, output, indexingMaps, iteratorTypes,
2806 Value result = b.create<T>(loc, args[0], args[1]);
2807 b.create<linalg::YieldOp>(loc, result);
2817 auto inputType = cast<ShapedType>(input.
getType());
2819 int64_t inputRank = inputShape.size();
2821 builder, inputRank, dim,
true);
2822 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2823 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2825 indexingMaps.push_back(indexingMaps[0]);
2826 auto genericOp = builder.
create<linalg::GenericOp>(
2829 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2830 Value result = b.create<math::ExpOp>(loc, diff);
2831 b.create<linalg::YieldOp>(loc, result);
2842 Value denominator,
Value output, int64_t dim) {
2843 auto inputType = cast<ShapedType>(numerator.
getType());
2845 int64_t inputRank = inputShape.size();
2847 builder, inputRank, dim,
true);
2848 assert(indexingMaps.size() == 2 &&
2849 "We should have one map for each input (2)");
2850 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
2852 indexingMaps.push_back(indexingMaps[0]);
2853 auto genericOp = builder.
create<linalg::GenericOp>(
2855 indexingMaps, iteratorTypes,
2857 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2858 b.create<linalg::YieldOp>(loc, result);
2882 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(
OpBuilder &b) {
2886 Value input = getInput();
2887 ShapedType inputType = getInputOperandType();
2888 Type elementType = inputType.getElementType();
2889 int64_t reductionDim = getDimension();
2891 Value output = getOutput();
2892 dims.erase(dims.begin() + reductionDim);
2894 Value outputReduce = b.
create<tensor::EmptyOp>(loc, dims, elementType);
2896 elementType, b, loc,
2898 Value neutralForMaxFInit =
2899 b.
create<linalg::FillOp>(loc,
Value{neutralForMaxF}, outputReduce)
2902 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2911 b.
create<linalg::FillOp>(loc,
Value{zero}, outputReduce).result();
2913 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2917 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2926 auto filterType = cast<ShapedType>(getFilter().
getType());
2928 int64_t filterH = filterShape[getFilterHDim()];
2929 int64_t filterW = filterShape[getFilterWDim()];
2933 if (filterH != r && filterH != 1)
2934 return emitOpError(
"expect filter height either equals to r or 1");
2935 if (filterW != r && filterW != 1)
2936 return emitOpError(
"expect filter width either equals to r or 1");
2937 if (filterH == 1 && filterW == 1)
2938 return emitOpError(
"expect either filter height or width equals to r");
2941 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2942 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2943 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2944 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2946 auto outputType = cast<ShapedType>(getOutput().
getType());
2949 return emitOpError(
"the output shape is not expected");
2955 WinogradFilterTransformOp::getIterationDomain(
OpBuilder &builder) {
2959 Value filter = getFilter();
2960 int64_t filterRank = getFilterOperandRank();
2962 for (
unsigned dim = 0; dim < filterRank; ++dim) {
2963 loopBounds[dim].offset = zeroAttr;
2964 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
2965 loopBounds[dim].stride = oneAttr;
2971 WinogradFilterTransformOp::getLoopIteratorTypes() {
2972 int64_t filterRank = getFilterOperandRank();
2974 utils::IteratorType::parallel);
2975 return iteratorTypes;
2978 LogicalResult WinogradFilterTransformOp::getResultTilePosition(
2983 ShapedType filterType = getFilterOperandType();
2985 int64_t filterH = filterShape[getFilterHDim()];
2986 int64_t filterW = filterShape[getFilterWDim()];
2989 int64_t alpha = m + r - 1;
2990 int64_t alphaH = filterH != 1 ? alpha : 1;
2991 int64_t alphaW = filterW != 1 ? alpha : 1;
2995 resultOffsets.append(
2996 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
2998 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3009 FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3014 ShapedType filterType = getFilterOperandType();
3016 int64_t filterH = filterShape[getFilterHDim()];
3017 int64_t filterW = filterShape[getFilterWDim()];
3023 sliceOffsets.append(
3024 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3025 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3026 sizes[getFilterCDim()]});
3027 int64_t filterRank = getFilterOperandRank();
3030 auto filterSlice = builder.
create<tensor::ExtractSliceOp>(
3031 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3032 tiledOperands.emplace_back(filterSlice);
3035 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3039 int64_t outputRank = getOutputOperandRank();
3041 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3042 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3043 tiledOperands.emplace_back(outputSlice);
3046 resultTypes.push_back(tiledOperands[1].
getType());
3048 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3061 auto inputType = cast<ShapedType>(getInput().
getType());
3063 int64_t inputH = inputShape[getInputHDim()];
3064 int64_t inputW = inputShape[getInputWDim()];
3067 int64_t tileSize = m + r - 1;
3068 bool leftTransform = inputH != 1;
3069 bool rightTransform = inputW != 1;
3072 if (ShapedType::isDynamic(inputH)) {
3073 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3074 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3076 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3077 expectedOutputShape[getOutputTileHDim()] =
3078 leftTransform ? (inputH - (r - 1)) / m : 1;
3080 if (ShapedType::isDynamic(inputW)) {
3081 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3082 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3084 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3085 expectedOutputShape[getOutputTileWDim()] =
3086 rightTransform ? (inputW - (r - 1)) / m : 1;
3088 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3089 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3091 auto outputType = cast<ShapedType>(getOutput().
getType());
3094 return emitOpError(
"the output shape is not expected");
3100 WinogradInputTransformOp::getIterationDomain(
OpBuilder &builder) {
3104 Value output = getOutput();
3105 int64_t outputRank = getOutputOperandRank();
3107 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3108 loopBounds[dim].offset = zeroAttr;
3110 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3111 loopBounds[dim].stride = oneAttr;
3117 WinogradInputTransformOp::getLoopIteratorTypes() {
3118 int64_t outputRank = getOutputOperandRank();
3120 utils::IteratorType::parallel);
3121 return iteratorTypes;
3124 LogicalResult WinogradInputTransformOp::getResultTilePosition(
3129 ShapedType inputType = getInputOperandType();
3131 int64_t inputH = inputShape[getInputHDim()];
3132 int64_t inputW = inputShape[getInputWDim()];
3135 int64_t alpha = m + r - 1;
3136 int64_t alphaH = inputH != 1 ? alpha : 1;
3137 int64_t alphaW = inputW != 1 ? alpha : 1;
3141 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3142 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3143 offsets[getOutputCDim()]});
3144 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3145 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3146 sizes[getOutputCDim()]});
3157 FailureOr<TilingResult>
3158 WinogradInputTransformOp::getTiledImplementation(
OpBuilder &builder,
3163 ShapedType inputType = getInputOperandType();
3165 int64_t inputH = inputShape[getInputHDim()];
3166 int64_t inputW = inputShape[getInputWDim()];
3172 auto offsetAffineMap =
3175 builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]);
3177 builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]);
3181 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3183 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3192 sliceOffsets.append(
3193 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3199 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3200 int64_t inputRank = getInputOperandRank();
3202 auto inputSlice = builder.
create<tensor::ExtractSliceOp>(
3203 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3204 tiledOperands.emplace_back(inputSlice);
3207 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3211 int64_t outputRank = getOutputOperandRank();
3213 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3214 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3215 tiledOperands.emplace_back(outputSlice);
3218 resultTypes.push_back(tiledOperands[1].
getType());
3220 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3233 auto valueType = cast<ShapedType>(getValue().
getType());
3235 int64_t valueH = valueShape[getValueAlphaHDim()];
3236 int64_t valueW = valueShape[getValueAlphaWDim()];
3237 int64_t valueTileH = valueShape[getValueTileHDim()];
3238 int64_t valueTileW = valueShape[getValueTileWDim()];
3241 bool leftTransform = valueH != 1;
3242 bool rightTransform = valueW != 1;
3244 int64_t outputRank = getOutputOperandRank();
3246 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3247 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3249 if (valueH != (leftTransform ? m + r - 1 : 1))
3250 return emitOpError(
"expect input height equals to input tile size");
3251 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3253 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3254 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3256 if (valueW != (rightTransform ? m + r - 1 : 1))
3257 return emitOpError(
"expect input width equals to input tile size");
3258 expectedOutputShape[getOutputWDim()] =
3259 (rightTransform ? m : 1) * valueTileW;
3261 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3262 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3264 auto outputType = cast<ShapedType>(getOutput().
getType());
3267 return emitOpError(
"the output shape is not expected");
3273 WinogradOutputTransformOp::getIterationDomain(
OpBuilder &builder) {
3277 Value value = getValue();
3278 int64_t valueRank = getValueOperandRank();
3280 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3281 loopBounds[dim].offset = zeroAttr;
3283 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3284 loopBounds[dim].stride = oneAttr;
3290 WinogradOutputTransformOp::getLoopIteratorTypes() {
3291 int64_t valueRank = getValueOperandRank();
3293 utils::IteratorType::parallel);
3294 return iteratorTypes;
3297 LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3309 builder, loc, affineMap, offsets[getValueTileHDim()]);
3311 builder, loc, affineMap, offsets[getValueTileWDim()]);
3313 builder, loc, affineMap, sizes[getValueTileHDim()]);
3315 builder, loc, affineMap, sizes[getValueTileWDim()]);
3317 ShapedType valueType = getValueOperandType();
3319 int64_t valueH = valueShape[0];
3320 int64_t valueW = valueShape[1];
3332 resultOffsets.append(
3333 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3335 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3345 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3354 ShapedType valueType = getValueOperandType();
3356 int64_t alphaH = valueShape[getValueAlphaHDim()];
3357 int64_t alphaW = valueShape[getValueAlphaWDim()];
3361 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3362 offsets[getValueTileWDim()], offsets[getValueNDim()],
3363 offsets[getValueFDim()]});
3364 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3365 sizes[getValueTileWDim()], sizes[getValueNDim()],
3366 sizes[getValueFDim()]});
3367 int64_t valueRank = getValueOperandRank();
3369 auto valueSlice = builder.
create<tensor::ExtractSliceOp>(
3370 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3371 tiledOperands.emplace_back(valueSlice);
3374 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3378 int64_t outputRank = getOutputOperandRank();
3380 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3381 loc, getOutput(), resultOffsets, resultSizes, strides);
3382 tiledOperands.emplace_back(outputSlice);
3385 resultTypes.push_back(tiledOperands[1].
getType());
3387 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3399 void LinalgDialect::getCanonicalizationPatterns(
3401 results.
add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
3408 return arith::ConstantOp::materialize(builder, value, type, loc);
3414 auto explicitRange = explictMap.
getResults();
3418 llvm::set_union(explicitSet, defaultSet);
3419 return explicitSet == defaultSet;
3435 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3437 auto opIndexingMap = opIndexingMaps[opIndex];
3438 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3441 return matmulOp->emitOpError()
3442 <<
"Unexpected dim expression in map result.";
3446 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3447 return matmulOp->emitOpError()
3448 <<
"Invalid broadcast requested, should be (d2).";
3470 return indexingMaps;
3475 utils::IteratorType::parallel,
3476 utils::IteratorType::reduction};
3479 unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3481 std::string MatmulOp::getLibraryCallName() {
3485 bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3489 bool MatmulOp::hasUserDefinedMaps() {
3493 return defaultMaps != explicitMaps;
3501 "MatmulOp regionBuilder expects 3 (>=0) args");
3502 RegionBuilderHelper helper(b, block);
3505 TypeFn castVal = TypeFn::cast_signed;
3506 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3507 return attr.
getName() ==
"cast";
3509 if (castIter != attrs.end()) {
3510 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3518 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3520 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), value3);
3521 yields.push_back(value4);
3522 helper.yieldOutputs(yields);
3526 bool MatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap) {
3527 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3546 if (!isa<AffineMapAttr>(mapAttr)) {
3548 "expected affine map attribute");
3550 indexingMapsAttr.push_back(mapAttr);
3560 if (indexingMapsAttr.empty()) {
3561 indexingMapsAttr = llvm::map_to_vector(
3562 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3569 MatmulOp::getRegionBuilder());
3573 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3578 MatmulOp::getDefaultIndexingMaps(
getContext()),
3580 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
3581 p <<
" indexing_maps = [";
3582 llvm::interleaveComma(getIndexingMaps(), p,
3591 if (!hasUserDefinedMaps())
3594 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3604 void MatmulOp::getEffects(
3607 if (hasPureTensorSemantics())
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Returns true if the explictMap is broadcasted with respect to the defaultMap.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap)
Returns true if the result AffineExpr of the explicitMap is same as defaultMap.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.