41#include "llvm/ADT/DenseMap.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SetOperations.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/StringSet.h"
47#include "llvm/ADT/TypeSwitch.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/Support/InterleavedRange.h"
50#include "llvm/Support/LogicalResult.h"
51#include "llvm/Support/MathExtras.h"
52#include "llvm/Support/raw_ostream.h"
62 auto type = cast<ShapedType>(v.
getType());
63 if (!type.isDynamicDim(dim))
68 .Case([&](RankedTensorType t) ->
Value {
69 return tensor::DimOp::create(builder, loc, v, dim);
71 .Case([&](MemRefType t) ->
Value {
72 return memref::DimOp::create(builder, loc, v, dim);
83 .Case([&](RankedTensorType t) ->
Operation * {
84 return tensor::ExtractSliceOp::create(
b, loc, source, offsets, sizes,
87 .Case([&](MemRefType type) ->
Operation * {
88 return memref::SubViewOp::create(
b, loc, source, offsets, sizes,
94static std::optional<TypedAttr>
98 if (!splatAttr || !splatAttr.
isSplat())
110 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
111 return b.createOrFold<memref::DimOp>(loc, source, dim);
112 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
113 return b.createOrFold<tensor::DimOp>(loc, source, dim);
114 llvm_unreachable(
"Expected MemRefType or TensorType");
119 auto shapedType = llvm::cast<ShapedType>(source.
getType());
120 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
122 return b.getIndexAttr(shapedType.getDimSize(dim));
145 for (
auto containers : {inputTypes, outputTypes}) {
146 for (
auto t : containers) {
158 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
174 std::optional<TypeRange> resultTensorTypes,
181 if (!resultTensorTypes)
182 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
183 llvm::IsaPred<RankedTensorType>);
191 "operandSegmentSizes",
192 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
193 static_cast<int32_t>(outputs.size())}));
203 std::optional<TypeRange> resultTensorTypes,
210 return attr.
getName() ==
"indexing_maps";
213 indexingMapsAttrVal = llvm::map_to_vector(
216 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
219 attributes, regionBuilder);
223 std::optional<TypeRange> resultTensorTypes,
230 return attr.
getName() ==
"indexing_maps";
233 indexingMapsAttrVal = llvm::map_to_vector(
236 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
239 attributes, regionBuilder);
243 std::optional<TypeRange> resultTensorTypes,
250 indexingMapsAttrVal =
252 return AffineMapAttr::get(map);
254 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
256 attributes, regionBuilder);
265 bool addOperandSegmentSizes =
true) {
266 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
295 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
297 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
301 if (addOperandSegmentSizes) {
308 if (
result.propertiesAttr) {
310 attrs.
append(
"operandSegmentSizes",
312 {static_cast<int32_t>(inputsOperands.size()),
313 static_cast<int32_t>(outputsOperands.size())}));
316 result.addAttribute(
"operandSegmentSizes",
318 {static_cast<int32_t>(inputsOperands.size()),
319 static_cast<int32_t>(outputsOperands.size())}));
322 if (!
result.propertiesAttr) {
323 std::optional<RegisteredOperationName> info =
324 result.name.getRegisteredInfo();
326 if (failed(info->verifyInherentAttrs(
result.attributes, [&]() {
327 return parser.emitError(attrsLoc)
328 <<
"'" << result.name.getStringRef() <<
"' op ";
339 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
340 if (!outputs.empty())
341 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
352 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
355 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
356 "region expects {0} args, got {1}",
357 numRegionArgs, inputTypes.size() + outputTypes.size()));
363 opBuilder, region, inputTypes, outputTypes, attrs,
382 unsigned numRegionArgs,
399 result.addTypes(outputTensorsTypes);
401 std::unique_ptr<Region> region = std::make_unique<Region>();
403 outputTypes,
result.attributes.getAttrs(),
406 result.addRegion(std::move(region));
413 if (resultTypes.empty())
458class RegionBuilderHelper {
460 RegionBuilderHelper(OpBuilder &builder,
Block &block)
461 : builder(builder), block(block) {}
464 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
466 if (!isFloatingPoint(arg)) {
468 emitError() <<
"unsupported non numeric type";
471 llvm_unreachable(
"unsupported non numeric type");
473 OpBuilder::InsertionGuard g(builder);
474 builder.setInsertionPointToEnd(&block);
477 return math::ExpOp::create(builder, arg.
getLoc(), arg);
479 return math::LogOp::create(builder, arg.
getLoc(), arg);
481 return math::AbsFOp::create(builder, arg.
getLoc(), arg);
483 return math::CeilOp::create(builder, arg.
getLoc(), arg);
485 return math::FloorOp::create(builder, arg.
getLoc(), arg);
487 return arith::NegFOp::create(builder, arg.
getLoc(), arg);
488 case UnaryFn::reciprocal: {
489 Attribute oneAttr = builder.getOneAttr(arg.
getType());
490 auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
491 ::cast<TypedAttr>(oneAttr));
492 return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
495 return math::RoundOp::create(builder, arg.
getLoc(), arg);
497 return math::SqrtOp::create(builder, arg.
getLoc(), arg);
499 return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
500 case UnaryFn::square:
501 return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
503 return math::TanhOp::create(builder, arg.
getLoc(), arg);
505 return math::ErfOp::create(builder, arg.
getLoc(), arg);
507 return math::SinOp::create(builder, arg.
getLoc(), arg);
509 return math::CosOp::create(builder, arg.
getLoc(), arg);
511 return math::TanOp::create(builder, arg.
getLoc(), arg);
513 return math::AcosOp::create(builder, arg.
getLoc(), arg);
515 return math::AcoshOp::create(builder, arg.
getLoc(), arg);
517 return math::AsinOp::create(builder, arg.
getLoc(), arg);
519 return math::AsinhOp::create(builder, arg.
getLoc(), arg);
521 return math::AtanOp::create(builder, arg.
getLoc(), arg);
523 return math::AtanhOp::create(builder, arg.
getLoc(), arg);
525 return math::Log10Op::create(builder, arg.
getLoc(), arg);
527 return math::Log1pOp::create(builder, arg.
getLoc(), arg);
529 return math::Log2Op::create(builder, arg.
getLoc(), arg);
532 emitError() <<
"unsupported unary function";
535 llvm_unreachable(
"unsupported unary function");
542 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
544 bool allComplex = isComplex(arg0) && isComplex(arg1);
545 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
546 bool allInteger = isInteger(arg0) && isInteger(arg1);
549 if (!allComplex && !allFloatingPoint && !allInteger) {
552 <<
"Cannot build binary Linalg operation: expects allComplex, "
553 "allFloatingPoint, or allInteger, got "
557 llvm_unreachable(
"unsupported non numeric type");
559 OpBuilder::InsertionGuard g(builder);
560 builder.setInsertionPointToEnd(&block);
564 return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
565 if (allFloatingPoint)
566 return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
568 return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
569 return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
572 return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
573 if (allFloatingPoint)
574 return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
577 emitError() <<
"unsupported operation: sub with bools";
580 llvm_unreachable(
"unsupported operation: sub with bools");
582 return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
585 return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
586 if (allFloatingPoint)
587 return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
589 return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
590 return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
593 return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
594 if (allFloatingPoint)
595 return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
598 emitError() <<
"unsupported operation: div with bools";
601 llvm_unreachable(
"unsupported operation: div with bools");
603 return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
604 case BinaryFn::div_unsigned:
605 if (!allInteger || allBool) {
607 emitError() <<
"unsupported operation: unsigned div not on uint";
610 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
612 return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
613 case BinaryFn::max_signed:
615 if (allFloatingPoint)
616 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
617 return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
618 case BinaryFn::min_signed:
620 if (allFloatingPoint)
621 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
622 return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
623 case BinaryFn::max_unsigned:
625 if (!allInteger || allBool) {
627 emitError() <<
"unsupported operation: unsigned max not on uint";
630 llvm_unreachable(
"unsupported operation: unsigned max not on uint");
632 return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
633 case BinaryFn::min_unsigned:
635 if (!allInteger || allBool) {
637 emitError() <<
"unsupported operation: unsigned min not on uint";
640 llvm_unreachable(
"unsupported operation: unsigned min not on uint");
642 return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
644 assert(allFloatingPoint);
645 return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
648 emitError() <<
"unsupported binary function";
651 llvm_unreachable(
"unsupported binary function");
655 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
657 OpBuilder::InsertionGuard g(builder);
658 builder.setInsertionPointToEnd(&block);
660 case TernaryFn::select:
661 return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
664 emitError() <<
"unsupported ternary function";
667 llvm_unreachable(
"unsupported ternary function");
671 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
674 case TypeFn::cast_signed:
675 return cast(toType, operand,
false);
676 case TypeFn::cast_unsigned:
677 return cast(toType, operand,
true);
680 emitError() <<
"unsupported type conversion function";
683 llvm_unreachable(
"unsupported type conversion function");
687 OpBuilder::InsertionGuard g(builder);
688 builder.setInsertionPointToEnd(&block);
689 Location loc = builder.getUnknownLoc();
690 YieldOp::create(builder, loc, values);
693 Value constant(
const std::string &value) {
694 OpBuilder::InsertionGuard g(builder);
695 builder.setInsertionPointToEnd(&block);
696 Location loc = builder.getUnknownLoc();
697 Attribute valueAttr =
parseAttribute(value, builder.getContext());
698 return arith::ConstantOp::create(builder, loc,
699 ::cast<TypedAttr>(valueAttr));
702 Value index(int64_t dim) {
703 OpBuilder::InsertionGuard g(builder);
704 builder.setInsertionPointToEnd(&block);
705 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
708 Type getIntegerType(
unsigned width) {
709 return IntegerType::get(builder.getContext(), width);
712 Type getFloat32Type() {
return Float32Type::get(builder.getContext()); }
713 Type getFloat64Type() {
return Float64Type::get(builder.getContext()); }
720 Value cast(Type toType, Value operand,
bool isUnsignedCast) {
721 OpBuilder::InsertionGuard g(builder);
722 builder.setInsertionPointToEnd(&block);
723 auto loc = operand.
getLoc();
724 if (isa<UnknownLoc>(loc)) {
734 bool isComplex(Value value) {
735 return llvm::isa<ComplexType>(value.
getType());
737 bool isFloatingPoint(Value value) {
738 return llvm::isa<FloatType>(value.
getType());
740 bool isInteger(Value value) {
741 return llvm::isa<IntegerType>(value.
getType());
757 using OpRewritePattern<CopyOp>::OpRewritePattern;
758 LogicalResult matchAndRewrite(CopyOp copyOp,
759 PatternRewriter &rewriter)
const override {
760 if (copyOp.getInputs() != copyOp.getOutputs())
762 if (copyOp.hasPureBufferSemantics())
765 rewriter.
replaceOp(copyOp, copyOp.getInputs());
775 results.
add<EraseSelfCopy>(context);
788template <
typename TensorReshapeOp>
789struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
790 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
791 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
792 PatternRewriter &rewriter)
const override {
793 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
797 Location loc = oldFill.getLoc();
798 TensorReshapeOp newInit;
799 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
801 newInit = TensorReshapeOp::create(
802 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
803 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
804 reshapeOp.getStaticOutputShape());
806 newInit = TensorReshapeOp::create(
807 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
808 reshapeOp.getReassociation());
818struct FoldFillWithPad final :
public OpRewritePattern<tensor::PadOp> {
821 LogicalResult matchAndRewrite(tensor::PadOp padOp,
822 PatternRewriter &rewriter)
const override {
823 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
829 Value padValue = padOp.getConstantPaddingValue();
830 if (!padValue || fillOp.value() != padValue)
836 padOp,
"failed to reify tensor.pad op result shape");
839 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
840 padOp.getResultType().getElementType());
842 FillOp::create(rewriter, fillOp.getLoc(),
ValueRange{padValue},
845 if (
replacement.getType() != padOp.getResultType()) {
846 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
857struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
860 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
861 PatternRewriter &rewriter)
const override {
862 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
866 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
871 Value firstDest = insertOp.getDest();
872 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
873 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
878 bool disjoint =
false;
879 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
882 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
883 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
884 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
888 int64_t prevStart = prevOp.getStaticOffset(i);
889 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
890 prevOp.getStaticStride(i);
891 int64_t nextStart = insertOp.getStaticOffset(i);
892 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
893 insertOp.getStaticStride(i);
894 if (prevEnd < nextStart || nextEnd < prevStart) {
902 firstDest = prevOp.getDest();
913 Value padValue = srcPadOp.getConstantPaddingValue();
914 if (!padValue || dstFillOp.value() != padValue)
917 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
918 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
920 Location loc = insertOp.getLoc();
923 AffineExpr sym0, sym1;
929 SmallVector<OpFoldResult, 4> newOffsets;
930 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
932 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
935 RankedTensorType srcPadType = srcPadOp.getSourceType();
936 SmallVector<OpFoldResult, 4> newSizes;
937 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
938 if (srcPadType.isDynamicDim(i)) {
940 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
943 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
948 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
949 newSizes, insertOp.getMixedStrides());
955struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
957 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
959 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
960 PatternRewriter &rewriter)
const override {
963 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
968 Value extractedScalar = fillOp.getInputs()[0];
971 rewriter.
replaceOp(extractOp, extractedScalar);
979static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
980 linalg::PackOp packOp) {
981 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
985 if (
auto paddingValue = packOp.getPaddingValue())
989 Value packOpDest = packOp.getDest();
993 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
998struct FoldFillWithPack :
public OpRewritePattern<linalg::PackOp> {
1000 FoldFillWithPack(MLIRContext *context)
1001 : OpRewritePattern<linalg::PackOp>(context) {}
1003 LogicalResult matchAndRewrite(linalg::PackOp packOp,
1004 PatternRewriter &rewriter)
const override {
1005 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
1008 rewriter.
replaceOp(packOp, fillOp.value().result());
1014struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
1015 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
1017 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
1018 PatternRewriter &rewriter)
const override {
1019 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
1022 copyOp.getOutputs());
1025 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
1027 fillOp.getOutputs());
1035struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1036 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1038 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1039 PatternRewriter &rewriter)
const override {
1040 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1042 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1043 transposeOp.getDpsInitOperand(0)->get());
1052struct FoldConcatsOfFill :
public OpRewritePattern<tensor::ConcatOp> {
1055 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1056 PatternRewriter &rewriter)
const override {
1057 auto concatOperands = concatOp.getInputs();
1058 if (concatOperands.empty()) {
1062 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1067 OpFoldResult firstFillVal =
1070 SmallVector<Value> allOuts;
1071 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1073 auto isDefinedByCompatibleFillOp = [&](Value v) ->
bool {
1074 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1079 OpFoldResult fillVal =
1081 if (fillVal != firstFillVal)
1084 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1087 if (!llvm::all_of(concatOperands.drop_front(),
1088 isDefinedByCompatibleFillOp)) {
1090 concatOp,
"not all operands are defined by a compatible fill op");
1093 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1094 concatOp.getDim(), allOuts);
1096 concatOp, firstFillOp.getDpsInputOperand(0)->
get(), outsConcat);
1103void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1104 MLIRContext *context) {
1105 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1106 FoldFillWithPack, FoldFillWithPad,
1107 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1108 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1109 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1122 for (
ValueRange container : {inputs, outputs}) {
1123 for (
Value v : container) {
1124 Type t = v.getType();
1125 blockArgTypes.push_back(
1127 blockArgLocs.push_back(v.getLoc());
1133 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1137void GenericOp::getAsmBlockArgumentNames(Region ®ion,
1139 for (Value v : getRegionInputArgs())
1141 for (Value v : getRegionOutputArgs())
1142 setNameFn(v,
"out");
1145void GenericOp::build(
1146 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1148 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1150 ArrayRef<NamedAttribute> attributes) {
1151 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1152 iteratorTypes, doc, libraryCall);
1153 result.addAttributes(attributes);
1156 inputs, outputs, bodyBuild);
1159void GenericOp::build(
1160 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1162 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1163 StringRef libraryCall,
1165 ArrayRef<NamedAttribute> attributes) {
1166 build(builder,
result, resultTensorTypes, inputs, outputs,
1170 [&](utils::IteratorType iter) -> mlir::Attribute {
1171 return IteratorTypeAttr::get(builder.getContext(), iter);
1174 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1175 bodyBuild, attributes);
1178void GenericOp::build(
1180 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1181 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1182 StringRef libraryCall,
1184 ArrayRef<NamedAttribute> attributes) {
1186 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1189void GenericOp::build(
1191 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1192 ArrayRef<utils::IteratorType> iteratorTypes,
1194 ArrayRef<NamedAttribute> attributes) {
1195 build(builder,
result, inputs, outputs, indexingMaps, iteratorTypes,
1197 "", bodyBuild, attributes);
1200void GenericOp::build(
1201 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1203 ArrayRef<utils::IteratorType> iteratorTypes,
1205 ArrayRef<NamedAttribute> attributes) {
1206 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1209 "", bodyBuild, attributes);
1212void GenericOp::print(OpAsmPrinter &p) {
1216 auto genericAttrNames = linalgTraitAttrNames();
1218 llvm::StringSet<> genericAttrNamesSet;
1219 genericAttrNamesSet.insert_range(genericAttrNames);
1220 SmallVector<NamedAttribute, 8> genericAttrs;
1221 for (
auto attr : (*this)->getAttrs()) {
1222 if (attr.getName() == getIteratorTypesAttrName()) {
1223 auto iteratorTypes =
1224 llvm::cast<ArrayAttr>(attr.getValue())
1225 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1230 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1231 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1232 return StringAttr::get(
getContext(), stringifyIteratorType(t));
1235 genericAttrs.emplace_back(
1236 getIteratorTypesAttrName(),
1237 ArrayAttr::get(
getContext(), iteratorTypeNames));
1238 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1239 genericAttrs.push_back(attr);
1242 if (!genericAttrs.empty()) {
1243 auto genericDictAttr = DictionaryAttr::get(
getContext(), genericAttrs);
1244 p << genericDictAttr;
1250 genericAttrNames.push_back(
"operandSegmentSizes");
1251 genericAttrNamesSet.insert(genericAttrNames.back());
1253 bool hasExtraAttrs =
false;
1254 for (NamedAttribute n : (*this)->getAttrs()) {
1255 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1258 if (hasExtraAttrs) {
1265 if (!getRegion().empty()) {
1274ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &
result) {
1275 DictionaryAttr dictAttr;
1283 result.attributes.assign(dictAttr.getValue().begin(),
1284 dictAttr.getValue().end());
1290 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1291 result.attributes.get(getIteratorTypesAttrName(
result.name)));
1292 if (!iteratorTypes) {
1293 return parser.
emitError(attributeLocation)
1294 <<
"expected " << getIteratorTypesAttrName(
result.name)
1295 <<
" array attribute";
1298 SmallVector<Attribute> iteratorTypeAttrs;
1300 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1301 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1302 if (!maybeIteratorType.has_value())
1304 <<
"unexpected iterator_type (" << s <<
")";
1306 iteratorTypeAttrs.push_back(
1307 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
1309 result.attributes.set(getIteratorTypesAttrName(
result.name),
1313 SmallVector<Type, 1> inputTypes, outputTypes;
1323 std::unique_ptr<Region> region = std::make_unique<Region>();
1326 result.addRegion(std::move(region));
1332 SmallVector<Type, 1> outputTensorsTypes;
1335 result.addTypes(outputTensorsTypes);
1343 LinalgOp linalgOp) {
1344 for (
auto [
index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1345 if (!llvm::isa<MemRefType>(operand.
getType()))
1347 effects.emplace_back(
1352 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1353 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1355 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1366void GenericOp::getEffects(
1367 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1376 if (!linalgOp.hasPureTensorSemantics())
1394template <
typename OpTy>
1395struct EraseIdentityLinalgOp :
public OpRewritePattern<OpTy> {
1396 using OpRewritePattern<OpTy>::OpRewritePattern;
1398 LogicalResult matchAndRewrite(OpTy linalgOp,
1399 PatternRewriter &rewriter)
const override {
1401 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1406 Block &body = linalgOp->getRegion(0).front();
1407 if (!llvm::hasSingleElement(body))
1409 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1414 if (linalgOp.hasPureBufferSemantics()) {
1415 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1416 linalgOp.getDpsInputOperand(0)->get() !=
1417 linalgOp.getDpsInitOperand(0)->get()) {
1419 linalgOp,
"expected single input and output to be the same value");
1422 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1423 if (!yieldArg || yieldArg.getOwner() != &body) {
1425 "cannot fold fill-like op");
1432 if (!linalgOp.hasPureTensorSemantics()) {
1434 linalgOp,
"mixed semantics is not supported yet");
1439 SmallVector<Value> returnedArgs;
1440 for (
const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1441 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1442 if (!yieldArg || yieldArg.getOwner() != &body)
1444 unsigned argumentNumber = yieldArg.getArgNumber();
1445 Value returnedArg = linalgOp->getOperand(argumentNumber);
1446 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1449 Type returnType = returnedArg.
getType();
1450 if (returnType != resultType) {
1455 returnedArg = sparse_tensor::ConvertOp::create(
1456 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1458 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1461 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1462 resultType, returnedArg);
1465 returnedArgs.push_back(returnedArg);
1468 if (returnedArgs.size() != linalgOp->getNumResults())
1470 rewriter.
replaceOp(linalgOp, returnedArgs);
1477void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1478 MLIRContext *context) {
1479 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1482LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1501 for (
Type outputType : outputTypes) {
1502 if (llvm::isa<RankedTensorType>(outputType))
1503 result.addTypes(outputType);
1507 if (parseAttrsFn && failed(parseAttrsFn(parser,
result.attributes)))
1516void MapOp::getAsmBlockArgumentNames(Region ®ion,
1518 for (Value v : getRegionInputArgs())
1520 for (Value v : getRegionOutputArgs())
1521 setNameFn(v,
"init");
1524void MapOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1525 if (!getResults().empty())
1526 setNameFn(getResults().front(),
"mapped");
1532 ArrayRef<NamedAttribute> attributes) {
1534 result.addAttributes(attributes);
1537 Type initType = init.
getType();
1538 if (llvm::isa<RankedTensorType>(initType))
1539 result.addTypes(initType);
1543 inputs, {init}, bodyBuild);
1550 bool initFirst =
false,
bool mapInit =
true) {
1554 b.setInsertionPointToStart(&block);
1555 for (
auto &operand : operands) {
1557 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1565 payloadOpOperands.push_back(block.
getArguments().back());
1566 for (
const auto &arg : block.
getArguments().drop_back())
1567 payloadOpOperands.push_back(arg);
1576 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1582ParseResult MapOp::parse(OpAsmParser &parser, OperationState &
result) {
1583 std::optional<OperationName> payloadOpName;
1584 NamedAttrList payloadOpAttrs;
1587 if (
failed(operationName))
1591 payloadOpName = operationName.value();
1599 if (payloadOpName.has_value()) {
1600 if (!
result.operands.empty())
1602 payloadOpAttrs, ArrayRef(
result.operands),
false,
1607 SmallVector<OpAsmParser::Argument> regionArgs;
1612 Region *body =
result.addRegion();
1620 bool mapInit =
true) {
1622 if (initFirst && !mapInit)
1646 for (
const auto &[operand, bbArg] :
1648 if (bbArg != operand)
1652 for (
const auto &[operand, bbArg] :
1655 if (bbArg != operand)
1662 return yieldOp.getNumOperands() == 1 &&
1663 yieldOp.getOperand(0).getDefiningOp() &&
1664 yieldOp.getOperand(0).getDefiningOp() == &payload;
1669 std::string attrToElide;
1671 for (
const auto &attr : payloadOp->
getAttrs()) {
1673 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1674 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1675 attrToElide = attr.getName().str();
1676 elidedAttrs.push_back(attrToElide);
1684void MapOp::print(OpAsmPrinter &p) {
1685 Block *mapper = getBody();
1695 if (!useShortForm) {
1701 [&](
auto arg) { p.printRegionArgument(arg); });
1709LogicalResult MapOp::verify() {
1710 auto *bodyBlock = getBody();
1711 auto blockArgs = bodyBlock->getArguments();
1715 if (getInputs().size() + 1 != blockArgs.size())
1716 return emitOpError() <<
"expects number of operands to match the arity of "
1718 << getInputs().size() + 1 <<
" and "
1719 << blockArgs.size();
1722 for (
const auto &[bbArgType, inputArg] :
1723 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1724 auto inputElemType =
1725 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1726 if (bbArgType != inputElemType) {
1727 return emitOpError() <<
"expected element type of input " << inputElemType
1728 <<
" to match bbArg type " << bbArgType;
1733 auto outputShape = getInit().getType().getShape();
1734 for (Type inputArgType :
TypeRange{getInputs()}) {
1735 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1736 if (inputElemShape != outputShape) {
1737 return emitOpError() <<
"expected shape of input (" << inputElemShape
1738 <<
") to match shape of output (" << outputShape
1746SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1747 int64_t rank = getInit().getType().getRank();
1748 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1753 int64_t rank = getInit().getType().getRank();
1754 int64_t numIndexingMaps = getOperands().size();
1759void MapOp::getEffects(
1760 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1773void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
1775 for (Value v : getRegionInputArgs())
1777 for (Value v : getRegionOutputArgs())
1778 setNameFn(v,
"init");
1781void ReduceOp::getAsmResultNames(
1783 if (!getResults().empty())
1784 setNameFn(getResults().front(),
"reduced");
1787void ReduceOp::build(
1789 ValueRange inits, ArrayRef<int64_t> dimensions,
1791 ArrayRef<NamedAttribute> attributes) {
1793 result.addAttributes(attributes);
1796 for (Value init : inits) {
1797 Type initType = init.
getType();
1798 if (llvm::isa<RankedTensorType>(initType))
1799 result.addTypes(initType);
1804 inputs, inits, bodyBuild);
1807SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1809 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1810 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1811 utils::IteratorType::parallel);
1812 for (int64_t reductionDim : getDimensions())
1813 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1814 return iteratorTypes;
1819 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1820 SmallVector<AffineMap> affineMaps(
1823 AffineMap resultMap =
1826 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1827 affineMaps.push_back(resultMap);
1828 return Builder(
getContext()).getAffineMapArrayAttr(affineMaps);
1831void ReduceOp::getEffects(
1832 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1843 StringRef attributeName) {
1851ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &
result) {
1852 std::optional<OperationName> payloadOpName;
1853 NamedAttrList payloadOpAttrs;
1856 if (
failed(operationName))
1860 payloadOpName = operationName.value();
1866 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1871 if (payloadOpName.has_value()) {
1873 ArrayRef(
result.operands),
true);
1875 SmallVector<OpAsmParser::Argument> regionArgs;
1881 Region *body =
result.addRegion();
1891 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1894void ReduceOp::print(OpAsmPrinter &p) {
1895 Block *mapper = getBody();
1904 if (!useShortForm) {
1910 [&](
auto arg) { p.printRegionArgument(arg); });
1918LogicalResult ReduceOp::verify() {
1919 ArrayRef<int64_t> dimensionsRef = getDimensions();
1926 if (getInputs().size() !=
static_cast<size_t>(getNumDpsInputs()))
1928 <<
"expected equal number of inputs and outputs (required by "
1929 "SameVariadicOperandSize), got "
1930 << getNumDpsInputs() <<
" input(s) and " << getNumDpsInits()
1933 if (getInputs().empty())
1934 return emitOpError() <<
"expected at least one input";
1936 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1939 return emitOpError() <<
"expects all inputs to have the same shapes. "
1940 "Shape at input-index "
1942 <<
" is not equal to the shape at input-index 0.";
1945 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1948 return emitOpError() <<
"expects all outputs to have the same shapes. "
1949 "Shape at output-index "
1951 <<
" is not equal to the shape at output-index 0.";
1954 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1955 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1958 for (int64_t dimension : dimensionsRef) {
1959 if (dimension < 0 || dimension >= inputType.getRank()) {
1961 <<
"dimensions for reduction should be in the range [0, "
1962 << inputType.getRank() - 1 <<
"].";
1964 dimensionsToReduce.insert(dimension);
1967 auto inputDims = inputType.getShape();
1968 auto initDims = initType.getShape();
1971 SmallVector<int64_t> reducedInputDims;
1972 for (
const auto &en : llvm::enumerate(inputDims)) {
1973 if (!dimensionsToReduce.count(en.index()))
1974 reducedInputDims.push_back(en.value());
1977 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1978 return emitOpError() <<
"number of dimensions after reduction "
1979 << reducedInputDims.size()
1980 <<
" doesn't match the init rank "
1981 << initType.getRank();
1984 if (reducedInputDims != initDims)
1985 return emitOpError() <<
"init dimensions [" << initDims
1986 <<
"] doesn't match input dimensions after reduction ["
1987 << reducedInputDims <<
"]";
1989 Block *block = getBody();
1992 <<
"mismatching number of operands and block arguments";
1995 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1996 Type inputElementType =
1997 llvm::cast<ShapedType>(input.getType()).getElementType();
1998 if (inputElementType != bbArg.getType())
2000 <<
"input element type " << inputElementType
2001 <<
" does not match corresponding block argument type "
2006 for (
auto [output, bbArg] : llvm::zip(
2007 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
2008 auto outputElementType =
2009 llvm::cast<ShapedType>(output.getType()).getElementType();
2010 if (outputElementType != bbArg.getType())
2012 <<
"output element type " << outputElementType
2013 <<
" does not match corresponding block argument type "
2029 linalg::YieldOp::create(
b, loc, args[0]);
2033void TransposeOp::build(::mlir::OpBuilder &builder,
2034 ::mlir::OperationState &
result, Value input, Value init,
2036 ArrayRef<NamedAttribute> attributes) {
2037 result.addOperands(input);
2038 result.addOperands(init);
2039 result.addAttribute(getPermutationAttrName(
result.name), permutation);
2040 result.addAttributes(attributes);
2043 Type initType = init.
getType();
2044 if (llvm::isa<RankedTensorType>(initType))
2045 result.addTypes(initType);
2051void TransposeOp::build(::mlir::OpBuilder &builder,
2052 ::mlir::OperationState &
result, Value input, Value init,
2053 ArrayRef<int64_t> permutation,
2054 ArrayRef<NamedAttribute> attributes) {
2059ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
2061 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2073void TransposeOp::getAsmResultNames(
2075 if (!getResults().empty())
2076 setNameFn(getResults().front(),
"transposed");
2079void TransposeOp::print(OpAsmPrinter &p) {
2085LogicalResult TransposeOp::verify() {
2086 ArrayRef<int64_t> permutationRef = getPermutation();
2091 auto inputType = getInput().getType();
2092 auto initType = getInit().getType();
2094 int64_t rank = inputType.getRank();
2100 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2101 return emitOpError() <<
"size of permutation " << permutationRef.size()
2102 <<
" does not match the argument rank " << rank;
2104 auto inputDims = inputType.getShape();
2105 auto initDims = initType.getShape();
2107 for (int64_t i = 0; i < rank; ++i) {
2108 int64_t inputDim = inputDims[permutationRef[i]];
2109 int64_t initDim = initDims[i];
2111 if (inputDim != initDim) {
2112 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2113 <<
" doesn't match dim(input, permutation[" << i
2114 <<
"]) = " << inputDim;
2121SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2122 int64_t rank = getInit().getType().getRank();
2123 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2126ArrayAttr TransposeOp::getIndexingMaps() {
2128 int64_t rank = getInit().getType().getRank();
2131 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2135void TransposeOp::getEffects(
2136 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2145LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2146 SmallVectorImpl<OpFoldResult> &
result) {
2148 if (!isa<TensorType>(getInput().
getType()))
2152 if (getPermutation().empty()) {
2153 result.push_back(getInput());
2158 result.push_back(getInput());
2171 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2172 if (!defTransposeOp)
2177 foldedPerms.reserve(perms.size());
2179 foldedPerms.push_back(defPerms[perm]);
2182 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2195 if (!transposeOp.hasPureTensorSemantics())
2200 if (!splatValue.has_value())
2204 cast<RankedTensorType>(transposeOp.getResult()[0].getType());
2221 Value input = transposeOp.getInput();
2222 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2233 unsigned dimensionSize = dimensions.size();
2234 for (
unsigned i = 0; i < dimensionSize; ++i)
2235 resultDimensions.push_back(invertPerm[dimensions[i]]);
2238 Value broadcastInput = broadcastOp.getInput();
2239 Location loc = transposeOp.getLoc();
2242 auto broadcastInputTy =
2243 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2244 unsigned inputRank = broadcastInputTy.getRank();
2245 for (
unsigned i = 0; i < inputRank; ++i) {
2246 if (broadcastInputTy.isDynamicDim(i)) {
2247 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2250 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2251 broadcastInputTy.getDimSize(i)));
2256 Value transposeInit = tensor::EmptyOp::create(
2257 rewriter, transposeOp.getLoc(), transposeResultShapes,
2258 broadcastInputTy.getElementType());
2261 Value transposeResult =
2262 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2263 transposeInit, resultPerms)
2266 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2271void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2272 MLIRContext *context) {
2273 results.
add<FoldTransposeWithTranspose, FoldTransposeSplatConstant,
2274 SwapTransposeWithBroadcast>(context);
2281void BroadcastOp::build(::mlir::OpBuilder &builder,
2282 ::mlir::OperationState &
result, Value input, Value init,
2284 ArrayRef<NamedAttribute> attributes) {
2285 result.addOperands(input);
2286 result.addOperands(init);
2287 result.addAttribute(getDimensionsAttrName(
result.name), dimensions);
2288 result.addAttributes(attributes);
2291 Type initType = init.
getType();
2292 if (llvm::isa<RankedTensorType>(initType))
2293 result.addTypes(initType);
2299void BroadcastOp::build(::mlir::OpBuilder &builder,
2300 ::mlir::OperationState &
result, Value input, Value init,
2301 ArrayRef<int64_t> dimensions,
2302 ArrayRef<NamedAttribute> attributes) {
2307ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &
result) {
2309 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2321void BroadcastOp::getAsmResultNames(
2323 if (!getResults().empty())
2324 setNameFn(getResults().front(),
"broadcasted");
2327void BroadcastOp::print(OpAsmPrinter &p) {
2333LogicalResult BroadcastOp::verify() {
2334 ArrayRef<int64_t> dimensionsRef = getDimensions();
2336 auto inputType = getInput().getType();
2337 auto initType = getInit().getType();
2339 int64_t inputRank = inputType.getRank();
2340 int64_t initRank = initType.getRank();
2342 auto inputShape = inputType.getShape();
2343 auto initShape = initType.getShape();
2345 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2346 return emitOpError() <<
"input rank plus added dimensions does not "
2347 "match init rank. input rank: "
2349 <<
", dimensions size: " << dimensionsRef.size()
2350 <<
", init rank: " << initRank;
2352 for (
const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2353 if (dim < 0 || dim >= initRank)
2355 <<
" is out of range. expected range: [0, "
2356 << initRank - 1 <<
"], got: " << dim;
2360 SmallVector<int64_t> dimMap;
2361 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2362 if (!llvm::is_contained(dimensionsRef, dim))
2363 dimMap.push_back(dim);
2366 for (
const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2369 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2370 return emitOpError() <<
"input dim " << inputDimIdx
2371 <<
" should match init dim " << initDimIdx
2372 <<
". input: " << inputShape[inputDimIdx]
2373 <<
", init: " << initShape[initDimIdx];
2379SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2380 int64_t rank = getInit().getType().getRank();
2381 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2384ArrayAttr BroadcastOp::getIndexingMaps() {
2386 int64_t rank = getInit().getType().getRank();
2392void BroadcastOp::getEffects(
2393 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2408 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2409 if (!defBroadcastOp)
2414 Value init = broadcastOp.getInit();
2418 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2419 if (!llvm::is_contained(dimensions, dim))
2420 dimMap.push_back(dim);
2422 for (
auto dim : defDimensions)
2423 foldedDims.push_back(dimMap[dim]);
2425 llvm::sort(foldedDims);
2427 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2439 if (!broadcastOp.hasPureTensorSemantics())
2445 if (!splatValue.has_value())
2449 cast<RankedTensorType>(broadcastOp.getResult()[0].getType());
2450 if (!resultType.hasStaticShape())
2452 "result type has dynamic shape");
2461void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2462 MLIRContext *context) {
2463 results.
add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts,
2464 FoldBroadcastSplatConstant>(context);
2471void linalg::YieldOp::print(OpAsmPrinter &p) {
2472 if (getNumOperands() > 0)
2473 p <<
' ' << getOperands();
2475 if (getNumOperands() > 0)
2476 p <<
" : " << getOperandTypes();
2479ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &
result) {
2480 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2481 SmallVector<Type, 2> types;
2491static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2492 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2493 return op.emitOpError(
"expected number of yield values (")
2494 << op.getNumOperands()
2495 <<
") to match the number of inits / outs operands of the enclosing "
2496 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2498 for (
OpOperand &opOperand : op->getOpOperands()) {
2500 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2502 if (isa<MemRefType, RankedTensorType>(elementType))
2504 if (opOperand.get().getType() != elementType)
2505 return op.emitOpError(
"type of yield operand ")
2506 << (opOperand.getOperandNumber() + 1) <<
" ("
2507 << opOperand.get().getType() <<
") doesn't match "
2508 <<
"the element type of the enclosing linalg.generic op ("
2509 << elementType <<
")";
2514LogicalResult linalg::YieldOp::verify() {
2515 auto *parentOp = (*this)->getParentOp();
2516 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2517 return emitOpError(
"expected single non-empty parent region");
2519 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2522 return emitOpError(
"expected parent op with LinalgOp interface");
2529LogicalResult IndexOp::verify() {
2530 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2532 return emitOpError(
"expected parent op with LinalgOp interface");
2533 if (linalgOp.getNumLoops() <= getDim())
2535 << getDim() <<
") to be lower than the number of loops ("
2536 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2540OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2541 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2546 return OpFoldResult{};
2549 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2550 uint64_t dim = getDim();
2551 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2552 if (loopBounds[dim] == 1)
2553 return IntegerAttr::get(IndexType::get(
getContext()), 0);
2555 return OpFoldResult{};
2560#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2562#define GET_OP_CLASSES
2563#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2565#define GET_OP_CLASSES
2566#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2567#define GET_OP_CLASSES
2568#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2585 for (
unsigned i = 0; i < num; ++i)
2592 auto rangeA = llvm::make_range(a.begin(), a.end());
2593 auto rangeB = llvm::make_range(
b.begin(),
b.end());
2594 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2595 return llvm::to_vector<4>(concatRanges);
2599 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2601 for (
auto size :
memref.getShape())
2608 if (
auto as =
memref.getMemorySpace()) {
2609 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2610 ss <<
"as" << attr.getInt();
2616 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2619 vec.getShape(), [&](
int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2632 assert(isa<LinalgOp>(op));
2634 std::string fun =
"";
2636 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2637 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2638 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2639 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2643 llvm::replace(name,
'.',
'_');
2644 llvm::raw_string_ostream ss(name);
2648 return std::string();
2663 LogicalResult matchAndRewrite(LinalgOp op,
2665 for (
OpOperand &opOperand : op->getOpOperands()) {
2669 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2672 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2683struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2684 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2686 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2687 PatternRewriter &rewriter)
const override {
2691 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2698 if (castOp->getBlock() != linalgOp->getBlock())
2701 OpBuilder::InsertionGuard guard(rewriter);
2704 Location loc = linalgOp.getLoc();
2705 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2708 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2714 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2716 tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
2717 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2718 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2719 linalgOp.getDpsInits().end());
2720 outputOperands[resultNumber] = newOperand;
2721 newOperands.append(outputOperands.begin(), outputOperands.end());
2723 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2724 linalgOp->result_type_end());
2725 resultTypes[resultNumber] = resultType;
2726 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2729 Value castBack = tensor::CastOp::create(
2733 results[resultNumber] = castBack;
2742static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2743 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2744 for (OpOperand &opOperand : operands) {
2745 if (linalgOp.isScalar(&opOperand))
2747 Value src = opOperand.get();
2748 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2749 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2755 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2757 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2758 Value castSource = castOp.getSource();
2759 auto castSourceType =
2760 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2761 if (castSourceType && castSourceType.hasStaticShape())
2762 sourceShape = castSourceType.getShape();
2768 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2769 if (sourceType.isDynamicDim(i))
2771 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2772 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2782static void createNewOperandWithStaticSizes(
2783 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2784 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2785 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2786 bool &changeNeeded) {
2787 Value src = opOperand->
get();
2788 newOperands.push_back(src);
2789 if (linalgOp.isScalar(opOperand))
2791 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2792 Type resultType = sourceType;
2793 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2794 resultTypes.push_back(resultType);
2797 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2798 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2799 SmallVector<int64_t> newShape;
2802 bool newOperandNeeded =
false;
2803 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2804 int64_t dimShape = sourceShape[i];
2805 AffineExpr dimExpr = sourceMap.
getResult(i);
2806 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2807 newShape.push_back(dimShape);
2813 newShape.push_back(affineExprToSize[dimExpr]);
2814 newOperandNeeded =
true;
2816 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2817 sourceType.getEncoding());
2818 if (newOperandNeeded) {
2819 changeNeeded =
true;
2822 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2824 newOperands[index] = newOperand;
2826 if (linalgOp.isDpsInit(opOperand))
2827 resultTypes.push_back(resultType);
2833struct InferStaticShapeOfOperands :
public OpInterfaceRewritePattern<LinalgOp> {
2834 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2836 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2837 PatternRewriter &rewriter)
const override {
2838 if (!linalgOp.hasPureTensorSemantics())
2842 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2843 return !map.isProjectedPermutation();
2848 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2849 Location loc = linalgOp.getLoc();
2853 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2855 SmallVector<Value> newOperands;
2856 SmallVector<Type> resultTypes;
2860 bool changeNeeded =
false;
2861 newOperands.reserve(linalgOp->getNumOperands());
2862 resultTypes.reserve(linalgOp.getNumDpsInits());
2865 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2866 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2867 affineExprToSize, linalgOp, newOperands,
2868 resultTypes, changeNeeded);
2877 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2878 SmallVector<Value> replacements;
2880 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2881 Value newResult = std::get<1>(it);
2882 Value oldResult = std::get<0>(it);
2883 Type newType = newResult.
getType();
2884 Type oldType = oldResult.
getType();
2885 replacements.push_back(
2886 (newType != oldType)
2887 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2890 rewriter.
replaceOp(linalgOp, replacements);
2904LogicalResult SoftmaxOp::verify() {
2905 ShapedType inputType = getInputOperandType();
2906 ShapedType outputType = getOutputOperandType();
2908 ArrayRef<int64_t> inputShape = inputType.getShape();
2909 ArrayRef<int64_t> outputShape = outputType.getShape();
2913 int64_t inputRank = getInputOperandRank();
2914 int64_t dimension = getDimension();
2915 if ((dimension < 0) || (dimension >= inputRank))
2916 return emitOpError(
"incorrect dimension specified");
2921SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2922 int64_t operandRank = getInputOperandRank();
2923 SmallVector<Range> loopBounds(operandRank);
2924 Location loc = getLoc();
2927 Value source = getInput();
2928 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2929 loopBounds[dim].offset = zero;
2930 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2931 loopBounds[dim].stride = one;
2936SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2937 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2938 utils::IteratorType::parallel);
2939 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2940 return iteratorTypes;
2943FailureOr<TilingResult>
2944SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2945 ArrayRef<OpFoldResult> offsets,
2946 ArrayRef<OpFoldResult> sizes) {
2947 int64_t rank = getInputOperandRank();
2949 SmallVector<OpFoldResult> strides(rank, oneAttr);
2950 SmallVector<Value> tiledOperands;
2951 Operation *inputSlice =
2952 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2954 return emitOpError(
"failed to compute input slice");
2956 tiledOperands.emplace_back(inputSlice->
getResult(0));
2957 Operation *outputSlice =
2958 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2960 return emitOpError(
"failed to compute output slice");
2962 tiledOperands.emplace_back(outputSlice->
getResult(0));
2964 SmallVector<Type, 4> resultTypes;
2965 if (hasPureTensorSemantics())
2966 resultTypes.push_back(tiledOperands[1].
getType());
2967 Operation *tiledOp =
2968 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2970 return TilingResult{
2973 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2976LogicalResult SoftmaxOp::getResultTilePosition(
2977 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2978 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2979 SmallVector<OpFoldResult> &resultSizes) {
2980 if (resultNumber == 0) {
2981 resultOffsets.assign(offsets.begin(), offsets.end());
2982 resultSizes.assign(sizes.begin(), sizes.end());
2989LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2994SoftmaxOp::reifyResultShapes(OpBuilder &
b,
2996 SmallVector<OpFoldResult> shapes;
2997 Location loc = getOperation()->getLoc();
2998 IRRewriter rewriter(
b);
2999 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
3000 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
3001 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
3002 if (!outputShapedType.isDynamicDim(dim)) {
3004 shapes.push_back(
b.getIndexAttr(inputShapedType.getDimSize(dim)));
3011 reifiedReturnShapes.emplace_back(std::move(shapes));
3015void SoftmaxOp::getEffects(
3016 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3018 for (
auto [index, operand] : llvm::enumerate(getDpsInputs())) {
3019 if (!llvm::isa<MemRefType>(operand.
getType()))
3022 &getOperation()->getOpOperand(index), 0,
3027 for (OpOperand &operand : getDpsInitsMutable()) {
3028 if (!llvm::isa<MemRefType>(operand.get().
getType()))
3059static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
3061 int64_t dim,
bool allParallel =
false) {
3063 utils::IteratorType::parallel);
3065 iteratorTypes[dim] = utils::IteratorType::reduction;
3069 for (
int i = 0; i < inputRank; i++) {
3076 return std::make_tuple(iteratorTypes, indexingMaps);
3081template <
typename T>
3084 auto inputType = cast<ShapedType>(input.
getType());
3086 int64_t inputRank = inputShape.size();
3087 auto [iteratorTypes, indexingMaps] =
3089 assert(indexingMaps.size() == 2 &&
3090 "We should have two maps: 1 for the input, 1 for the output");
3091 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3093 auto genericOp = linalg::GenericOp::create(
3094 builder, loc, output.
getType(), input, output, indexingMaps,
3096 Value result = T::create(b, loc, args[0], args[1]);
3097 linalg::YieldOp::create(b, loc, result);
3099 return genericOp.getResult(0);
3107 auto inputType = cast<ShapedType>(input.
getType());
3109 int64_t inputRank = inputShape.size();
3111 builder, inputRank, dim,
true);
3112 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
3113 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3115 indexingMaps.push_back(indexingMaps[0]);
3116 auto genericOp = linalg::GenericOp::create(
3118 indexingMaps, iteratorTypes,
3120 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3121 Value result = math::ExpOp::create(b, loc, diff);
3122 linalg::YieldOp::create(b, loc, result);
3124 return genericOp.getResult(0);
3134 auto inputType = cast<ShapedType>(numerator.
getType());
3136 int64_t inputRank = inputShape.size();
3138 builder, inputRank, dim,
true);
3139 assert(indexingMaps.size() == 2 &&
3140 "We should have one map for each input (2)");
3141 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
3143 indexingMaps.push_back(indexingMaps[0]);
3144 auto genericOp = linalg::GenericOp::create(
3146 output, indexingMaps, iteratorTypes,
3148 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3149 linalg::YieldOp::create(b, loc, result);
3151 return genericOp.getResult(0);
3173FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &
b) {
3174 OpBuilder::InsertionGuard guard(
b);
3175 b.setInsertionPoint(*
this);
3176 Location loc = getLoc();
3177 Value input = getInput();
3178 ShapedType inputType = getInputOperandType();
3179 Type elementType = inputType.getElementType();
3180 int64_t reductionDim = getDimension();
3182 Value output = getOutput();
3183 dims.erase(dims.begin() + reductionDim);
3185 Value outputReduce = tensor::EmptyOp::create(
b, loc, dims, elementType);
3187 elementType,
b, loc,
3189 Value neutralForMaxFInit =
3190 linalg::FillOp::create(
b, loc, Value{neutralForMaxF}, outputReduce)
3202 linalg::FillOp::create(
b, loc, Value{zero}, outputReduce).
result();
3208 buildDivOp(
b, loc, numerator, denominator, output, reductionDim);
3209 return SmallVector<Value>{
result};
3216LogicalResult WinogradFilterTransformOp::verify() {
3217 auto filterType = cast<ShapedType>(getFilter().
getType());
3218 ArrayRef<int64_t> filterShape = filterType.getShape();
3219 int64_t filterH = filterShape[getFilterHDim()];
3220 int64_t filterW = filterShape[getFilterWDim()];
3221 WinogradConv2DFmr fmr = getFmr();
3225 if (filterH != r && filterH != 1)
3226 return emitOpError(
"expect filter height either equals to r or 1");
3227 if (filterW != r && filterW != 1)
3228 return emitOpError(
"expect filter width either equals to r or 1");
3229 if (filterH == 1 && filterW == 1)
3230 return emitOpError(
"expect either filter height or width equals to r");
3232 SmallVector<int64_t> expectedOutputShape;
3233 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3234 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3235 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3236 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3238 auto outputType = cast<ShapedType>(getOutput().
getType());
3239 ArrayRef<int64_t> outputShape = outputType.getShape();
3241 return emitOpError(
"the output shape is not expected");
3247WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3248 Location loc = getLoc();
3251 Value filter = getFilter();
3252 int64_t filterRank = getFilterOperandRank();
3253 SmallVector<Range> loopBounds(filterRank);
3254 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3255 loopBounds[dim].offset = zeroAttr;
3256 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3257 loopBounds[dim].stride = oneAttr;
3262SmallVector<utils::IteratorType>
3263WinogradFilterTransformOp::getLoopIteratorTypes() {
3264 int64_t filterRank = getFilterOperandRank();
3265 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3266 utils::IteratorType::parallel);
3267 return iteratorTypes;
3270LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3271 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3272 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3273 SmallVector<OpFoldResult> &resultSizes) {
3275 ShapedType filterType = getFilterOperandType();
3276 ArrayRef<int64_t> filterShape = filterType.getShape();
3277 int64_t filterH = filterShape[getFilterHDim()];
3278 int64_t filterW = filterShape[getFilterWDim()];
3279 WinogradConv2DFmr fmr = getFmr();
3282 int64_t alpha = m + r - 1;
3283 int64_t alphaH = filterH != 1 ? alpha : 1;
3284 int64_t alphaW = filterW != 1 ? alpha : 1;
3288 resultOffsets.append(
3289 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3291 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3302FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3303 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3304 ArrayRef<OpFoldResult> sizes) {
3307 ShapedType filterType = getFilterOperandType();
3308 ArrayRef<int64_t> filterShape = filterType.getShape();
3309 int64_t filterH = filterShape[getFilterHDim()];
3310 int64_t filterW = filterShape[getFilterWDim()];
3313 SmallVector<Value> tiledOperands;
3314 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3316 sliceOffsets.append(
3317 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3318 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3319 sizes[getFilterCDim()]});
3320 int64_t filterRank = getFilterOperandRank();
3321 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3322 Location loc = getLoc();
3323 auto filterSlice = tensor::ExtractSliceOp::create(
3324 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3325 tiledOperands.emplace_back(filterSlice);
3327 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3332 int64_t outputRank = getOutputOperandRank();
3333 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3334 auto outputSlice = tensor::ExtractSliceOp::create(
3335 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3336 tiledOperands.emplace_back(outputSlice);
3338 SmallVector<Type> resultTypes;
3339 resultTypes.push_back(tiledOperands[1].
getType());
3340 Operation *tiledOp =
3341 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3343 return TilingResult{
3346 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3353LogicalResult WinogradInputTransformOp::verify() {
3354 auto inputType = cast<ShapedType>(getInput().
getType());
3355 ArrayRef<int64_t> inputShape = inputType.getShape();
3356 int64_t inputH = inputShape[getInputHDim()];
3357 int64_t inputW = inputShape[getInputWDim()];
3358 WinogradConv2DFmr fmr = getFmr();
3361 int64_t tileSize = m + r - 1;
3363 auto outputType = cast<ShapedType>(getOutput().
getType());
3364 ArrayRef<int64_t> outputShape = outputType.getShape();
3365 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3366 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3368 SmallVector<int64_t> expectedOutputShape(6, inputH);
3369 if (ShapedType::isDynamic(inputH)) {
3370 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3371 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3373 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3374 expectedOutputShape[getOutputTileHDim()] =
3375 leftTransform ? (inputH - (r - 1)) / m : inputH;
3377 if (ShapedType::isDynamic(inputW)) {
3378 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3379 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3381 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3382 expectedOutputShape[getOutputTileWDim()] =
3383 rightTransform ? (inputW - (r - 1)) / m : inputW;
3385 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3386 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3389 return emitOpError(
"the output shape is not expected");
3395WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3396 Location loc = getLoc();
3399 Value output = getOutput();
3400 int64_t outputRank = getOutputOperandRank();
3401 SmallVector<Range> loopBounds(outputRank);
3402 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3403 loopBounds[dim].offset = zeroAttr;
3405 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3406 loopBounds[dim].stride = oneAttr;
3411SmallVector<utils::IteratorType>
3412WinogradInputTransformOp::getLoopIteratorTypes() {
3413 int64_t outputRank = getOutputOperandRank();
3414 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3415 utils::IteratorType::parallel);
3416 return iteratorTypes;
3419LogicalResult WinogradInputTransformOp::getResultTilePosition(
3420 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3421 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3422 SmallVector<OpFoldResult> &resultSizes) {
3424 ShapedType outputType = getOutputOperandType();
3425 ArrayRef<int64_t> outputShape = outputType.getShape();
3426 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3427 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3429 WinogradConv2DFmr fmr = getFmr();
3432 int64_t alpha = m + r - 1;
3433 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3434 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3439 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3440 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3441 offsets[getOutputCDim()]});
3442 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3443 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3444 sizes[getOutputCDim()]});
3455FailureOr<TilingResult>
3456WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3457 ArrayRef<OpFoldResult> offsets,
3458 ArrayRef<OpFoldResult> sizes) {
3460 WinogradConv2DFmr fmr = getFmr();
3464 ShapedType outputType = getOutputOperandType();
3465 ArrayRef<int64_t> outputShape = outputType.getShape();
3466 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3467 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3469 Location loc = getLoc();
3471 auto identityAffineMap =
3473 auto offsetAffineMap =
3476 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3477 offsets[getOutputTileHDim()]);
3479 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3480 offsets[getOutputTileWDim()]);
3484 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3486 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3488 SmallVector<Value> tiledOperands;
3489 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3491 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3492 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3493 sliceOffsets.append(
3494 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3495 OpFoldResult sizeH =
3496 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3497 OpFoldResult sizeW =
3498 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3500 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3501 int64_t inputRank = getInputOperandRank();
3502 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3503 auto inputSlice = tensor::ExtractSliceOp::create(
3504 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3505 tiledOperands.emplace_back(inputSlice);
3507 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3512 int64_t outputRank = getOutputOperandRank();
3513 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3514 auto outputSlice = tensor::ExtractSliceOp::create(
3515 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3516 tiledOperands.emplace_back(outputSlice);
3518 SmallVector<Type> resultTypes;
3519 resultTypes.push_back(tiledOperands[1].
getType());
3520 Operation *tiledOp =
3521 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3523 return TilingResult{
3526 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3533LogicalResult WinogradOutputTransformOp::verify() {
3534 auto valueType = cast<ShapedType>(getValue().
getType());
3535 ArrayRef<int64_t> valueShape = valueType.getShape();
3536 int64_t valueH = valueShape[getValueAlphaHDim()];
3537 int64_t valueW = valueShape[getValueAlphaWDim()];
3538 int64_t valueTileH = valueShape[getValueTileHDim()];
3539 int64_t valueTileW = valueShape[getValueTileWDim()];
3540 WinogradConv2DFmr fmr = getFmr();
3543 bool leftTransform = valueH != 1;
3544 bool rightTransform = valueW != 1;
3546 int64_t outputRank = getOutputOperandRank();
3547 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3548 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3549 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3551 if (valueH != (leftTransform ? m + r - 1 : 1))
3552 return emitOpError(
"expect input height equals to input tile size");
3553 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3555 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3556 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3558 if (valueW != (rightTransform ? m + r - 1 : 1))
3559 return emitOpError(
"expect input width equals to input tile size");
3560 expectedOutputShape[getOutputWDim()] =
3561 (rightTransform ? m : 1) * valueTileW;
3563 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3564 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3566 auto outputType = cast<ShapedType>(getOutput().
getType());
3567 ArrayRef<int64_t> outputShape = outputType.getShape();
3569 return emitOpError(
"the output shape is not expected");
3575WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3576 Location loc = getLoc();
3579 Value value = getValue();
3580 int64_t valueRank = getValueOperandRank();
3581 SmallVector<Range> loopBounds(valueRank);
3582 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3583 loopBounds[dim].offset = zeroAttr;
3585 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3586 loopBounds[dim].stride = oneAttr;
3591SmallVector<utils::IteratorType>
3592WinogradOutputTransformOp::getLoopIteratorTypes() {
3593 int64_t valueRank = getValueOperandRank();
3594 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3595 utils::IteratorType::parallel);
3596 return iteratorTypes;
3599LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3600 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3601 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3602 SmallVector<OpFoldResult> &resultSizes) {
3603 WinogradConv2DFmr fmr = getFmr();
3607 Location loc = getLoc();
3609 auto identityAffineMap =
3614 ShapedType valueType = getValueOperandType();
3615 ArrayRef<int64_t> valueShape = valueType.getShape();
3616 int64_t valueH = valueShape[0];
3617 int64_t valueW = valueShape[1];
3619 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3620 offsets[getValueTileHDim()]);
3622 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3623 offsets[getValueTileWDim()]);
3625 builder, loc, affineMap, sizes[getValueTileHDim()]);
3627 builder, loc, affineMap, sizes[getValueTileWDim()]);
3630 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3631 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3632 OpFoldResult sizeH =
3633 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3634 OpFoldResult sizeW =
3635 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3637 resultOffsets.append(
3638 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3640 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3650FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3651 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3652 ArrayRef<OpFoldResult> sizes) {
3655 Location loc = getLoc();
3656 SmallVector<Value> tiledOperands;
3657 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3659 ShapedType valueType = getValueOperandType();
3660 ArrayRef<int64_t> valueShape = valueType.getShape();
3661 int64_t alphaH = valueShape[getValueAlphaHDim()];
3662 int64_t alphaW = valueShape[getValueAlphaWDim()];
3666 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3667 offsets[getValueTileWDim()], offsets[getValueNDim()],
3668 offsets[getValueFDim()]});
3669 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3670 sizes[getValueTileWDim()], sizes[getValueNDim()],
3671 sizes[getValueFDim()]});
3672 int64_t valueRank = getValueOperandRank();
3673 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3674 auto valueSlice = tensor::ExtractSliceOp::create(
3675 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3676 tiledOperands.emplace_back(valueSlice);
3678 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3683 int64_t outputRank = getOutputOperandRank();
3684 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3685 auto outputSlice = tensor::ExtractSliceOp::create(
3686 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3687 tiledOperands.emplace_back(outputSlice);
3689 SmallVector<Type> resultTypes;
3690 resultTypes.push_back(tiledOperands[1].
getType());
3691 Operation *tiledOp =
3692 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3694 return TilingResult{
3697 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3711 llvm::set_union(explicitSet, defaultSet);
3712 return explicitSet == defaultSet;
3732 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3734 auto opIndexingMap = opIndexingMaps[opIndex];
3735 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3738 return matmulOp->emitOpError()
3739 <<
"Unexpected dim expression in map result.";
3742 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3743 return matmulOp->emitOpError()
3744 <<
"Invalid broadcast requested, should be (d2).";
3753template <
typename OpTy>
3756 AffineMap defaultIndexingMap,
bool isLHS) {
3757 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3758 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3759 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3762 return batchVariantMatmulOp->emitOpError()
3763 <<
"Unexpected result dim expression (outside the set of default "
3768 return batchVariantMatmulOp->emitOpError()
3769 <<
"no. of result dim expressions exceeds 3.";
3771 auto hasValidBatchDim = [](
AffineMap map) {
3778 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3779 return batchVariantMatmulOp->emitOpError()
3780 <<
"Invalid broadcast requested.";
3781 }
else if (!hasValidBatchDim(opIndexingMap)) {
3782 return batchVariantMatmulOp->emitOpError()
3783 <<
"Invalid batch dimension expression.";
3791template <
typename OpTy>
3794 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3795 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3796 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3797 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3800 return batchVariantMatmulOp->emitOpError()
3801 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3804 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3806 return batchVariantMatmulOp->emitOpError()
3807 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3811 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3812 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3813 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3814 outputMap.getResult(1).isFunctionOfDim(1) &&
3815 outputMap.getResult(2).isFunctionOfDim(2)
3816 : outputMap.getResult(0).isFunctionOfDim(1) &&
3817 outputMap.getResult(1).isFunctionOfDim(2);
3820 if (!areValidOutputResultDim(opIndexingMap)) {
3821 return batchVariantMatmulOp->emitOpError()
3822 <<
"Invalid output map result dimension.";
3831template <
typename OpTy>
3836 batchVariantMatmulOp.getIndexingMapsArray();
3838 batchVariantMatmulOp.getDefaultIndexingMaps(
3839 batchVariantMatmulOp->getContext());
3841 if (opIndexingMaps.size() != 3)
3842 return batchVariantMatmulOp->emitOpError()
3843 <<
"Indexing_map attribute must have 3 affine maps.";
3845 auto opIndexingMap = opIndexingMaps[opIndex];
3846 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3854 defaultIndexingMap, opIndex == 0)))
3864 if (m == 2 && r == 3)
3865 return WinogradConv2DFmr::F_2_3;
3866 if (m == 4 && r == 3)
3867 return WinogradConv2DFmr::F_4_3;
3868 if (m == 2 && r == 5)
3869 return WinogradConv2DFmr::F_2_5;
3870 return std::nullopt;
3875 case WinogradConv2DFmr::F_2_3:
3877 case WinogradConv2DFmr::F_4_3:
3879 case WinogradConv2DFmr::F_2_5:
3882 llvm_unreachable(
"Unkown WinogradConv2DFmr");
3889static FailureOr<SmallVector<SmallVector<int64_t>>>
3892 for (
auto map : maps) {
3893 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3897 for (
auto result : attr.getAffineMap().getResults()) {
3898 auto dim = dyn_cast<AffineDimExpr>(
result);
3901 pos.push_back(dim.getPosition());
3903 positions.push_back(pos);
3916 return indexingMaps;
3919bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3920 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3923 if (maps.size() != 3)
3928 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3929 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3930 (*positions)[2] == SmallVector<int64_t>{0, 1};
3933SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3934 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3935 utils::IteratorType::parallel,
3936 utils::IteratorType::reduction};
3939unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3941std::string MatmulOp::getLibraryCallName() {
3945bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3949bool MatmulOp::hasUserDefinedMaps() {
3950 SmallVector<AffineMap, 3> defaultMaps =
3952 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3953 return defaultMaps != explicitMaps;
3958void MatmulOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
3959 ArrayRef<NamedAttribute> attrs,
3962 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3967 "MatmulOp regionBuilder expects 3 args");
3968 RegionBuilderHelper helper(
b, block);
3969 SmallVector<Value> yields;
3971 TypeFn castVal = TypeFn::cast_signed;
3972 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3973 return attr.
getName() ==
"cast";
3975 if (castIter != attrs.end()) {
3976 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3984 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3985 if (!value1 || !value2 || !value3)
3987 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
3991 yields.push_back(value4);
3992 helper.yieldOutputs(yields);
4002bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
4003 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
4004 AffineExpr expr = bcastMap.
getResult(0);
4014 ArrayAttr arrayAttr;
4018 if (llvm::any_of(arrayAttr,
4019 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
4021 <<
"element of indexing_maps array is not an affine_map";
4028 if (failed(indexingMapsAttr))
4031 if (*indexingMapsAttr ==
nullptr) {
4032 auto indexingMapAttrs = llvm::map_to_vector(
4033 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4038 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4040 MatmulOp::getRegionBuilder());
4043void MatmulOp::print(OpAsmPrinter &p) {
4044 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4045 MatmulOp::getDefaultIndexingMaps(
getContext()),
4046 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4047 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4048 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4050 std::array<StringRef, 3> elidedAttrs = {
4051 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4057LogicalResult MatmulOp::verify() {
4059 if (!hasUserDefinedMaps())
4062 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
4069LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4073void MatmulOp::getEffects(
4074 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4076 if (hasPureTensorSemantics())
4085SmallVector<AffineMap>
4086MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4087 AffineExpr d0, d1, d2;
4093 return {mapLHS, mapRHS, mapOut};
4097 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4100 if (maps.size() != 3)
4103 if (failed(positions))
4115 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4123 build(builder, state, inputs, outputs, attributes);
4124 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4125 assert(res &&
"builder didn't return the right type");
4135 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4144 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4145 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4146 assert(res &&
"builder didn't return the right type");
4156 result.addAttribute(
"cast", cast);
4158 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4167 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4168 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4169 assert(res &&
"builder didn't return the right type");
4174 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4176 op->
getAttr(
"indexing_maps"));
4180MatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4187 return {mapLHS, mapRHS, mapOut};
4191 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4194 if (maps.size() != 3)
4197 if (failed(positions))
4209 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4217 build(builder, state, inputs, outputs, attributes);
4218 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4219 assert(res &&
"builder didn't return the right type");
4229 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4238 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4239 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4240 assert(res &&
"builder didn't return the right type");
4250 result.addAttribute(
"cast", cast);
4252 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4261 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4262 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4263 assert(res &&
"builder didn't return the right type");
4268 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4270 op->
getAttr(
"indexing_maps"));
4274BatchMatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4281 return {mapLHS, mapRHS, mapOut};
4285 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4288 if (maps.size() != 3)
4291 if (failed(positions))
4302 BatchMatmulOp::getRegionBuilder(),
4303 getDefaultIndexingMaps(builder));
4311 build(builder, state, inputs, outputs, attributes);
4312 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4313 assert(res &&
"builder didn't return the right type");
4322 BatchMatmulOp::getRegionBuilder(),
4323 getDefaultIndexingMaps(builder));
4332 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4333 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4334 assert(res &&
"builder didn't return the right type");
4342 result.addAttribute(
"cast", cast);
4344 BatchMatmulOp::getRegionBuilder(),
4345 getDefaultIndexingMaps(builder));
4354 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4355 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4356 assert(res &&
"builder didn't return the right type");
4361 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4363 op->
getAttr(
"indexing_maps"));
4367BatchMatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4374 return {mapLHS, mapRHS, mapOut};
4378 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4381 if (maps.size() != 3)
4384 if (failed(positions))
4395 BatchMatmulOp::getRegionBuilder(),
4396 getDefaultIndexingMaps(builder));
4404 build(builder, state, inputs, outputs, attributes);
4405 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4406 assert(res &&
"builder didn't return the right type");
4415 BatchMatmulOp::getRegionBuilder(),
4416 getDefaultIndexingMaps(builder));
4425 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4426 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4427 assert(res &&
"builder didn't return the right type");
4435 result.addAttribute(
"cast", cast);
4437 BatchMatmulOp::getRegionBuilder(),
4438 getDefaultIndexingMaps(builder));
4447 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4448 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4449 assert(res &&
"builder didn't return the right type");
4454 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4456 op->
getAttr(
"indexing_maps"));
4464 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4475 auto dimExpr = dyn_cast<AffineDimExpr>(
result);
4476 assert(dimExpr &&
"affine_map is a projected permutation");
4477 dimsInOutput[dimExpr.getPosition()] =
true;
4481 for (
auto dimOccursInOutput : dimsInOutput)
4482 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4483 : utils::IteratorType::reduction);
4485 return iteratorTypes;
4488unsigned ContractOp::getNumRegionArgs() {
return 3; }
4491void ContractOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
4492 ArrayRef<NamedAttribute> attrs,
4495 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
4500 "ContractOp regionBuilder expects 3 args");
4501 RegionBuilderHelper helper(
b, block);
4503 TypeFn castSignedness = TypeFn::cast_signed;
4504 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4505 return attr.
getName() ==
"cast";
4507 if (castIter != attrs.end()) {
4508 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4514 Value lhsAtOutType =
4515 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
4516 Value rhsAtOutType =
4517 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
4518 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4520 if (!productAtOutType)
4526 helper.yieldOutputs({
result});
4529ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &
result) {
4531 if (
failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
4533 "expected 'indexing_maps' attribute");
4534 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4540void ContractOp::print(OpAsmPrinter &p) {
4541 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4543 p, getOperation(), getInputs(), getOutputs(),
4544 {
"indexing_maps",
"operandSegmentSizes"});
4547LogicalResult ContractOp::verify() {
4548 int iterationSpaceDims = -1;
4553 SmallVector<size_t> inOccurrences;
4554 SmallVector<size_t> outOccurrences;
4557 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4558 bool isInput) -> LogicalResult {
4561 return emitError(
"provided affine_map is not a projected permutation");
4564 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
4566 return emitError(
"ranks of shaped operand and results of corresponding "
4567 "affine_map differ");
4569 return emitError(
"affine_map specifies shaped access while operand has "
4574 if (iterationSpaceDims == -1) {
4576 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4577 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4578 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4579 return emitError(
"iteration spaces of provided affine_maps differ");
4583 for (AffineExpr affineExpr : affineMap.
getResults()) {
4584 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4586 llvm_unreachable(
"affine_map is a projected permutation");
4589 inOccurrences[affineDimExpr.getPosition()] += 1;
4591 outOccurrences[affineDimExpr.getPosition()] += 1;
4597 for (
auto &&[affineMap, operandType, isInput] :
4598 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4599 SmallVector<bool>{
true,
true,
false})) {
4600 if (
failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4604 bool hasContractingDim =
false;
4605 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4606 size_t inOccCount = inOccurrences[dimIndex];
4607 size_t outOccCount = outOccurrences[dimIndex];
4610 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4612 if (inOccCount == 0 && outOccCount == 0)
4613 return emitError() <<
"iteration space dim at index " << dimIndex
4614 <<
" not used to access any operand";
4625 if (inOccCount == 1 && outOccCount != 1)
4627 <<
"iteration space dim at index " << dimIndex
4628 <<
" is neither a contracting dim nor of parallel iteration type";
4631 if (!hasContractingDim)
4632 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4637LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4641void ContractOp::getEffects(
4642 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4644 if (hasPureTensorSemantics())
4656SmallVector<AffineMap>
4657BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4658 AffineExpr d0, d1, d2, d3;
4659 SmallVector<AffineMap> indexingMaps;
4661 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4662 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4663 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4664 return indexingMaps;
4667bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4668 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4671 if (maps.size() != 3)
4676 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4677 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4678 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4681SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4682 return SmallVector<utils::IteratorType>{
4683 utils::IteratorType::parallel, utils::IteratorType::parallel,
4684 utils::IteratorType::parallel, utils::IteratorType::reduction};
4687unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4689std::string BatchMatmulOp::getLibraryCallName() {
4695bool BatchMatmulOp::hasUserDefinedMaps() {
4696 SmallVector<AffineMap, 3> defaultMaps =
4698 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4699 return defaultMaps != explicitMaps;
4709bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
bool isLHS) {
4711 "Expected less than 3 result dim expr.");
4712 bool isValid =
false;
4713 enum Indices { batchPos, mPos, nPos, kPos };
4715 AffineExpr expr = bcastMap.
getResult(0);
4718 AffineExpr expr0 = bcastMap.
getResult(0);
4719 AffineExpr expr1 = bcastMap.
getResult(1);
4724 : ((expr0.isFunctionOfDim(batchPos) &&
4725 expr1.isFunctionOfDim(kPos)) ||
4726 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4731void BatchMatmulOp::regionBuilder(
4732 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4735 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4740 "BatchMatmulOp regionBuilder expects 3 args");
4741 RegionBuilderHelper helper(
b, block);
4742 SmallVector<Value> yields;
4744 TypeFn castVal = TypeFn::cast_signed;
4745 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4746 return attr.
getName() ==
"cast";
4748 if (castIter != attrs.end()) {
4749 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4754 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4755 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4757 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
4758 if (!castValA || !castValB || !mulVal)
4760 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
4764 yields.push_back(addVal);
4765 helper.yieldOutputs(yields);
4768ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &
result) {
4769 SmallVector<Attribute, 3> indexingMapsAttr;
4781 if (!isa<AffineMapAttr>(mapAttr)) {
4783 "expected affine map attribute");
4785 indexingMapsAttr.push_back(mapAttr);
4795 if (indexingMapsAttr.empty()) {
4796 indexingMapsAttr = llvm::map_to_vector(
4797 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4798 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4800 result.addAttribute(
"indexing_maps",
4803 return ::parseNamedStructuredOp(parser,
result,
4804 BatchMatmulOp::getNumRegionArgs(),
4805 BatchMatmulOp::getRegionBuilder());
4808void BatchMatmulOp::print(OpAsmPrinter &p) {
4809 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4810 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4811 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4812 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4813 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4815 std::array<StringRef, 3> elidedAttrs = {
4816 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4822LogicalResult BatchMatmulOp::verify() {
4825 if (!hasUserDefinedMaps())
4828 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4835LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4836 SmallVectorImpl<OpFoldResult> &) {
4840void BatchMatmulOp::getEffects(
4841 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4843 if (hasPureTensorSemantics())
4857struct ArityGroupAndKind {
4859 ElementwiseArityGroup arityGroup;
4865 TernaryFn ternaryFn;
4869unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4870 return static_cast<unsigned>(arityGroup);
4875 constexpr int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4876 constexpr int lastBinary =
4877 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4878 constexpr int lastTernary =
4879 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4881 int val =
static_cast<int>(kind);
4882 ArityGroupAndKind
result;
4884 if (val < lastUnary) {
4885 result.arityGroup = ElementwiseArityGroup::Unary;
4886 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4889 if (val < lastBinary) {
4890 result.arityGroup = ElementwiseArityGroup::Binary;
4891 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4894 if (val >= lastTernary) {
4895 llvm_unreachable(
"unhandled ElementwiseFn");
4897 result.arityGroup = ElementwiseArityGroup::Ternary;
4898 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4903 auto rank = getResultRank();
4908ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4914ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &
result) {
4917 mlir::linalg::ElementwiseKind elemwiseKindVal;
4922 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4923 if (!elemwiseKindAttr)
4925 "expected ElementwiseKind attribute");
4926 elemwiseKindVal = elemwiseKindAttr.getValue();
4929 "expected operation 'kind' attribute");
4932 "kind", ElementwiseKindAttr::get(parser.
getContext(), elemwiseKindVal));
4935 SmallVector<Attribute, 3> indexingMapsAttr;
4945 if (!isa<AffineMapAttr>(mapAttr))
4947 "expected affine map attribute");
4948 indexingMapsAttr.push_back(mapAttr);
4959 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4961 ElementwiseOp::getRegionBuilder())) {
4963 "unable to parse elemwise op");
4967 if (indexingMapsAttr.empty()) {
4970 auto resultType =
result.operands[
result.operands.size() - 1].getType();
4971 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4974 "return type needs to be shaped type");
4975 auto numDims = shapedType.getRank();
4976 indexingMapsAttr = llvm::map_to_vector(
4977 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4979 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4982 result.addAttribute(
"indexing_maps",
4987void ElementwiseOp::print(OpAsmPrinter &p) {
4990 SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes",
"kind",
4994 unsigned numDims = getResultRank();
4996 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4997 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4999 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
5001 if (!llvm::equal(getIndexingMaps(), indexingMaps))
5002 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
5010void ElementwiseOp::regionBuilder(
5011 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
5013 ElementwiseKind elemwiseKind;
5014 for (
auto attr : attrs) {
5015 if (attr.getName() ==
b.getStringAttr(
"kind")) {
5016 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
5017 assert(kindAttr &&
"op kind attribute incorrectly set");
5018 elemwiseKind = kindAttr.getValue();
5024 auto arityGroup = groupAndKind.arityGroup;
5025 auto kind = groupAndKind.kind;
5027 getArityGroupAsUInt(arityGroup) + 1 ) {
5028 emitError() <<
"Elementwise regionBuilder expects "
5029 << (getArityGroupAsUInt(arityGroup) + 1) <<
" args, got "
5034 getArityGroupAsUInt(arityGroup) + 1
5035 &&
"Elementwise regionBuilder number of block args mismatch");
5037 RegionBuilderHelper helper(
b, block);
5038 SmallVector<Value> yields;
5041 if (arityGroup == ElementwiseArityGroup::Unary) {
5044 }
else if (arityGroup == ElementwiseArityGroup::Binary) {
5048 }
else if (arityGroup == ElementwiseArityGroup::Ternary) {
5053 assert(
false &&
"found unhandled category in elemwise");
5056 yields.push_back(
result);
5057 helper.yieldOutputs(yields);
5060LogicalResult ElementwiseOp::fold(FoldAdaptor,
5061 SmallVectorImpl<OpFoldResult> &) {
5065void ElementwiseOp::getEffects(
5066 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5068 if (hasPureTensorSemantics())
5081template <
typename OpTy,
typename>
5084 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5085 ? packOrUnPack.getDestType()
5086 : packOrUnPack.getSourceType();
5087 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5088 ? packOrUnPack.getSourceType()
5089 : packOrUnPack.getDestType();
5091 packedType.getShape().take_front(unpackedType.getRank()));
5092 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5113 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5115 .take_back(mixedTiles.size()),
5117 int64_t dimSize = std::get<0>(it);
5118 if (dimSize == ShapedType::kDynamic) {
5119 newMixedTileSizes.push_back(std::get<1>(it));
5122 newMixedTileSizes.push_back(rewriter.
getIndexAttr(dimSize));
5125 return newMixedTileSizes;
5128template <
typename OpTy>
5132 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5133 "applies to only pack or unpack operations");
5134 int64_t destRank = op.getDestRank();
5136 for (
auto dim : llvm::seq<int64_t>(0, destRank))
5137 reifiedReturnShapes[0][dim] =
5142template <
typename OpTy>
5144 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5145 "applies to only pack or unpack operations");
5149 assert(tiles.size() == dimsToTile.size() &&
5150 "tiles must match indices of dimension to block");
5152 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5153 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5154 return dimAndTileMapping;
5157template <
typename OpTy>
5159 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5160 "applies to only pack or unpack operations");
5163 unsigned dynamicValIndex = 0;
5164 for (
int64_t staticTile : op.getStaticInnerTiles()) {
5165 if (ShapedType::isStatic(staticTile))
5168 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5170 return mixedInnerTiles;
5173template <
typename OpTy>
5175 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5176 "applies to only pack or unpack operations");
5189 size_t dimsPosSize = dimsPos.size();
5190 if (dimsPosSize > rank)
5193 if (dimsPosSize != uniqued.size())
5195 return llvm::any_of(dimsPos, [rank](
int64_t dimPos) {
5196 return dimPos < 0 || dimPos >=
static_cast<int64_t>(rank);
5200template <
typename OpTy>
5202 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5203 "applies to only pack or unpack operations");
5204 Operation *op = packOrUnPack.getOperation();
5214 if (!packOrUnPack.getSourceType().hasRank() ||
5215 !packOrUnPack.getDestType().hasRank())
5216 return op->
emitError(
"expected both source and destination to have rank");
5219 if (!packOrUnPack.hasPureBufferSemantics() &&
5220 !packOrUnPack.hasPureTensorSemantics())
5221 return op->
emitError(
"mixing tensor and buffer semantics is not allowed");
5222 const unsigned numResults = packOrUnPack.getNumResults();
5223 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5224 return op->
emitError(
"expected 1 result, got ") << numResults;
5225 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5226 return op->
emitError(
"expected 0 results, got ") << numResults;
5230 if (hasZeros(mixedTiles))
5231 return op->
emitError(
"invalid zero tile factor");
5234 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5235 ? packOrUnPack.getSourceType()
5236 : packOrUnPack.getDestType();
5237 size_t unpackedRank = unpackedType.getRank();
5241 return op->
emitError(
"invalid inner_dims_pos vector");
5243 return op->
emitError(
"invalid outer_dims_perm vector");
5244 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5245 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
5249 if (mixedTiles.size() > unpackedRank) {
5250 return op->
emitError(
"tiling factors must be less than or equal to the "
5251 "input rank for pack or output rank for unpack");
5253 if (mixedTiles.size() != innerDimsPos.size()) {
5255 "tiling factors must equal the number of dimensions to tile");
5258 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5259 ? packOrUnPack.getDestType()
5260 : packOrUnPack.getSourceType();
5261 size_t packedRank = packedType.getRank();
5263 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5264 if (expectedPackedRank != packedRank) {
5266 "packed rank != (unpacked rank + num tiling factors), got ")
5267 << packedRank <<
" != " << expectedPackedRank;
5274 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5275 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5276 for (
auto it : llvm::enumerate(llvm::zip(
5277 packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
5278 int64_t dimSize = std::get<0>(it.value());
5280 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it.value()))) {
5281 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5282 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5283 if (dimSize != staticTileSize)
5285 "mismatch in inner tile sizes specified and shaped of "
5286 "tiled dimension in the packed type at index ")
5287 << it.index() <<
": got " << dimSize <<
" != " << staticTileSize;
5288 }
else if (!ShapedType::isDynamic(dimSize)) {
5289 return op->
emitError(
"mismatch in inner tile sizes specified at index ")
5290 << it.index() <<
": got static shape " << dimSize
5291 <<
" but dynamic tile size";
5296 auto elementType = unpackedType.getElementType();
5297 Type expectedType, actualType;
5298 if (packOrUnPack.hasPureTensorSemantics()) {
5299 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5300 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5302 expectedType = MemRefType::get(expectedPackedShape, elementType);
5303 actualType = MemRefType::get(packedType.getShape(), elementType);
5306 << expectedType <<
" for the packed domain value, got "
5319struct PackOrUnPackTransposeResult {
5326template <
typename OpTy>
5327static PackOrUnPackTransposeResult
5331 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5332 "applies to only pack or unpack operations");
5333 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5334 "some permutation must be non-empty");
5335 PackOrUnPackTransposeResult metadata;
5336 metadata.innerDimsPos =
5338 metadata.innerTiles =
5340 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5341 ? packOrUnPackOp.getSourceRank()
5342 : packOrUnPackOp.getDestRank();
5343 metadata.outerDimsPerm =
5344 packOrUnPackOp.getOuterDimsPerm().empty()
5345 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5347 if (!innerPermutation.empty()) {
5348 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5350 "invalid inner permutation");
5354 if (!outerPermutation.empty()) {
5355 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5357 "invalid outer permutation");
5368 if (!getResults().empty())
5369 setNameFn(getResult(),
"pack");
5379 Type sourceType, destType, resultType;
5396 SmallVector<int64_t> outerDimsPermVec;
5399 if (parser.parseInteger(value))
5401 outerDimsPermVec.push_back(value);
5411 SmallVector<int64_t> innerDimsPosVec;
5414 if (parser.parseInteger(value))
5416 innerDimsPosVec.push_back(value);
5428 for (
auto val : staticTilesAttr.
asArrayRef())
5429 staticTiles.push_back(val);
5446 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5449 "pack/unpack requires '->' and destination type");
5453 resultType = destType;
5459 if (!paddingValue.empty() &&
5464 if (!dynamicTiles.empty() &&
5469 result.addAttribute(
"static_inner_tiles",
5471 result.addAttribute(
"inner_dims_pos", innerDimsPos);
5473 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
5475 SmallVector<int32_t> segmentSizes = {
5476 1, 1,
static_cast<int32_t
>(paddingValue.size()),
5477 static_cast<int32_t
>(dynamicTiles.size())};
5478 result.addAttribute(
"operandSegmentSizes",
5482 result.addTypes(resultType);
5487void PackOp::print(OpAsmPrinter &p) {
5488 p <<
" " << getSource();
5490 if (getPaddingValue()) {
5491 p <<
" padding_value(" << getPaddingValue() <<
" : "
5492 << getPaddingValue().getType() <<
")";
5495 if (!getOuterDimsPerm().empty()) {
5496 p <<
" outer_dims_perm = [";
5497 llvm::interleaveComma(getOuterDimsPerm(), p);
5501 p <<
" inner_dims_pos = [";
5502 llvm::interleaveComma(getInnerDimsPos(), p);
5505 p <<
" inner_tiles = ";
5508 p <<
" into " << getDest();
5511 {
"static_inner_tiles",
"inner_dims_pos",
5512 "outer_dims_perm",
"operandSegmentSizes"});
5514 p <<
" : " << getSource().getType();
5515 p <<
" -> " << getDest().getType();
5518void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5519 Value dest, ArrayRef<int64_t> innerDimsPos,
5520 ArrayRef<OpFoldResult> innerTiles,
5521 std::optional<Value> paddingValue,
5522 ArrayRef<int64_t> outerDimsPerm) {
5523 assert(innerDimsPos.size() == innerTiles.size() &&
5524 "number of tile sizes specified must match the specified number of "
5525 "original dimensions to be tiled");
5526 SmallVector<int64_t> staticTileSizes;
5527 SmallVector<Value> dynamicTileSizes;
5529 build(builder, state, dest.
getType(), source, dest,
5530 paddingValue ? *paddingValue :
nullptr,
5531 outerDimsPerm.empty() ?
nullptr
5538PackOp::reifyResultShapes(OpBuilder &builder,
5547SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5551SmallVector<int64_t> PackOp::getStaticTiles() {
5555ArrayRef<int64_t> PackOp::getAllOuterDims() {
5556 ShapedType inputType = getSourceType();
5557 int64_t inputRank = inputType.getRank();
5558 return getDestType().getShape().take_front(inputRank);
5561SmallVector<int64_t> PackOp::getTiledOuterDims() {
5562 auto innerDimsPos = getInnerDimsPos();
5563 SmallVector<int64_t> outerDims(getAllOuterDims());
5564 SmallVector<int64_t> res;
5567 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5569 if (!outerDimPermInv.empty())
5573 for (
auto index : innerDimsPos)
5574 res.push_back(outerDims[index]);
5579bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5580 ArrayRef<int64_t> innerDimsPos,
5581 ArrayRef<int64_t> outputShape,
5582 ArrayRef<int64_t> outerDimsPerm,
5583 ArrayRef<OpFoldResult> innerTiles) {
5584 SmallVector<int64_t> outputTileSizes(
5585 outputShape.take_front(inputShape.size()));
5586 if (!outerDimsPerm.empty()) {
5587 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5588 "expected output and outer_dims_perm to have same size");
5592 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5593 if (ShapedType::isDynamic(inputShape[pos]))
5596 if (!constantTile) {
5597 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5598 (inputShape[pos] % outputTileSizes[pos] != 0))
5601 assert(*constantTile != 0 &&
"static tile size can't be zero");
5602 if (inputShape[pos] % (*constantTile) != 0) {
5610bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5611 ArrayRef<int64_t> innerDimsPos,
5612 ArrayRef<int64_t> outputShape,
5613 ArrayRef<int64_t> outerDimsPerm,
5614 ArrayRef<OpFoldResult> innerTiles) {
5615 SmallVector<int64_t> outputTileSizes(
5616 outputShape.take_front(inputShape.size()));
5617 if (!outerDimsPerm.empty()) {
5618 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5619 "expected output and outer_dims_perm to have same size");
5623 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5624 if (ShapedType::isDynamic(inputShape[pos]) ||
5625 ShapedType::isDynamic(outputTileSizes[pos]))
5630 assert(*constantTile != 0 &&
"static tile size can't be zero");
5631 if (inputShape[pos] % (*constantTile) != 0)
5637LogicalResult PackOp::verify() {
5644 auto paddingValue = getPaddingValue();
5648 << getSourceType().getElementType()
5649 <<
" but got: " << paddingValue.getType();
5652 if (!paddingValue &&
5653 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
5654 getDestType().
getShape(), getOuterDimsPerm(),
5657 "invalid tile factor or output size provided. Only full tiles are "
5658 "supported when padding_value is not set");
5665static SmallVector<int64_t>
5668 for (
auto o : ofrs) {
5670 if (llvm::dyn_cast_if_present<Value>(o))
5671 result.push_back(ShapedType::kDynamic);
5683 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5684 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5686 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5687 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5690 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5691 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5695 if (!outerDimsPerm.empty())
5699 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5703SmallVector<OpFoldResult> PackOp::getResultShape(
5704 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5705 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5706 ArrayRef<int64_t> outerDimsPerm) {
5707 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5711 AffineExpr ceilDivExpr = s0.
ceilDiv(s1);
5712 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5714 builder, loc, ceilDivExpr,
5715 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5717 if (!outerDimsPerm.empty())
5719 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5721 SmallVector<int64_t> resultTypeShape =
5724 innerDimsPos, outerDimsPerm);
5730 for (
unsigned i = 0; i < resultDims.size(); ++i) {
5731 if (ShapedType::isStatic(resultTypeShape[i]))
5740RankedTensorType PackOp::inferPackedTensorType(
5741 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5742 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5743 SmallVector<int64_t> resultShape = inferPackedShape(
5744 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5745 return RankedTensorType::get(resultShape, sourceType.getElementType());
5748MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5749 ArrayRef<int64_t> innerTileSizes,
5750 ArrayRef<int64_t> innerDimsPos,
5751 ArrayRef<int64_t> outerDimsPerm) {
5752 SmallVector<int64_t> resultShape = inferPackedShape(
5753 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5754 return MemRefType::get(resultShape, sourceType.getElementType());
5757Value PackOp::createDestinationTensor(OpBuilder &
b, Location loc, Value source,
5758 ArrayRef<OpFoldResult> innerTileSizes,
5759 ArrayRef<int64_t> innerDimsPos,
5760 ArrayRef<int64_t> outerDimsPerm) {
5761 AffineExpr dim0, dim1;
5763 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5768 SmallVector<OpFoldResult> mixedSizes;
5769 for (
auto [index, value] : llvm::enumerate(
5770 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
5771 if (ShapedType::isDynamic(value))
5772 mixedSizes.push_back(
5773 tensor::DimOp::create(
b, loc, source, index).getResult());
5775 mixedSizes.push_back(
b.getIndexAttr(value));
5777 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5778 int64_t dimPos = std::get<0>(it);
5779 OpFoldResult tileSize = std::get<1>(it);
5780 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5782 if (!outerDimsPerm.empty())
5785 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5786 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5787 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
5790PackOp PackOp::createTransposedClone(OpBuilder &
b, Location loc,
5791 ArrayRef<int64_t> innerPermutation,
5792 ArrayRef<int64_t> outerPermutation) {
5794 *
this, innerPermutation, outerPermutation);
5795 Value transposedDest =
5796 createDestinationTensor(
b, loc, getSource(), metadata.innerTiles,
5797 metadata.innerDimsPos, metadata.outerDimsPerm);
5798 return PackOp::create(
b, loc, getSource(), transposedDest,
5799 metadata.innerDimsPos, metadata.innerTiles,
5800 getPaddingValue(), metadata.outerDimsPerm);
5803template <
typename OpTy>
5808 if (op.hasPureTensorSemantics())
5811 for (
OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5812 if (!llvm::isa<MemRefType>(opOperand.
get().
getType()))
5815 if (&opOperand == &op.getSourceMutable()) {
5819 }
else if (&opOperand == &op.getDestMutable()) {
5830void PackOp::getEffects(
5836void UnPackOp::getEffects(
5843template <
typename OpTy>
5845 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5846 "applies to only pack or unpack operations");
5847 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5849 : op.getSourceType();
5851 for (
auto [dimDest,
tile] : llvm::zip(
5852 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5854 if (!constTileSize || ShapedType::isDynamic(dimDest))
5861 if (!hasPureTensorSemantics())
5863 if (getPaddingValue())
5878 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5880 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5892 auto packTiles = packOp.getMixedTiles();
5893 auto unPackTiles = unPackOp.getMixedTiles();
5894 if (packTiles.size() != unPackTiles.size())
5896 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5905 auto srcType = op.getSourceType();
5906 auto innerDimsPos = op.getInnerDimsPos();
5907 auto innerTiles = op.getStaticInnerTiles();
5908 if (ShapedType::isDynamicShape(innerTiles))
5910 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5911 if (srcType.isDynamicDim(pos) && tileSize != 1)
5914 return !PackOp::requirePaddingValue(
5915 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5916 op.getOuterDimsPerm(), op.getMixedTiles());
5923 bool changeNeeded =
false;
5924 srcShape.assign(packOp.getSourceType().getShape().begin(),
5925 packOp.getSourceType().getShape().end());
5926 destShape.assign(packOp.getDestType().getShape().begin(),
5927 packOp.getDestType().getShape().end());
5928 llvm::SmallSetVector<int64_t, 4> innerDims;
5929 innerDims.insert_range(packOp.getInnerDimsPos());
5931 if (!packOp.getOuterDimsPerm().empty())
5933 int srcRank = packOp.getSourceRank();
5934 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5935 if (innerDims.contains(i))
5939 if (!inverseOuterDimsPerm.empty())
5940 destPos = inverseOuterDimsPerm[srcPos];
5941 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5942 ShapedType::isDynamic(destShape[destPos])) {
5945 int64_t size = srcShape[srcPos];
5946 if (ShapedType::isDynamic(size))
5947 size = destShape[destPos];
5948 srcShape[srcPos] = size;
5949 destShape[destPos] = size;
5950 changeNeeded =
true;
5952 return changeNeeded;
5955LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5957 if (!packOp.hasPureTensorSemantics())
5961 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5962 if (unPackOp.getSourceType() == packOp.getDestType() &&
5963 !packOp.getPaddingValue() &&
5966 rewriter.
replaceOp(packOp, unPackOp.getSource());
5974 packOp.getPaddingValueMutable().clear();
5980 SmallVector<int64_t> srcShape, destShape;
5982 Location loc = packOp.getLoc();
5983 Value source = packOp.getSource();
5984 if (srcShape != packOp.getSourceType().getShape()) {
5985 auto newSrcType = packOp.getSourceType().clone(srcShape);
5987 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5989 Value dest = packOp.getDest();
5990 ShapedType originalResultType = packOp.getDestType();
5991 bool needUpdateDestType = (destShape != originalResultType.getShape());
5992 if (needUpdateDestType) {
5993 auto newDestType = packOp.getDestType().clone(destShape);
5995 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5998 packOp.getSourceMutable().assign(source);
5999 packOp.getDestMutable().assign(dest);
6000 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
6003 if (needUpdateDestType) {
6005 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
6006 packOp.getResult());
6015template <
typename PackOrUnpackOp>
6017 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
6018 std::is_same<PackOrUnpackOp, UnPackOp>::value,
6019 "Function meant for pack/unpack");
6024 int64_t numPackedDims = innerDimsPos.size();
6025 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
6026 if (orderedDims != innerDimsPos) {
6032 int64_t packedRank = packedTensorType.getRank();
6042 return llvm::all_of(
6043 llvm::seq<int64_t>(0, packedRank - numPackedDims),
6044 [&packedShape](
int64_t i) {
return packedShape[i] == 1; });
6047bool PackOp::isLikePad() {
6048 auto packedTensorType =
6049 llvm::cast<ShapedType>((*this)->getResultTypes().front());
6053::mlir::LogicalResult
6054PackOp::fold(FoldAdaptor adaptor,
6056 if (!hasPureTensorSemantics())
6058 std::optional<Attribute> paddingValue;
6059 if (
auto pad = adaptor.getPaddingValue())
6061 if (
OpFoldResult reshapedSource = reshapeConstantSource(
6062 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6063 cast<TensorType>(getDestType()), paddingValue)) {
6064 results.push_back(reshapedSource);
6090 if (!op.hasPureTensorSemantics())
6111 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6112 op.getInnerDimsPos(), newMixedTileSizes,
6113 op.getPaddingValue(), op.getOuterDimsPerm());
6114 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6117 Value oldResult = op.getResult();
6118 Value newResult = newOp.getResult();
6121 ? tensor::CastOp::create(rewriter, op->getLoc(),
6122 oldResult.
getType(), newResult)
6135void UnPackOp::getAsmResultNames(
6137 if (!getResults().empty())
6138 setNameFn(getResult(),
"unpack");
6147 Type sourceType, destType, resultType;
6159 if (parser.parseInteger(value))
6161 outerDimsPermVec.push_back(value);
6171 SmallVector<int64_t> innerDimsPosVec;
6174 if (parser.parseInteger(value))
6176 innerDimsPosVec.push_back(value);
6188 for (
auto val : staticTilesAttr.
asArrayRef())
6189 staticTiles.push_back(val);
6206 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6209 "pack/unpack requires '->' and destination type");
6213 resultType = destType;
6219 if (!dynamicTiles.empty() &&
6224 result.addAttribute(
"static_inner_tiles",
6226 result.addAttribute(
"inner_dims_pos", innerDimsPos);
6228 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
6230 SmallVector<int32_t> segmentSizes = {
6231 1, 1, 0,
static_cast<int32_t
>(dynamicTiles.size())};
6232 result.addAttribute(
"operandSegmentSizes",
6236 result.addTypes(resultType);
6241void UnPackOp::print(OpAsmPrinter &p) {
6242 p <<
" " << getSource();
6244 if (!getOuterDimsPerm().empty()) {
6245 p <<
" outer_dims_perm = [";
6246 llvm::interleaveComma(getOuterDimsPerm(), p);
6250 p <<
" inner_dims_pos = [";
6251 llvm::interleaveComma(getInnerDimsPos(), p);
6254 p <<
" inner_tiles = ";
6257 p <<
" into " << getDest();
6260 {
"static_inner_tiles",
"inner_dims_pos",
6261 "outer_dims_perm",
"operandSegmentSizes"});
6263 p <<
" : " << getSource().getType();
6264 p <<
" -> " << getDest().getType();
6268UnPackOp::reifyResultShapes(OpBuilder &builder,
6277SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6281SmallVector<int64_t> UnPackOp::getStaticTiles() {
6285ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6286 ShapedType destType = getDestType();
6287 int64_t destRank = destType.getRank();
6288 return getSourceType().getShape().take_front(destRank);
6291SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6292 auto innerDimsPos = getInnerDimsPos();
6293 SmallVector<int64_t> outerDims(getAllOuterDims());
6294 SmallVector<int64_t> res;
6297 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6299 if (!outerDimPermInv.empty())
6303 for (
auto index : innerDimsPos)
6304 res.push_back(outerDims[index]);
6309LogicalResult UnPackOp::verify() {
6314 if (!hasPureTensorSemantics())
6323void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6324 Value dest, ArrayRef<int64_t> innerDimsPos,
6325 ArrayRef<OpFoldResult> innerTiles,
6326 ArrayRef<int64_t> outerDimsPerm) {
6327 assert(innerDimsPos.size() == innerTiles.size() &&
6328 "number of tile sizes specified must match the specified number of "
6329 "original dimensions to be tiled");
6330 SmallVector<int64_t> staticTileSizes;
6331 SmallVector<Value> dynamicTileSizes;
6333 build(builder, state, dest.
getType(), source, dest,
6334 outerDimsPerm.empty() ?
nullptr
6340Value UnPackOp::createDestinationTensor(OpBuilder &
b, Location loc,
6342 ArrayRef<OpFoldResult> innerTileSizes,
6343 ArrayRef<int64_t> innerDimsPos,
6344 ArrayRef<int64_t> outerDimsPerm) {
6345 AffineExpr sym0, sym1;
6347 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6351 SmallVector<OpFoldResult> mixedSizes;
6352 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
6354 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6355 if (srcType.isDynamicDim(i))
6356 mixedSizes.push_back(
6357 tensor::DimOp::create(
b, loc, source, i).getResult());
6359 mixedSizes.push_back(
b.getIndexAttr(srcType.getDimSize(i)));
6361 if (!outerDimsPerm.empty()) {
6366 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6367 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6369 auto elemType = srcType.getElementType();
6370 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
6373UnPackOp UnPackOp::createTransposedClone(OpBuilder &
b, Location loc,
6374 Value transposedSource,
6375 ArrayRef<int64_t> innerPermutation,
6376 ArrayRef<int64_t> outerPermutation) {
6378 *
this, innerPermutation, outerPermutation);
6379 return UnPackOp::create(
b, loc, transposedSource, getDest(),
6380 metadata.innerDimsPos, metadata.innerTiles,
6381 metadata.outerDimsPerm);
6388 bool changeNeeded =
false;
6389 srcShape.assign(op.getSourceType().getShape().begin(),
6390 op.getSourceType().getShape().end());
6391 destShape.assign(op.getDestType().getShape().begin(),
6392 op.getDestType().getShape().end());
6393 llvm::SmallSetVector<int64_t, 4> innerDims;
6394 innerDims.insert_range(op.getInnerDimsPos());
6396 if (!op.getOuterDimsPerm().empty())
6398 int destRank = op.getDestRank();
6399 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
6400 if (innerDims.contains(i))
6404 if (!inverseOuterDimsPerm.empty())
6405 srcPos = inverseOuterDimsPerm[destPos];
6406 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6407 ShapedType::isDynamic(destShape[destPos])) {
6410 int64_t size = srcShape[srcPos];
6411 if (ShapedType::isDynamic(size))
6412 size = destShape[destPos];
6413 srcShape[srcPos] = size;
6414 destShape[destPos] = size;
6415 changeNeeded =
true;
6417 return changeNeeded;
6420LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6423 if (!unPackOp.hasPureTensorSemantics())
6427 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6428 if (packOp.getSourceType() != unPackOp.getDestType())
6430 if (packOp.getPaddingValue() ||
6434 rewriter.
replaceOp(unPackOp, packOp.getSource());
6438 if (
auto dstStyleOp =
6439 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6440 auto destValue = cast<OpResult>(unPackOp.getDest());
6441 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6443 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6447 if (unPackOp->hasOneUse()) {
6448 auto extractSliceUser =
6449 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6450 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6451 OpBuilder::InsertionGuard g(rewriter);
6453 auto newDest = tensor::ExtractSliceOp::create(
6454 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6455 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6456 extractSliceUser.getMixedStrides());
6458 unPackOp.setDpsInitOperand(0, newDest);
6459 unPackOp.getResult().setType(newDest.
getType());
6461 rewriter.
replaceOp(extractSliceUser, unPackOp);
6467 SmallVector<int64_t> srcShape, destShape;
6469 Location loc = unPackOp.getLoc();
6470 Value source = unPackOp.getSource();
6471 if (srcShape != unPackOp.getSourceType().getShape()) {
6472 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6473 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6474 unPackOp.getSource());
6476 Value dest = unPackOp.getDest();
6477 if (destShape != unPackOp.getDestType().getShape()) {
6478 auto newDestType = unPackOp.getDestType().clone(destShape);
6479 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6480 unPackOp.getDest());
6482 UnPackOp newOp = UnPackOp::create(
6483 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6484 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6486 unPackOp, unPackOp.getResult().
getType(), newOp.getResult());
6493bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6495 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6500 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6501 SmallVector<int64_t> outerShapeWithoutTranspose =
6503 SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(),
false);
6504 for (
auto [pos, tileSize] :
6505 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6506 areOuterDimsTiled[pos] =
true;
6507 if (unpackedTypeAfterFold.isDynamicDim(pos))
6509 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6511 if (ShapedType::isDynamic(tileSize))
6513 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6514 unpackedTypeAfterFold.getDimSize(pos);
6515 if (paddingSize >= tileSize)
6519 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6520 if (areOuterDimsTiled[pos])
6522 int64_t dim = outerShapeWithoutTranspose[pos];
6523 if (ShapedType::isDynamic(dim))
6525 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6531bool UnPackOp::isLikeUnPad() {
6532 ShapedType packedTensorType = getSourceType();
6536::mlir::LogicalResult
6537UnPackOp::fold(FoldAdaptor adaptor,
6538 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6540 if (!hasPureTensorSemantics())
6543 if (OpFoldResult reshapedSource = reshapeConstantSource(
6544 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6545 cast<TensorType>(getResult().
getType()))) {
6546 results.push_back(reshapedSource);
6572 if (!op.hasPureTensorSemantics())
6581 Value sourceTensor = newOperands[0];
6585 rewriter, sourceTensor.
getType(), op.getMixedTiles());
6591 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6592 newOperands[1], op.getInnerDimsPos(),
6593 newMixedTileSizes, op.getOuterDimsPerm());
6594 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6597 Value oldResult = op.getResult();
6598 Value newResult = newOp.getResult();
6601 ? tensor::CastOp::create(rewriter, op->getLoc(),
6602 oldResult.
getType(), newResult)
6616 utils::IteratorType::reduction, utils::IteratorType::parallel,
6617 utils::IteratorType::parallel, utils::IteratorType::reduction};
6620SmallVector<AffineMap>
6621BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6622 AffineExpr d0, d1, d2, d3;
6623 SmallVector<AffineMap> indexingMaps;
6625 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
6626 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
6628 return indexingMaps;
6631bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6632 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6635 if (maps.size() != 3)
6640 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6641 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6642 (*positions)[2] == SmallVector<int64_t>{1, 2};
6644unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
6646std::string BatchReduceMatmulOp::getLibraryCallName() {
6652bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6653 SmallVector<AffineMap, 3> defaultMaps =
6655 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6656 return defaultMaps != explicitMaps;
6666bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6669 "Expected less than 3 result dim expr.");
6670 bool isValid =
false;
6671 enum Indices { batchPos, mPos, nPos, kPos };
6673 AffineExpr expr = bcastMap.
getResult(0);
6676 AffineExpr expr0 = bcastMap.
getResult(0);
6677 AffineExpr expr1 = bcastMap.
getResult(1);
6682 : ((expr0.isFunctionOfDim(batchPos) &&
6683 expr1.isFunctionOfDim(kPos)) ||
6684 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6689void BatchReduceMatmulOp::regionBuilder(
6690 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
6693 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
6698 "BatchReduceMatmulOp regionBuilder expects 3 args");
6699 RegionBuilderHelper helper(
b, block);
6700 SmallVector<Value> yields;
6704 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
6706 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
6708 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
6709 if (!castValA || !castValB || !mulVal)
6712 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
6715 yields.push_back(addVal);
6716 helper.yieldOutputs(yields);
6719ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6720 OperationState &
result) {
6721 SmallVector<Attribute, 3> indexingMapsAttr;
6732 if (!isa<AffineMapAttr>(mapAttr)) {
6734 "expected affine map attribute");
6736 indexingMapsAttr.push_back(mapAttr);
6746 if (indexingMapsAttr.empty()) {
6747 indexingMapsAttr = llvm::map_to_vector(
6748 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
6749 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6751 result.addAttribute(
"indexing_maps",
6753 return ::parseNamedStructuredOp(parser,
result,
6754 BatchReduceMatmulOp::getNumRegionArgs(),
6755 BatchReduceMatmulOp::getRegionBuilder());
6758void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6759 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6760 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
6761 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
6763 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6764 p <<
" indexing_maps = [";
6765 llvm::interleaveComma(getIndexingMaps(), p,
6770 SmallVector<StringRef, 3> elidedAttrs = {
6771 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
6777LogicalResult BatchReduceMatmulOp::verify() {
6780 if (!hasUserDefinedMaps())
6783 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
6789LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6790 SmallVectorImpl<OpFoldResult> &) {
6793void BatchReduceMatmulOp::getEffects(
6794 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
6796 if (hasPureTensorSemantics())
6812void LinalgDialect::getCanonicalizationPatterns(
6821 return arith::ConstantOp::materialize(builder, value, type, loc);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static std::optional< TypedAttr > getScalarConstantAttrFromDenseSplat(Value input)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_iterator result_begin()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
ArrayRef< T > asArrayRef() const
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, ArrayRef< OpFoldResult > mixedTiles)
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a broadcast of a dense splat constant into a dense splat constant of the broadcast output sha...
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Rewrite a transpose of a dense splat constant into a dense splat constant of the transposed output sh...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override