39#include "llvm/ADT/DenseMap.h"
40#include "llvm/ADT/STLExtras.h"
41#include "llvm/ADT/SetOperations.h"
42#include "llvm/ADT/SmallVector.h"
43#include "llvm/ADT/StringSet.h"
44#include "llvm/ADT/TypeSwitch.h"
45#include "llvm/Support/FormatVariadic.h"
46#include "llvm/Support/InterleavedRange.h"
47#include "llvm/Support/LogicalResult.h"
48#include "llvm/Support/MathExtras.h"
49#include "llvm/Support/raw_ostream.h"
59 auto type = cast<ShapedType>(v.
getType());
60 if (!type.isDynamicDim(dim))
65 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
66 return tensor::DimOp::create(builder, loc, v, dim);
68 .Case<MemRefType>([&](MemRefType t) ->
Value {
69 return memref::DimOp::create(builder, loc, v, dim);
80 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
81 return tensor::ExtractSliceOp::create(
b, loc, source, offsets, sizes,
84 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
85 return memref::SubViewOp::create(
b, loc, source, offsets, sizes,
97 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
98 return b.createOrFold<memref::DimOp>(loc, source, dim);
99 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
100 return b.createOrFold<tensor::DimOp>(loc, source, dim);
101 llvm_unreachable(
"Expected MemRefType or TensorType");
106 auto shapedType = llvm::cast<ShapedType>(source.
getType());
107 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
109 return b.getIndexAttr(shapedType.getDimSize(dim));
132 for (
auto containers : {inputTypes, outputTypes}) {
133 for (
auto t : containers) {
145 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
161 std::optional<TypeRange> resultTensorTypes,
168 if (!resultTensorTypes)
169 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
170 llvm::IsaPred<RankedTensorType>);
178 "operandSegmentSizes",
179 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
180 static_cast<int32_t>(outputs.size())}));
190 std::optional<TypeRange> resultTensorTypes,
197 indexingMapsAttrVal =
199 return AffineMapAttr::get(map);
201 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
203 attributes, regionBuilder);
207 std::optional<TypeRange> resultTensorTypes,
214 indexingMapsAttrVal =
216 return AffineMapAttr::get(map);
218 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
220 attributes, regionBuilder);
224 std::optional<TypeRange> resultTensorTypes,
231 indexingMapsAttrVal =
233 return AffineMapAttr::get(map);
235 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
237 attributes, regionBuilder);
246 bool addOperandSegmentSizes =
true) {
247 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
276 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
278 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
282 if (addOperandSegmentSizes) {
289 if (
result.propertiesAttr) {
291 attrs.
append(
"operandSegmentSizes",
293 {static_cast<int32_t>(inputsOperands.size()),
294 static_cast<int32_t>(outputsOperands.size())}));
297 result.addAttribute(
"operandSegmentSizes",
299 {static_cast<int32_t>(inputsOperands.size()),
300 static_cast<int32_t>(outputsOperands.size())}));
303 if (!
result.propertiesAttr) {
304 std::optional<RegisteredOperationName> info =
305 result.name.getRegisteredInfo();
307 if (failed(info->verifyInherentAttrs(
result.attributes, [&]() {
308 return parser.emitError(attrsLoc)
309 <<
"'" << result.name.getStringRef() <<
"' op ";
320 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
321 if (!outputs.empty())
322 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
333 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
336 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
337 "region expects {0} args, got {1}",
338 numRegionArgs, inputTypes.size() + outputTypes.size()));
344 opBuilder, region, inputTypes, outputTypes, attrs,
363 unsigned numRegionArgs,
380 result.addTypes(outputTensorsTypes);
382 std::unique_ptr<Region> region = std::make_unique<Region>();
384 outputTypes,
result.attributes.getAttrs(),
387 result.addRegion(std::move(region));
394 if (resultTypes.empty())
439class RegionBuilderHelper {
441 RegionBuilderHelper(OpBuilder &builder,
Block &block)
442 : builder(builder), block(block) {}
445 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
447 if (!isFloatingPoint(arg)) {
449 emitError() <<
"unsupported non numeric type";
452 llvm_unreachable(
"unsupported non numeric type");
454 OpBuilder::InsertionGuard g(builder);
455 builder.setInsertionPointToEnd(&block);
458 return math::ExpOp::create(builder, arg.
getLoc(), arg);
460 return math::LogOp::create(builder, arg.
getLoc(), arg);
462 return math::AbsFOp::create(builder, arg.
getLoc(), arg);
464 return math::CeilOp::create(builder, arg.
getLoc(), arg);
466 return math::FloorOp::create(builder, arg.
getLoc(), arg);
468 return arith::NegFOp::create(builder, arg.
getLoc(), arg);
469 case UnaryFn::reciprocal: {
470 Attribute oneAttr = builder.getOneAttr(arg.
getType());
471 auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
472 ::cast<TypedAttr>(oneAttr));
473 return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
476 return math::RoundOp::create(builder, arg.
getLoc(), arg);
478 return math::SqrtOp::create(builder, arg.
getLoc(), arg);
480 return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
481 case UnaryFn::square:
482 return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
484 return math::TanhOp::create(builder, arg.
getLoc(), arg);
486 return math::ErfOp::create(builder, arg.
getLoc(), arg);
489 emitError() <<
"unsupported unary function";
492 llvm_unreachable(
"unsupported unary function");
499 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
501 bool allComplex = isComplex(arg0) && isComplex(arg1);
502 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
503 bool allInteger = isInteger(arg0) && isInteger(arg1);
506 if (!allComplex && !allFloatingPoint && !allInteger) {
509 <<
"Cannot build binary Linalg operation: expects allComplex, "
510 "allFloatingPoint, or allInteger, got "
514 llvm_unreachable(
"unsupported non numeric type");
516 OpBuilder::InsertionGuard g(builder);
517 builder.setInsertionPointToEnd(&block);
521 return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
522 if (allFloatingPoint)
523 return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
525 return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
526 return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
529 return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
534 emitError() <<
"unsupported operation: sub with bools";
537 llvm_unreachable(
"unsupported operation: sub with bools");
539 return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
542 return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
543 if (allFloatingPoint)
544 return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
546 return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
547 return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
550 return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
555 emitError() <<
"unsupported operation: div with bools";
558 llvm_unreachable(
"unsupported operation: div with bools");
560 return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
561 case BinaryFn::div_unsigned:
562 if (!allInteger || allBool) {
564 emitError() <<
"unsupported operation: unsigned div not on uint";
567 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
569 return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
570 case BinaryFn::max_signed:
572 if (allFloatingPoint)
573 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
574 return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
575 case BinaryFn::min_signed:
577 if (allFloatingPoint)
578 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
579 return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
580 case BinaryFn::max_unsigned:
582 if (!allInteger || allBool) {
584 emitError() <<
"unsupported operation: unsigned max not on uint";
587 llvm_unreachable(
"unsupported operation: unsigned max not on uint");
589 return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
590 case BinaryFn::min_unsigned:
592 if (!allInteger || allBool) {
594 emitError() <<
"unsupported operation: unsigned min not on uint";
597 llvm_unreachable(
"unsupported operation: unsigned min not on uint");
599 return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
601 assert(allFloatingPoint);
602 return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
605 emitError() <<
"unsupported binary function";
608 llvm_unreachable(
"unsupported binary function");
612 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
616 bool tailFloatingPoint =
617 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
618 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
619 OpBuilder::InsertionGuard g(builder);
620 builder.setInsertionPointToEnd(&block);
622 case TernaryFn::select:
623 if (!headBool && !(tailFloatingPoint || tailInteger))
624 llvm_unreachable(
"unsupported non numeric type");
625 return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
628 emitError() <<
"unsupported ternary function";
631 llvm_unreachable(
"unsupported ternary function");
635 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
638 case TypeFn::cast_signed:
639 return cast(toType, operand,
false);
640 case TypeFn::cast_unsigned:
641 return cast(toType, operand,
true);
644 emitError() <<
"unsupported type conversion function";
647 llvm_unreachable(
"unsupported type conversion function");
651 OpBuilder::InsertionGuard g(builder);
652 builder.setInsertionPointToEnd(&block);
653 Location loc = builder.getUnknownLoc();
654 YieldOp::create(builder, loc, values);
657 Value constant(
const std::string &value) {
658 OpBuilder::InsertionGuard g(builder);
659 builder.setInsertionPointToEnd(&block);
660 Location loc = builder.getUnknownLoc();
661 Attribute valueAttr =
parseAttribute(value, builder.getContext());
662 return arith::ConstantOp::create(builder, loc,
663 ::cast<TypedAttr>(valueAttr));
666 Value index(int64_t dim) {
667 OpBuilder::InsertionGuard g(builder);
668 builder.setInsertionPointToEnd(&block);
669 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
672 Type getIntegerType(
unsigned width) {
673 return IntegerType::get(builder.getContext(), width);
676 Type getFloat32Type() {
return Float32Type::get(builder.getContext()); }
677 Type getFloat64Type() {
return Float64Type::get(builder.getContext()); }
684 Value cast(Type toType, Value operand,
bool isUnsignedCast) {
685 OpBuilder::InsertionGuard g(builder);
686 builder.setInsertionPointToEnd(&block);
687 auto loc = operand.
getLoc();
688 if (isa<UnknownLoc>(loc)) {
698 bool isComplex(Value value) {
699 return llvm::isa<ComplexType>(value.
getType());
701 bool isFloatingPoint(Value value) {
702 return llvm::isa<FloatType>(value.
getType());
704 bool isInteger(Value value) {
705 return llvm::isa<IntegerType>(value.
getType());
721 using OpRewritePattern<CopyOp>::OpRewritePattern;
722 LogicalResult matchAndRewrite(CopyOp copyOp,
723 PatternRewriter &rewriter)
const override {
724 if (copyOp.getInputs() != copyOp.getOutputs())
726 if (copyOp.hasPureBufferSemantics())
729 rewriter.
replaceOp(copyOp, copyOp.getInputs());
739 results.
add<EraseSelfCopy>(context);
752template <
typename TensorReshapeOp>
754 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
755 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
756 PatternRewriter &rewriter)
const override {
757 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
761 Location loc = oldFill.getLoc();
762 TensorReshapeOp newInit;
763 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
765 newInit = TensorReshapeOp::create(
766 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
767 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
768 reshapeOp.getStaticOutputShape());
770 newInit = TensorReshapeOp::create(
771 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
772 reshapeOp.getReassociation());
785 LogicalResult matchAndRewrite(tensor::PadOp padOp,
786 PatternRewriter &rewriter)
const override {
787 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
793 Value padValue = padOp.getConstantPaddingValue();
794 if (!padValue || fillOp.value() != padValue)
800 padOp,
"failed to reify tensor.pad op result shape");
803 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
804 padOp.getResultType().getElementType());
806 FillOp::create(rewriter, fillOp.getLoc(),
ValueRange{padValue},
809 if (
replacement.getType() != padOp.getResultType()) {
810 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
821struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
824 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
825 PatternRewriter &rewriter)
const override {
826 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
830 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
835 Value firstDest = insertOp.getDest();
836 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
837 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
842 bool disjoint =
false;
843 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
846 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
847 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
848 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
852 int64_t prevStart = prevOp.getStaticOffset(i);
853 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
854 prevOp.getStaticStride(i);
855 int64_t nextStart = insertOp.getStaticOffset(i);
856 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
857 insertOp.getStaticStride(i);
858 if (prevEnd < nextStart || nextEnd < prevStart) {
866 firstDest = prevOp.getDest();
877 Value padValue = srcPadOp.getConstantPaddingValue();
878 if (!padValue || dstFillOp.value() != padValue)
881 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
882 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
884 Location loc = insertOp.getLoc();
887 AffineExpr sym0, sym1;
893 SmallVector<OpFoldResult, 4> newOffsets;
894 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
895 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
896 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
899 RankedTensorType srcPadType = srcPadOp.getSourceType();
900 SmallVector<OpFoldResult, 4> newSizes;
901 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
902 if (srcPadType.isDynamicDim(i)) {
904 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
907 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
912 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
913 newSizes, insertOp.getMixedStrides());
919struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
921 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
923 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
924 PatternRewriter &rewriter)
const override {
927 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
932 Value extractedScalar = fillOp.getInputs()[0];
935 rewriter.
replaceOp(extractOp, extractedScalar);
943static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
944 linalg::PackOp packOp) {
945 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
949 if (
auto paddingValue = packOp.getPaddingValue())
953 Value packOpDest = packOp.getDest();
957 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
964 FoldFillWithPack(MLIRContext *context)
965 : OpRewritePattern<linalg::PackOp>(context) {}
967 LogicalResult matchAndRewrite(linalg::PackOp packOp,
968 PatternRewriter &rewriter)
const override {
969 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
972 rewriter.
replaceOp(packOp, fillOp.value().result());
979 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
981 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
982 PatternRewriter &rewriter)
const override {
983 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
986 copyOp.getOutputs());
989 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
991 fillOp.getOutputs());
1000 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1002 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1003 PatternRewriter &rewriter)
const override {
1004 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1006 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1007 transposeOp.getDpsInitOperand(0)->get());
1019 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1020 PatternRewriter &rewriter)
const override {
1021 auto concatOperands = concatOp.getInputs();
1022 if (concatOperands.empty()) {
1026 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1031 OpFoldResult firstFillVal =
1034 SmallVector<Value> allOuts;
1035 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1037 auto isDefinedByCompatibleFillOp = [&](Value v) ->
bool {
1038 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1043 OpFoldResult fillVal =
1045 if (fillVal != firstFillVal)
1048 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1051 if (!llvm::all_of(concatOperands.drop_front(),
1052 isDefinedByCompatibleFillOp)) {
1054 concatOp,
"not all operands are defined by a compatible fill op");
1057 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1058 concatOp.getDim(), allOuts);
1060 concatOp, firstFillOp.getDpsInputOperand(0)->
get(), outsConcat);
1069 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1070 FoldFillWithPack, FoldFillWithPad,
1071 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1072 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1073 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1086 for (
ValueRange container : {inputs, outputs}) {
1087 for (
Value v : container) {
1088 Type t = v.getType();
1089 blockArgTypes.push_back(
1091 blockArgLocs.push_back(v.getLoc());
1097 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1101void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
1103 for (Value v : getRegionInputArgs())
1105 for (Value v : getRegionOutputArgs())
1106 setNameFn(v,
"out");
1109void GenericOp::build(
1110 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1112 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1114 ArrayRef<NamedAttribute> attributes) {
1115 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1116 iteratorTypes, doc, libraryCall);
1117 result.addAttributes(attributes);
1120 inputs, outputs, bodyBuild);
1123void GenericOp::build(
1124 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1126 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1127 StringRef libraryCall,
1129 ArrayRef<NamedAttribute> attributes) {
1130 build(builder,
result, resultTensorTypes, inputs, outputs,
1134 [&](utils::IteratorType iter) -> mlir::Attribute {
1135 return IteratorTypeAttr::get(builder.getContext(), iter);
1138 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1139 bodyBuild, attributes);
1142void GenericOp::build(
1144 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1145 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1146 StringRef libraryCall,
1148 ArrayRef<NamedAttribute> attributes) {
1150 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1153void GenericOp::build(
1155 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1156 ArrayRef<utils::IteratorType> iteratorTypes,
1158 ArrayRef<NamedAttribute> attributes) {
1159 build(builder,
result, inputs, outputs, indexingMaps, iteratorTypes,
1161 "", bodyBuild, attributes);
1164void GenericOp::build(
1165 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1167 ArrayRef<utils::IteratorType> iteratorTypes,
1169 ArrayRef<NamedAttribute> attributes) {
1170 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1173 "", bodyBuild, attributes);
1176void GenericOp::print(OpAsmPrinter &p) {
1180 auto genericAttrNames = linalgTraitAttrNames();
1182 llvm::StringSet<> genericAttrNamesSet;
1183 genericAttrNamesSet.insert_range(genericAttrNames);
1184 SmallVector<NamedAttribute, 8> genericAttrs;
1185 for (
auto attr : (*this)->getAttrs()) {
1186 if (attr.getName() == getIteratorTypesAttrName()) {
1187 auto iteratorTypes =
1188 llvm::cast<ArrayAttr>(attr.getValue())
1189 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1194 SmallVector<Attribute> iteratorTypeNames =
1195 llvm::to_vector(llvm::map_range(
1196 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1197 return StringAttr::get(
getContext(), stringifyIteratorType(t));
1200 genericAttrs.emplace_back(
1201 getIteratorTypesAttrName(),
1202 ArrayAttr::get(
getContext(), iteratorTypeNames));
1203 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1204 genericAttrs.push_back(attr);
1207 if (!genericAttrs.empty()) {
1208 auto genericDictAttr = DictionaryAttr::get(
getContext(), genericAttrs);
1209 p << genericDictAttr;
1215 genericAttrNames.push_back(
"operandSegmentSizes");
1216 genericAttrNamesSet.insert(genericAttrNames.back());
1218 bool hasExtraAttrs =
false;
1219 for (NamedAttribute n : (*this)->getAttrs()) {
1220 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1223 if (hasExtraAttrs) {
1230 if (!getRegion().empty()) {
1239ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &
result) {
1240 DictionaryAttr dictAttr;
1248 result.attributes.assign(dictAttr.getValue().begin(),
1249 dictAttr.getValue().end());
1255 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1256 result.attributes.get(getIteratorTypesAttrName(
result.name)));
1257 if (!iteratorTypes) {
1258 return parser.
emitError(attributeLocation)
1259 <<
"expected " << getIteratorTypesAttrName(
result.name)
1260 <<
" array attribute";
1263 SmallVector<Attribute> iteratorTypeAttrs;
1265 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1266 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1267 if (!maybeIteratorType.has_value())
1269 <<
"unexpected iterator_type (" << s <<
")";
1271 iteratorTypeAttrs.push_back(
1272 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
1274 result.attributes.set(getIteratorTypesAttrName(
result.name),
1278 SmallVector<Type, 1> inputTypes, outputTypes;
1288 std::unique_ptr<Region> region = std::make_unique<Region>();
1291 result.addRegion(std::move(region));
1297 SmallVector<Type, 1> outputTensorsTypes;
1300 result.addTypes(outputTensorsTypes);
1308 LinalgOp linalgOp) {
1309 for (
auto [
index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1310 if (!llvm::isa<MemRefType>(operand.
getType()))
1312 effects.emplace_back(
1317 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1318 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1320 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1331void GenericOp::getEffects(
1332 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1341 if (!linalgOp.hasPureTensorSemantics())
1359template <
typename OpTy>
1360struct EraseIdentityLinalgOp :
public OpRewritePattern<OpTy> {
1361 using OpRewritePattern<OpTy>::OpRewritePattern;
1363 LogicalResult matchAndRewrite(OpTy linalgOp,
1364 PatternRewriter &rewriter)
const override {
1366 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1371 Block &body = linalgOp->getRegion(0).front();
1372 if (!llvm::hasSingleElement(body))
1374 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1379 if (linalgOp.hasPureBufferSemantics()) {
1380 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1381 linalgOp.getDpsInputOperand(0)->get() !=
1382 linalgOp.getDpsInitOperand(0)->get()) {
1384 linalgOp,
"expected single input and output to be the same value");
1387 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1388 if (!yieldArg || yieldArg.getOwner() != &body) {
1390 "cannot fold fill-like op");
1397 if (!linalgOp.hasPureTensorSemantics()) {
1399 linalgOp,
"mixed semantics is not supported yet");
1404 SmallVector<Value> returnedArgs;
1405 for (
const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1406 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1407 if (!yieldArg || yieldArg.getOwner() != &body)
1409 unsigned argumentNumber = yieldArg.getArgNumber();
1410 Value returnedArg = linalgOp->getOperand(argumentNumber);
1411 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1414 Type returnType = returnedArg.
getType();
1415 if (returnType != resultType) {
1420 returnedArg = sparse_tensor::ConvertOp::create(
1421 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1423 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1426 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1427 resultType, returnedArg);
1430 returnedArgs.push_back(returnedArg);
1433 if (returnedArgs.size() != linalgOp->getNumResults())
1435 rewriter.
replaceOp(linalgOp, returnedArgs);
1442void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1443 MLIRContext *context) {
1444 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1447LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1466 for (
Type outputType : outputTypes) {
1467 if (llvm::isa<RankedTensorType>(outputType))
1468 result.addTypes(outputType);
1472 if (parseAttrsFn && failed(parseAttrsFn(parser,
result.attributes)))
1481void MapOp::getAsmBlockArgumentNames(Region ®ion,
1483 for (Value v : getRegionInputArgs())
1485 for (Value v : getRegionOutputArgs())
1486 setNameFn(v,
"init");
1489void MapOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1490 if (!getResults().empty())
1491 setNameFn(getResults().front(),
"mapped");
1497 ArrayRef<NamedAttribute> attributes) {
1499 result.addAttributes(attributes);
1502 Type initType = init.
getType();
1503 if (llvm::isa<RankedTensorType>(initType))
1504 result.addTypes(initType);
1508 inputs, {init}, bodyBuild);
1515 bool initFirst =
false,
bool mapInit =
true) {
1519 b.setInsertionPointToStart(&block);
1520 for (
auto &operand : operands) {
1522 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1530 payloadOpOperands.push_back(block.
getArguments().back());
1531 for (
const auto &arg : block.
getArguments().drop_back())
1532 payloadOpOperands.push_back(arg);
1541 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1547ParseResult MapOp::parse(OpAsmParser &parser, OperationState &
result) {
1548 std::optional<OperationName> payloadOpName;
1549 NamedAttrList payloadOpAttrs;
1552 if (
failed(operationName))
1556 payloadOpName = operationName.value();
1564 if (payloadOpName.has_value()) {
1565 if (!
result.operands.empty())
1567 payloadOpAttrs, ArrayRef(
result.operands),
false,
1572 SmallVector<OpAsmParser::Argument> regionArgs;
1577 Region *body =
result.addRegion();
1585 bool mapInit =
true) {
1587 if (initFirst && !mapInit)
1611 for (
const auto &[operand, bbArg] :
1613 if (bbArg != operand)
1617 for (
const auto &[operand, bbArg] :
1620 if (bbArg != operand)
1627 return yieldOp.getNumOperands() == 1 &&
1628 yieldOp.getOperand(0).getDefiningOp() &&
1629 yieldOp.getOperand(0).getDefiningOp() == &payload;
1634 std::string attrToElide;
1636 for (
const auto &attr : payloadOp->
getAttrs()) {
1638 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1639 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1640 attrToElide = attr.getName().str();
1641 elidedAttrs.push_back(attrToElide);
1649void MapOp::print(OpAsmPrinter &p) {
1650 Block *mapper = getBody();
1660 if (!useShortForm) {
1666 [&](
auto arg) { p.printRegionArgument(arg); });
1674LogicalResult MapOp::verify() {
1675 auto *bodyBlock = getBody();
1676 auto blockArgs = bodyBlock->getArguments();
1680 if (getInputs().size() + 1 != blockArgs.size())
1681 return emitOpError() <<
"expects number of operands to match the arity of "
1683 << getInputs().size() + 1 <<
" and "
1684 << blockArgs.size();
1687 for (
const auto &[bbArgType, inputArg] :
1688 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1689 auto inputElemType =
1690 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1691 if (bbArgType != inputElemType) {
1692 return emitOpError() <<
"expected element type of input " << inputElemType
1693 <<
" to match bbArg type " << bbArgType;
1698 auto outputShape = getInit().getType().getShape();
1699 for (Type inputArgType :
TypeRange{getInputs()}) {
1700 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1701 if (inputElemShape != outputShape) {
1702 return emitOpError() <<
"expected shape of input (" << inputElemShape
1703 <<
") to match shape of output (" << outputShape
1711SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1712 int64_t rank = getInit().getType().getRank();
1713 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1718 int64_t rank = getInit().getType().getRank();
1719 int64_t numIndexingMaps = getOperands().size();
1724void MapOp::getEffects(
1725 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1738void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
1740 for (Value v : getRegionInputArgs())
1742 for (Value v : getRegionOutputArgs())
1743 setNameFn(v,
"init");
1746void ReduceOp::getAsmResultNames(
1748 if (!getResults().empty())
1749 setNameFn(getResults().front(),
"reduced");
1752void ReduceOp::build(
1754 ValueRange inits, ArrayRef<int64_t> dimensions,
1756 ArrayRef<NamedAttribute> attributes) {
1758 result.addAttributes(attributes);
1761 for (Value init : inits) {
1762 Type initType = init.
getType();
1763 if (llvm::isa<RankedTensorType>(initType))
1764 result.addTypes(initType);
1769 inputs, inits, bodyBuild);
1772SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1774 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1775 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1776 utils::IteratorType::parallel);
1777 for (int64_t reductionDim : getDimensions())
1778 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1779 return iteratorTypes;
1784 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1785 SmallVector<AffineMap> affineMaps(
1788 AffineMap resultMap =
1791 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1792 affineMaps.push_back(resultMap);
1793 return Builder(
getContext()).getAffineMapArrayAttr(affineMaps);
1796void ReduceOp::getEffects(
1797 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1808 StringRef attributeName) {
1816ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &
result) {
1817 std::optional<OperationName> payloadOpName;
1818 NamedAttrList payloadOpAttrs;
1821 if (
failed(operationName))
1825 payloadOpName = operationName.value();
1831 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1836 if (payloadOpName.has_value()) {
1838 ArrayRef(
result.operands),
true);
1840 SmallVector<OpAsmParser::Argument> regionArgs;
1846 Region *body =
result.addRegion();
1856 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1859void ReduceOp::print(OpAsmPrinter &p) {
1860 Block *mapper = getBody();
1869 if (!useShortForm) {
1875 [&](
auto arg) { p.printRegionArgument(arg); });
1883LogicalResult ReduceOp::verify() {
1884 ArrayRef<int64_t> dimensionsRef = getDimensions();
1886 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1889 return emitOpError() <<
"expects all inputs to have the same shapes. "
1890 "Shape at input-index "
1892 <<
" is not equal to the shape at input-index 0.";
1895 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1898 return emitOpError() <<
"expects all outputs to have the same shapes. "
1899 "Shape at output-index "
1901 <<
" is not equal to the shape at output-index 0.";
1904 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1905 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1908 for (int64_t dimension : dimensionsRef) {
1909 if (dimension < 0 || dimension >= inputType.getRank()) {
1911 <<
"dimensions for reduction should be in the range [0, "
1912 << inputType.getRank() - 1 <<
"].";
1914 dimensionsToReduce.insert(dimension);
1917 auto inputDims = inputType.getShape();
1918 auto initDims = initType.getShape();
1921 SmallVector<int64_t> reducedInputDims;
1922 for (
const auto &en : llvm::enumerate(inputDims)) {
1923 if (!dimensionsToReduce.count(en.index()))
1924 reducedInputDims.push_back(en.value());
1927 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1928 return emitOpError() <<
"number of dimensions after reduction "
1929 << reducedInputDims.size()
1930 <<
" doesn't match the init rank "
1931 << initType.getRank();
1934 if (reducedInputDims != initDims)
1935 return emitOpError() <<
"init dimensions [" << initDims
1936 <<
"] doesn't match input dimensions after reduction ["
1937 << reducedInputDims <<
"]";
1939 Block *block = getBody();
1942 <<
"mismatching number of operands and block arguments";
1945 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1946 Type inputElementType =
1947 llvm::cast<ShapedType>(input.getType()).getElementType();
1948 if (inputElementType != bbArg.getType())
1950 <<
"input element type " << inputElementType
1951 <<
" does not match corresponding block argument type "
1956 for (
auto [output, bbArg] : llvm::zip(
1957 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1958 auto outputElementType =
1959 llvm::cast<ShapedType>(output.getType()).getElementType();
1960 if (outputElementType != bbArg.getType())
1962 <<
"output element type " << outputElementType
1963 <<
" does not match corresponding block argument type "
1979 linalg::YieldOp::create(
b, loc, args[0]);
1983void TransposeOp::build(::mlir::OpBuilder &builder,
1984 ::mlir::OperationState &
result, Value input, Value init,
1986 ArrayRef<NamedAttribute> attributes) {
1987 result.addOperands(input);
1988 result.addOperands(init);
1989 result.addAttribute(getPermutationAttrName(
result.name), permutation);
1990 result.addAttributes(attributes);
1993 Type initType = init.
getType();
1994 if (llvm::isa<RankedTensorType>(initType))
1995 result.addTypes(initType);
2001void TransposeOp::build(::mlir::OpBuilder &builder,
2002 ::mlir::OperationState &
result, Value input, Value init,
2003 ArrayRef<int64_t> permutation,
2004 ArrayRef<NamedAttribute> attributes) {
2009ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
2011 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2023void TransposeOp::getAsmResultNames(
2025 if (!getResults().empty())
2026 setNameFn(getResults().front(),
"transposed");
2029void TransposeOp::print(OpAsmPrinter &p) {
2035LogicalResult TransposeOp::verify() {
2036 ArrayRef<int64_t> permutationRef = getPermutation();
2041 auto inputType = getInput().getType();
2042 auto initType = getInit().getType();
2044 int64_t rank = inputType.getRank();
2046 if (rank != initType.getRank())
2048 <<
" does not match init rank " << initType.getRank();
2050 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2051 return emitOpError() <<
"size of permutation " << permutationRef.size()
2052 <<
" does not match the argument rank " << rank;
2054 auto inputDims = inputType.getShape();
2055 auto initDims = initType.getShape();
2057 for (int64_t i = 0; i < rank; ++i) {
2058 int64_t inputDim = inputDims[permutationRef[i]];
2059 int64_t initDim = initDims[i];
2061 if (inputDim != initDim) {
2062 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2063 <<
" doesn't match dim(input, permutation[" << i
2064 <<
"]) = " << inputDim;
2071SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2072 int64_t rank = getInit().getType().getRank();
2073 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2076ArrayAttr TransposeOp::getIndexingMaps() {
2078 int64_t rank = getInit().getType().getRank();
2081 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2085void TransposeOp::getEffects(
2086 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2095LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2096 SmallVectorImpl<OpFoldResult> &
result) {
2098 if (!isa<TensorType>(getInput().
getType()))
2102 if (getPermutation().empty()) {
2103 result.push_back(getInput());
2108 result.push_back(getInput());
2121 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2122 if (!defTransposeOp)
2127 foldedPerms.reserve(perms.size());
2129 foldedPerms.push_back(defPerms[perm]);
2132 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2146 Value input = transposeOp.getInput();
2147 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2158 unsigned dimensionSize = dimensions.size();
2159 for (
unsigned i = 0; i < dimensionSize; ++i)
2160 resultDimensions.push_back(invertPerm[dimensions[i]]);
2163 Value broadcastInput = broadcastOp.getInput();
2164 Location loc = transposeOp.getLoc();
2167 auto broadcastInputTy =
2168 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2169 unsigned inputRank = broadcastInputTy.getRank();
2170 for (
unsigned i = 0; i < inputRank; ++i) {
2171 if (broadcastInputTy.isDynamicDim(i)) {
2172 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2175 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2176 broadcastInputTy.getDimSize(i)));
2181 Value transposeInit = tensor::EmptyOp::create(
2182 rewriter, transposeOp.getLoc(), transposeResultShapes,
2183 broadcastInputTy.getElementType());
2186 Value transposeResult =
2187 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2188 transposeInit, resultPerms)
2191 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2196void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2197 MLIRContext *context) {
2198 results.
add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2205void BroadcastOp::build(::mlir::OpBuilder &builder,
2206 ::mlir::OperationState &
result, Value input, Value init,
2208 ArrayRef<NamedAttribute> attributes) {
2209 result.addOperands(input);
2210 result.addOperands(init);
2211 result.addAttribute(getDimensionsAttrName(
result.name), dimensions);
2212 result.addAttributes(attributes);
2215 Type initType = init.
getType();
2216 if (llvm::isa<RankedTensorType>(initType))
2217 result.addTypes(initType);
2223void BroadcastOp::build(::mlir::OpBuilder &builder,
2224 ::mlir::OperationState &
result, Value input, Value init,
2225 ArrayRef<int64_t> dimensions,
2226 ArrayRef<NamedAttribute> attributes) {
2231ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &
result) {
2233 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2245void BroadcastOp::getAsmResultNames(
2247 if (!getResults().empty())
2248 setNameFn(getResults().front(),
"broadcasted");
2251void BroadcastOp::print(OpAsmPrinter &p) {
2257LogicalResult BroadcastOp::verify() {
2258 ArrayRef<int64_t> dimensionsRef = getDimensions();
2260 auto inputType = getInput().getType();
2261 auto initType = getInit().getType();
2263 int64_t inputRank = inputType.getRank();
2264 int64_t initRank = initType.getRank();
2266 auto inputShape = inputType.getShape();
2267 auto initShape = initType.getShape();
2269 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2270 return emitOpError() <<
"input rank plus added dimensions does not "
2271 "match init rank. input rank: "
2273 <<
", dimensions size: " << dimensionsRef.size()
2274 <<
", init rank: " << initRank;
2276 for (
const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2277 if (dim < 0 || dim >= initRank)
2279 <<
" is out of range. expected range: [0, "
2280 << initRank - 1 <<
"], got: " << dim;
2284 SmallVector<int64_t> dimMap;
2285 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2286 if (!llvm::is_contained(dimensionsRef, dim))
2287 dimMap.push_back(dim);
2290 for (
const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2293 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2294 return emitOpError() <<
"input dim " << inputDimIdx
2295 <<
" should match init dim " << initDimIdx
2296 <<
". input: " << inputShape[inputDimIdx]
2297 <<
", init: " << initShape[initDimIdx];
2303SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2304 int64_t rank = getInit().getType().getRank();
2305 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2308ArrayAttr BroadcastOp::getIndexingMaps() {
2310 int64_t rank = getInit().getType().getRank();
2316void BroadcastOp::getEffects(
2317 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2332 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2333 if (!defBroadcastOp)
2338 Value init = broadcastOp.getInit();
2342 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2343 if (!llvm::is_contained(dimensions, dim))
2344 dimMap.push_back(dim);
2346 for (
auto dim : defDimensions)
2347 foldedDims.push_back(dimMap[dim]);
2349 llvm::sort(foldedDims);
2351 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2356void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2357 MLIRContext *context) {
2358 results.
add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2365void linalg::YieldOp::print(OpAsmPrinter &p) {
2366 if (getNumOperands() > 0)
2367 p <<
' ' << getOperands();
2369 if (getNumOperands() > 0)
2370 p <<
" : " << getOperandTypes();
2373ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &
result) {
2374 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2375 SmallVector<Type, 2> types;
2385static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2386 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2387 return op.emitOpError(
"expected number of yield values (")
2388 << op.getNumOperands()
2389 <<
") to match the number of inits / outs operands of the enclosing "
2390 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2392 for (
OpOperand &opOperand : op->getOpOperands()) {
2394 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2396 if (isa<MemRefType, RankedTensorType>(elementType))
2398 if (opOperand.get().getType() != elementType)
2399 return op.emitOpError(
"type of yield operand ")
2400 << (opOperand.getOperandNumber() + 1) <<
" ("
2401 << opOperand.get().getType() <<
") doesn't match "
2402 <<
"the element type of the enclosing linalg.generic op ("
2403 << elementType <<
")";
2408LogicalResult linalg::YieldOp::verify() {
2409 auto *parentOp = (*this)->getParentOp();
2410 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2411 return emitOpError(
"expected single non-empty parent region");
2413 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2416 return emitOpError(
"expected parent op with LinalgOp interface");
2423LogicalResult IndexOp::verify() {
2424 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2426 return emitOpError(
"expected parent op with LinalgOp interface");
2427 if (linalgOp.getNumLoops() <= getDim())
2429 << getDim() <<
") to be lower than the number of loops ("
2430 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2434OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2435 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2440 return OpFoldResult{};
2443 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2444 uint64_t dim = getDim();
2445 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2446 if (loopBounds[dim] == 1)
2447 return IntegerAttr::get(IndexType::get(
getContext()), 0);
2449 return OpFoldResult{};
2454#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2456#define GET_OP_CLASSES
2457#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2459#define GET_OP_CLASSES
2460#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2461#define GET_OP_CLASSES
2462#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2479 for (
unsigned i = 0; i < num; ++i)
2486 auto rangeA = llvm::make_range(a.begin(), a.end());
2487 auto rangeB = llvm::make_range(
b.begin(),
b.end());
2488 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2489 return llvm::to_vector<4>(concatRanges);
2493 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2495 for (
auto size :
memref.getShape())
2502 if (
auto as =
memref.getMemorySpace()) {
2503 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2504 ss <<
"as" << attr.getInt();
2510 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2513 vec.getShape(), [&](
int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2526 assert(isa<LinalgOp>(op));
2528 std::string fun =
"";
2530 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2531 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2532 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2533 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2537 llvm::replace(name,
'.',
'_');
2538 llvm::raw_string_ostream ss(name);
2542 return std::string();
2557 LogicalResult matchAndRewrite(LinalgOp op,
2559 for (
OpOperand &opOperand : op->getOpOperands()) {
2563 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2566 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2577struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2578 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2580 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2581 PatternRewriter &rewriter)
const override {
2585 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2592 if (castOp->getBlock() != linalgOp->getBlock())
2595 OpBuilder::InsertionGuard guard(rewriter);
2598 Location loc = linalgOp.getLoc();
2599 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2602 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2608 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2610 tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
2611 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2612 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2613 linalgOp.getDpsInits().end());
2614 outputOperands[resultNumber] = newOperand;
2615 newOperands.append(outputOperands.begin(), outputOperands.end());
2617 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2618 linalgOp->result_type_end());
2619 resultTypes[resultNumber] = resultType;
2620 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2623 Value castBack = tensor::CastOp::create(
2627 results[resultNumber] = castBack;
2636static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2637 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2638 for (OpOperand &opOperand : operands) {
2639 if (linalgOp.isScalar(&opOperand))
2641 Value src = opOperand.get();
2642 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2643 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2649 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2651 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2652 Value castSource = castOp.getSource();
2653 auto castSourceType =
2654 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2655 if (castSourceType && castSourceType.hasStaticShape())
2656 sourceShape = castSourceType.getShape();
2662 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2663 if (sourceType.isDynamicDim(i))
2665 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2666 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2676static void createNewOperandWithStaticSizes(
2677 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2678 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2679 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2680 bool &changeNeeded) {
2681 Value src = opOperand->
get();
2682 newOperands.push_back(src);
2683 if (linalgOp.isScalar(opOperand))
2685 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2686 Type resultType = sourceType;
2687 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2688 resultTypes.push_back(resultType);
2691 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2692 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2693 SmallVector<int64_t> newShape;
2696 bool newOperandNeeded =
false;
2697 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2698 int64_t dimShape = sourceShape[i];
2699 AffineExpr dimExpr = sourceMap.
getResult(i);
2700 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2701 newShape.push_back(dimShape);
2707 newShape.push_back(affineExprToSize[dimExpr]);
2708 newOperandNeeded =
true;
2710 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2711 sourceType.getEncoding());
2712 if (newOperandNeeded) {
2713 changeNeeded =
true;
2716 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2718 newOperands[index] = newOperand;
2720 if (linalgOp.isDpsInit(opOperand))
2721 resultTypes.push_back(resultType);
2727struct InferStaticShapeOfOperands :
public OpInterfaceRewritePattern<LinalgOp> {
2728 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2730 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2731 PatternRewriter &rewriter)
const override {
2732 if (!linalgOp.hasPureTensorSemantics())
2736 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2737 return !map.isProjectedPermutation();
2742 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2743 Location loc = linalgOp.getLoc();
2747 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2749 SmallVector<Value> newOperands;
2750 SmallVector<Type> resultTypes;
2754 bool changeNeeded =
false;
2755 newOperands.reserve(linalgOp->getNumOperands());
2756 resultTypes.reserve(linalgOp.getNumDpsInits());
2759 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2760 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2761 affineExprToSize, linalgOp, newOperands,
2762 resultTypes, changeNeeded);
2771 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2772 SmallVector<Value> replacements;
2774 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2775 Value newResult = std::get<1>(it);
2776 Value oldResult = std::get<0>(it);
2777 Type newType = newResult.
getType();
2778 Type oldType = oldResult.
getType();
2779 replacements.push_back(
2780 (newType != oldType)
2781 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2784 rewriter.
replaceOp(linalgOp, replacements);
2798LogicalResult SoftmaxOp::verify() {
2799 ShapedType inputType = getInputOperandType();
2800 ShapedType outputType = getOutputOperandType();
2802 ArrayRef<int64_t> inputShape = inputType.getShape();
2803 ArrayRef<int64_t> outputShape = outputType.getShape();
2807 int64_t inputRank = getInputOperandRank();
2808 int64_t dimension = getDimension();
2809 if ((dimension < 0) || (dimension >= inputRank))
2810 return emitOpError(
"incorrect dimension specified");
2815SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2816 int64_t operandRank = getInputOperandRank();
2817 SmallVector<Range> loopBounds(operandRank);
2818 Location loc = getLoc();
2821 Value source = getInput();
2822 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2823 loopBounds[dim].offset = zero;
2824 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2825 loopBounds[dim].stride = one;
2830SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2831 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2832 utils::IteratorType::parallel);
2833 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2834 return iteratorTypes;
2837FailureOr<TilingResult>
2838SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2839 ArrayRef<OpFoldResult> offsets,
2840 ArrayRef<OpFoldResult> sizes) {
2841 int64_t rank = getInputOperandRank();
2843 SmallVector<OpFoldResult> strides(rank, oneAttr);
2844 SmallVector<Value> tiledOperands;
2845 Operation *inputSlice =
2846 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2848 return emitOpError(
"failed to compute input slice");
2850 tiledOperands.emplace_back(inputSlice->
getResult(0));
2851 Operation *outputSlice =
2852 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2854 return emitOpError(
"failed to compute output slice");
2856 tiledOperands.emplace_back(outputSlice->
getResult(0));
2858 SmallVector<Type, 4> resultTypes;
2859 if (hasPureTensorSemantics())
2860 resultTypes.push_back(tiledOperands[1].
getType());
2861 Operation *tiledOp =
2862 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2864 return TilingResult{
2867 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2870LogicalResult SoftmaxOp::getResultTilePosition(
2871 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2872 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2873 SmallVector<OpFoldResult> &resultSizes) {
2874 if (resultNumber == 0) {
2875 resultOffsets.assign(offsets.begin(), offsets.end());
2876 resultSizes.assign(sizes.begin(), sizes.end());
2883LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2888SoftmaxOp::reifyResultShapes(OpBuilder &
b,
2890 SmallVector<OpFoldResult> shapes;
2891 Location loc = getOperation()->getLoc();
2892 IRRewriter rewriter(
b);
2893 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2894 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2895 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2896 if (!outputShapedType.isDynamicDim(dim)) {
2898 shapes.push_back(
b.getIndexAttr(inputShapedType.getDimSize(dim)));
2905 reifiedReturnShapes.emplace_back(std::move(shapes));
2909void SoftmaxOp::getEffects(
2910 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2912 for (
auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2913 if (!llvm::isa<MemRefType>(operand.
getType()))
2916 &getOperation()->getOpOperand(index), 0,
2921 for (OpOperand &operand : getDpsInitsMutable()) {
2922 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2953static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2955 int64_t dim,
bool allParallel =
false) {
2957 utils::IteratorType::parallel);
2959 iteratorTypes[dim] = utils::IteratorType::reduction;
2963 for (
int i = 0; i < inputRank; i++) {
2970 return std::make_tuple(iteratorTypes, indexingMaps);
2975template <
typename T>
2978 auto inputType = cast<ShapedType>(input.
getType());
2980 int64_t inputRank = inputShape.size();
2981 auto [iteratorTypes, indexingMaps] =
2983 assert(indexingMaps.size() == 2 &&
2984 "We should have two maps: 1 for the input, 1 for the output");
2985 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2987 auto genericOp = linalg::GenericOp::create(
2988 builder, loc, output.
getType(), input, output, indexingMaps,
2990 Value result = T::create(b, loc, args[0], args[1]);
2991 linalg::YieldOp::create(b, loc, result);
2993 return genericOp.getResult(0);
3001 auto inputType = cast<ShapedType>(input.
getType());
3003 int64_t inputRank = inputShape.size();
3005 builder, inputRank, dim,
true);
3006 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
3007 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3009 indexingMaps.push_back(indexingMaps[0]);
3010 auto genericOp = linalg::GenericOp::create(
3012 indexingMaps, iteratorTypes,
3014 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3015 Value result = math::ExpOp::create(b, loc, diff);
3016 linalg::YieldOp::create(b, loc, result);
3018 return genericOp.getResult(0);
3028 auto inputType = cast<ShapedType>(numerator.
getType());
3030 int64_t inputRank = inputShape.size();
3032 builder, inputRank, dim,
true);
3033 assert(indexingMaps.size() == 2 &&
3034 "We should have one map for each input (2)");
3035 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
3037 indexingMaps.push_back(indexingMaps[0]);
3038 auto genericOp = linalg::GenericOp::create(
3040 output, indexingMaps, iteratorTypes,
3042 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3043 linalg::YieldOp::create(b, loc, result);
3045 return genericOp.getResult(0);
3067FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &
b) {
3068 OpBuilder::InsertionGuard guard(
b);
3069 b.setInsertionPoint(*
this);
3070 Location loc = getLoc();
3071 Value input = getInput();
3072 ShapedType inputType = getInputOperandType();
3073 Type elementType = inputType.getElementType();
3074 int64_t reductionDim = getDimension();
3076 Value output = getOutput();
3077 dims.erase(dims.begin() + reductionDim);
3079 Value outputReduce = tensor::EmptyOp::create(
b, loc, dims, elementType);
3081 elementType,
b, loc,
3083 Value neutralForMaxFInit =
3084 linalg::FillOp::create(
b, loc, Value{neutralForMaxF}, outputReduce)
3096 linalg::FillOp::create(
b, loc, Value{zero}, outputReduce).
result();
3102 buildDivOp(
b, loc, numerator, denominator, output, reductionDim);
3103 return SmallVector<Value>{
result};
3110LogicalResult WinogradFilterTransformOp::verify() {
3111 auto filterType = cast<ShapedType>(getFilter().
getType());
3112 ArrayRef<int64_t> filterShape = filterType.getShape();
3113 int64_t filterH = filterShape[getFilterHDim()];
3114 int64_t filterW = filterShape[getFilterWDim()];
3115 WinogradConv2DFmr fmr = getFmr();
3119 if (filterH != r && filterH != 1)
3120 return emitOpError(
"expect filter height either equals to r or 1");
3121 if (filterW != r && filterW != 1)
3122 return emitOpError(
"expect filter width either equals to r or 1");
3123 if (filterH == 1 && filterW == 1)
3124 return emitOpError(
"expect either filter height or width equals to r");
3126 SmallVector<int64_t> expectedOutputShape;
3127 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3128 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3129 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3130 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3132 auto outputType = cast<ShapedType>(getOutput().
getType());
3133 ArrayRef<int64_t> outputShape = outputType.getShape();
3135 return emitOpError(
"the output shape is not expected");
3141WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3142 Location loc = getLoc();
3145 Value filter = getFilter();
3146 int64_t filterRank = getFilterOperandRank();
3147 SmallVector<Range> loopBounds(filterRank);
3148 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3149 loopBounds[dim].offset = zeroAttr;
3150 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3151 loopBounds[dim].stride = oneAttr;
3156SmallVector<utils::IteratorType>
3157WinogradFilterTransformOp::getLoopIteratorTypes() {
3158 int64_t filterRank = getFilterOperandRank();
3159 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3160 utils::IteratorType::parallel);
3161 return iteratorTypes;
3164LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3165 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3166 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3167 SmallVector<OpFoldResult> &resultSizes) {
3169 ShapedType filterType = getFilterOperandType();
3170 ArrayRef<int64_t> filterShape = filterType.getShape();
3171 int64_t filterH = filterShape[getFilterHDim()];
3172 int64_t filterW = filterShape[getFilterWDim()];
3173 WinogradConv2DFmr fmr = getFmr();
3176 int64_t alpha = m + r - 1;
3177 int64_t alphaH = filterH != 1 ? alpha : 1;
3178 int64_t alphaW = filterW != 1 ? alpha : 1;
3182 resultOffsets.append(
3183 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3185 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3196FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3197 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3198 ArrayRef<OpFoldResult> sizes) {
3201 ShapedType filterType = getFilterOperandType();
3202 ArrayRef<int64_t> filterShape = filterType.getShape();
3203 int64_t filterH = filterShape[getFilterHDim()];
3204 int64_t filterW = filterShape[getFilterWDim()];
3207 SmallVector<Value> tiledOperands;
3208 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3210 sliceOffsets.append(
3211 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3212 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3213 sizes[getFilterCDim()]});
3214 int64_t filterRank = getFilterOperandRank();
3215 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3216 Location loc = getLoc();
3217 auto filterSlice = tensor::ExtractSliceOp::create(
3218 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3219 tiledOperands.emplace_back(filterSlice);
3221 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3226 int64_t outputRank = getOutputOperandRank();
3227 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3228 auto outputSlice = tensor::ExtractSliceOp::create(
3229 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3230 tiledOperands.emplace_back(outputSlice);
3232 SmallVector<Type> resultTypes;
3233 resultTypes.push_back(tiledOperands[1].
getType());
3234 Operation *tiledOp =
3235 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3237 return TilingResult{
3240 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3247LogicalResult WinogradInputTransformOp::verify() {
3248 auto inputType = cast<ShapedType>(getInput().
getType());
3249 ArrayRef<int64_t> inputShape = inputType.getShape();
3250 int64_t inputH = inputShape[getInputHDim()];
3251 int64_t inputW = inputShape[getInputWDim()];
3252 WinogradConv2DFmr fmr = getFmr();
3255 int64_t tileSize = m + r - 1;
3257 auto outputType = cast<ShapedType>(getOutput().
getType());
3258 ArrayRef<int64_t> outputShape = outputType.getShape();
3259 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3260 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3262 SmallVector<int64_t> expectedOutputShape(6, inputH);
3263 if (ShapedType::isDynamic(inputH)) {
3264 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3265 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3267 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3268 expectedOutputShape[getOutputTileHDim()] =
3269 leftTransform ? (inputH - (r - 1)) / m : inputH;
3271 if (ShapedType::isDynamic(inputW)) {
3272 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3273 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3275 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3276 expectedOutputShape[getOutputTileWDim()] =
3277 rightTransform ? (inputW - (r - 1)) / m : inputW;
3279 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3280 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3283 return emitOpError(
"the output shape is not expected");
3289WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3290 Location loc = getLoc();
3293 Value output = getOutput();
3294 int64_t outputRank = getOutputOperandRank();
3295 SmallVector<Range> loopBounds(outputRank);
3296 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3297 loopBounds[dim].offset = zeroAttr;
3299 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3300 loopBounds[dim].stride = oneAttr;
3305SmallVector<utils::IteratorType>
3306WinogradInputTransformOp::getLoopIteratorTypes() {
3307 int64_t outputRank = getOutputOperandRank();
3308 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3309 utils::IteratorType::parallel);
3310 return iteratorTypes;
3313LogicalResult WinogradInputTransformOp::getResultTilePosition(
3314 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3315 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3316 SmallVector<OpFoldResult> &resultSizes) {
3318 ShapedType outputType = getOutputOperandType();
3319 ArrayRef<int64_t> outputShape = outputType.getShape();
3320 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3321 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3323 WinogradConv2DFmr fmr = getFmr();
3326 int64_t alpha = m + r - 1;
3327 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3328 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3333 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3334 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3335 offsets[getOutputCDim()]});
3336 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3337 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3338 sizes[getOutputCDim()]});
3349FailureOr<TilingResult>
3350WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3351 ArrayRef<OpFoldResult> offsets,
3352 ArrayRef<OpFoldResult> sizes) {
3354 WinogradConv2DFmr fmr = getFmr();
3358 ShapedType outputType = getOutputOperandType();
3359 ArrayRef<int64_t> outputShape = outputType.getShape();
3360 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3361 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3363 Location loc = getLoc();
3365 auto identityAffineMap =
3367 auto offsetAffineMap =
3370 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3371 offsets[getOutputTileHDim()]);
3373 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3374 offsets[getOutputTileWDim()]);
3378 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3380 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3382 SmallVector<Value> tiledOperands;
3383 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3385 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3386 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3387 sliceOffsets.append(
3388 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3389 OpFoldResult sizeH =
3390 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3391 OpFoldResult sizeW =
3392 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3394 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3395 int64_t inputRank = getInputOperandRank();
3396 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3397 auto inputSlice = tensor::ExtractSliceOp::create(
3398 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3399 tiledOperands.emplace_back(inputSlice);
3401 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3406 int64_t outputRank = getOutputOperandRank();
3407 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3408 auto outputSlice = tensor::ExtractSliceOp::create(
3409 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3410 tiledOperands.emplace_back(outputSlice);
3412 SmallVector<Type> resultTypes;
3413 resultTypes.push_back(tiledOperands[1].
getType());
3414 Operation *tiledOp =
3415 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3417 return TilingResult{
3420 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3427LogicalResult WinogradOutputTransformOp::verify() {
3428 auto valueType = cast<ShapedType>(getValue().
getType());
3429 ArrayRef<int64_t> valueShape = valueType.getShape();
3430 int64_t valueH = valueShape[getValueAlphaHDim()];
3431 int64_t valueW = valueShape[getValueAlphaWDim()];
3432 int64_t valueTileH = valueShape[getValueTileHDim()];
3433 int64_t valueTileW = valueShape[getValueTileWDim()];
3434 WinogradConv2DFmr fmr = getFmr();
3437 bool leftTransform = valueH != 1;
3438 bool rightTransform = valueW != 1;
3440 int64_t outputRank = getOutputOperandRank();
3441 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3442 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3443 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3445 if (valueH != (leftTransform ? m + r - 1 : 1))
3446 return emitOpError(
"expect input height equals to input tile size");
3447 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3449 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3450 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3452 if (valueW != (rightTransform ? m + r - 1 : 1))
3453 return emitOpError(
"expect input width equals to input tile size");
3454 expectedOutputShape[getOutputWDim()] =
3455 (rightTransform ? m : 1) * valueTileW;
3457 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3458 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3460 auto outputType = cast<ShapedType>(getOutput().
getType());
3461 ArrayRef<int64_t> outputShape = outputType.getShape();
3463 return emitOpError(
"the output shape is not expected");
3469WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3470 Location loc = getLoc();
3473 Value value = getValue();
3474 int64_t valueRank = getValueOperandRank();
3475 SmallVector<Range> loopBounds(valueRank);
3476 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3477 loopBounds[dim].offset = zeroAttr;
3479 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3480 loopBounds[dim].stride = oneAttr;
3485SmallVector<utils::IteratorType>
3486WinogradOutputTransformOp::getLoopIteratorTypes() {
3487 int64_t valueRank = getValueOperandRank();
3488 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3489 utils::IteratorType::parallel);
3490 return iteratorTypes;
3493LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3494 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3495 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3496 SmallVector<OpFoldResult> &resultSizes) {
3497 WinogradConv2DFmr fmr = getFmr();
3501 Location loc = getLoc();
3503 auto identityAffineMap =
3508 ShapedType valueType = getValueOperandType();
3509 ArrayRef<int64_t> valueShape = valueType.getShape();
3510 int64_t valueH = valueShape[0];
3511 int64_t valueW = valueShape[1];
3513 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3514 offsets[getValueTileHDim()]);
3516 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3517 offsets[getValueTileWDim()]);
3519 builder, loc, affineMap, sizes[getValueTileHDim()]);
3521 builder, loc, affineMap, sizes[getValueTileWDim()]);
3524 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3525 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3526 OpFoldResult sizeH =
3527 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3528 OpFoldResult sizeW =
3529 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3531 resultOffsets.append(
3532 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3534 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3544FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3545 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3546 ArrayRef<OpFoldResult> sizes) {
3549 Location loc = getLoc();
3550 SmallVector<Value> tiledOperands;
3551 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3553 ShapedType valueType = getValueOperandType();
3554 ArrayRef<int64_t> valueShape = valueType.getShape();
3555 int64_t alphaH = valueShape[getValueAlphaHDim()];
3556 int64_t alphaW = valueShape[getValueAlphaWDim()];
3560 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3561 offsets[getValueTileWDim()], offsets[getValueNDim()],
3562 offsets[getValueFDim()]});
3563 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3564 sizes[getValueTileWDim()], sizes[getValueNDim()],
3565 sizes[getValueFDim()]});
3566 int64_t valueRank = getValueOperandRank();
3567 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3568 auto valueSlice = tensor::ExtractSliceOp::create(
3569 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3570 tiledOperands.emplace_back(valueSlice);
3572 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3577 int64_t outputRank = getOutputOperandRank();
3578 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3579 auto outputSlice = tensor::ExtractSliceOp::create(
3580 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3581 tiledOperands.emplace_back(outputSlice);
3583 SmallVector<Type> resultTypes;
3584 resultTypes.push_back(tiledOperands[1].
getType());
3585 Operation *tiledOp =
3586 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3588 return TilingResult{
3591 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3605 llvm::set_union(explicitSet, defaultSet);
3606 return explicitSet == defaultSet;
3626 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3628 auto opIndexingMap = opIndexingMaps[opIndex];
3629 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3632 return matmulOp->emitOpError()
3633 <<
"Unexpected dim expression in map result.";
3636 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3637 return matmulOp->emitOpError()
3638 <<
"Invalid broadcast requested, should be (d2).";
3647template <
typename OpTy>
3650 AffineMap defaultIndexingMap,
bool isLHS) {
3651 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3652 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3653 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3656 return batchVariantMatmulOp->emitOpError()
3657 <<
"Unexpected result dim expression (outside the set of default "
3662 return batchVariantMatmulOp->emitOpError()
3663 <<
"no. of result dim expressions exceeds 3.";
3665 auto hasValidBatchDim = [](
AffineMap map) {
3672 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3673 return batchVariantMatmulOp->emitOpError()
3674 <<
"Invalid broadcast requested.";
3675 }
else if (!hasValidBatchDim(opIndexingMap)) {
3676 return batchVariantMatmulOp->emitOpError()
3677 <<
"Invalid batch dimension expression.";
3685template <
typename OpTy>
3688 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3689 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3690 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3691 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3694 return batchVariantMatmulOp->emitOpError()
3695 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3698 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3700 return batchVariantMatmulOp->emitOpError()
3701 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3705 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3706 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3707 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3708 outputMap.getResult(1).isFunctionOfDim(1) &&
3709 outputMap.getResult(2).isFunctionOfDim(2)
3710 : outputMap.getResult(0).isFunctionOfDim(1) &&
3711 outputMap.getResult(1).isFunctionOfDim(2);
3714 if (!areValidOutputResultDim(opIndexingMap)) {
3715 return batchVariantMatmulOp->emitOpError()
3716 <<
"Invalid output map result dimension.";
3725template <
typename OpTy>
3730 batchVariantMatmulOp.getIndexingMapsArray();
3732 batchVariantMatmulOp.getDefaultIndexingMaps(
3733 batchVariantMatmulOp->getContext());
3735 if (opIndexingMaps.size() != 3)
3736 return batchVariantMatmulOp->emitOpError()
3737 <<
"Indexing_map attribute must have 3 affine maps.";
3739 auto opIndexingMap = opIndexingMaps[opIndex];
3740 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3748 defaultIndexingMap, opIndex == 0)))
3758 if (m == 2 && r == 3)
3759 return WinogradConv2DFmr::F_2_3;
3760 if (m == 4 && r == 3)
3761 return WinogradConv2DFmr::F_4_3;
3762 if (m == 2 && r == 5)
3763 return WinogradConv2DFmr::F_2_5;
3764 return std::nullopt;
3769 case WinogradConv2DFmr::F_2_3:
3771 case WinogradConv2DFmr::F_4_3:
3773 case WinogradConv2DFmr::F_2_5:
3782static FailureOr<SmallVector<SmallVector<int64_t>>>
3785 for (
auto map : maps) {
3786 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3790 for (
auto result : attr.getAffineMap().getResults()) {
3791 auto dim = dyn_cast<AffineDimExpr>(
result);
3794 pos.push_back(dim.getPosition());
3796 positions.push_back(pos);
3809 return indexingMaps;
3812bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3813 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3816 if (maps.size() != 3)
3821 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3822 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3823 (*positions)[2] == SmallVector<int64_t>{0, 1};
3826SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3827 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3828 utils::IteratorType::parallel,
3829 utils::IteratorType::reduction};
3832unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3834std::string MatmulOp::getLibraryCallName() {
3838bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3842bool MatmulOp::hasUserDefinedMaps() {
3843 SmallVector<AffineMap, 3> defaultMaps =
3845 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3846 return defaultMaps != explicitMaps;
3851void MatmulOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
3852 ArrayRef<NamedAttribute> attrs,
3855 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3860 "MatmulOp regionBuilder expects 3 args");
3861 RegionBuilderHelper helper(
b, block);
3862 SmallVector<Value> yields;
3864 TypeFn castVal = TypeFn::cast_signed;
3865 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3866 return attr.
getName() ==
"cast";
3868 if (castIter != attrs.end()) {
3869 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3877 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3880 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
3884 yields.push_back(value4);
3885 helper.yieldOutputs(yields);
3895bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3896 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3897 AffineExpr expr = bcastMap.
getResult(0);
3907 ArrayAttr arrayAttr;
3911 if (llvm::any_of(arrayAttr,
3912 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3914 <<
"element of indexing_maps array is not an affine_map";
3921 if (failed(indexingMapsAttr))
3924 if (*indexingMapsAttr ==
nullptr) {
3925 auto indexingMapAttrs = llvm::map_to_vector(
3926 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3931 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
3933 MatmulOp::getRegionBuilder());
3936void MatmulOp::print(OpAsmPrinter &p) {
3937 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3938 MatmulOp::getDefaultIndexingMaps(
getContext()),
3939 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
3940 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3941 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3943 std::array<StringRef, 3> elidedAttrs = {
3944 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3950LogicalResult MatmulOp::verify() {
3952 if (!hasUserDefinedMaps())
3955 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3962LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3966void MatmulOp::getEffects(
3967 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3969 if (hasPureTensorSemantics())
3978SmallVector<AffineMap>
3979MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3980 AffineExpr d0, d1, d2;
3986 return {mapLHS, mapRHS, mapOut};
3990 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3993 if (maps.size() != 3)
3996 if (failed(positions))
4008 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4016 build(builder, state, inputs, outputs, attributes);
4017 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4018 assert(res &&
"builder didn't return the right type");
4028 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4037 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4038 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4039 assert(res &&
"builder didn't return the right type");
4049 result.addAttribute(
"cast", cast);
4051 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4060 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4061 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4062 assert(res &&
"builder didn't return the right type");
4067 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4069 op->
getAttr(
"indexing_maps"));
4073MatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4080 return {mapLHS, mapRHS, mapOut};
4084 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4087 if (maps.size() != 3)
4090 if (failed(positions))
4102 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4110 build(builder, state, inputs, outputs, attributes);
4111 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4112 assert(res &&
"builder didn't return the right type");
4122 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4131 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4132 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4133 assert(res &&
"builder didn't return the right type");
4143 result.addAttribute(
"cast", cast);
4145 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4154 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4155 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4156 assert(res &&
"builder didn't return the right type");
4161 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4163 op->
getAttr(
"indexing_maps"));
4167BatchMatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4174 return {mapLHS, mapRHS, mapOut};
4178 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4181 if (maps.size() != 3)
4184 if (failed(positions))
4195 BatchMatmulOp::getRegionBuilder(),
4196 getDefaultIndexingMaps(builder));
4204 build(builder, state, inputs, outputs, attributes);
4205 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4206 assert(res &&
"builder didn't return the right type");
4215 BatchMatmulOp::getRegionBuilder(),
4216 getDefaultIndexingMaps(builder));
4225 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4226 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4227 assert(res &&
"builder didn't return the right type");
4235 result.addAttribute(
"cast", cast);
4237 BatchMatmulOp::getRegionBuilder(),
4238 getDefaultIndexingMaps(builder));
4247 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4248 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4249 assert(res &&
"builder didn't return the right type");
4254 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4256 op->
getAttr(
"indexing_maps"));
4260BatchMatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4267 return {mapLHS, mapRHS, mapOut};
4271 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4274 if (maps.size() != 3)
4277 if (failed(positions))
4288 BatchMatmulOp::getRegionBuilder(),
4289 getDefaultIndexingMaps(builder));
4297 build(builder, state, inputs, outputs, attributes);
4298 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4299 assert(res &&
"builder didn't return the right type");
4308 BatchMatmulOp::getRegionBuilder(),
4309 getDefaultIndexingMaps(builder));
4318 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4319 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4320 assert(res &&
"builder didn't return the right type");
4328 result.addAttribute(
"cast", cast);
4330 BatchMatmulOp::getRegionBuilder(),
4331 getDefaultIndexingMaps(builder));
4340 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4341 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4342 assert(res &&
"builder didn't return the right type");
4347 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4349 op->
getAttr(
"indexing_maps"));
4357 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4368 auto dimExpr = dyn_cast<AffineDimExpr>(
result);
4369 assert(dimExpr &&
"affine_map is a projected permutation");
4370 dimsInOutput[dimExpr.getPosition()] =
true;
4374 for (
auto dimOccursInOutput : dimsInOutput)
4375 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4376 : utils::IteratorType::reduction);
4378 return iteratorTypes;
4381unsigned ContractOp::getNumRegionArgs() {
return 3; }
4384void ContractOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
4385 ArrayRef<NamedAttribute> attrs,
4388 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
4393 "ContractOp regionBuilder expects 3 args");
4394 RegionBuilderHelper helper(
b, block);
4396 TypeFn castSignedness = TypeFn::cast_signed;
4397 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4398 return attr.
getName() ==
"cast";
4400 if (castIter != attrs.end()) {
4401 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4407 Value lhsAtOutType =
4408 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
4409 Value rhsAtOutType =
4410 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
4411 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4413 if (!productAtOutType)
4419 helper.yieldOutputs({
result});
4422ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &
result) {
4424 if (
failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
4426 "expected 'indexing_maps' attribute");
4427 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4433void ContractOp::print(OpAsmPrinter &p) {
4434 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4436 p, getOperation(), getInputs(), getOutputs(),
4437 {
"indexing_maps",
"operandSegmentSizes"});
4440LogicalResult ContractOp::verify() {
4441 int iterationSpaceDims = -1;
4446 SmallVector<size_t> inOccurrences;
4447 SmallVector<size_t> outOccurrences;
4450 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4451 bool isInput) -> LogicalResult {
4454 return emitError(
"provided affine_map is not a projected permutation");
4457 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
4459 return emitError(
"ranks of shaped operand and results of corresponding "
4460 "affine_map differ");
4462 return emitError(
"affine_map specifies shaped access while operand has "
4467 if (iterationSpaceDims == -1) {
4469 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4470 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4471 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4472 return emitError(
"iteration spaces of provided affine_maps differ");
4476 for (AffineExpr affineExpr : affineMap.
getResults()) {
4477 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4479 llvm_unreachable(
"affine_map is a projected permutation");
4482 inOccurrences[affineDimExpr.getPosition()] += 1;
4484 outOccurrences[affineDimExpr.getPosition()] += 1;
4490 for (
auto &&[affineMap, operandType, isInput] :
4491 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4492 SmallVector<bool>{
true,
true,
false})) {
4493 if (
failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4497 bool hasContractingDim =
false;
4498 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4499 size_t inOccCount = inOccurrences[dimIndex];
4500 size_t outOccCount = outOccurrences[dimIndex];
4503 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4505 if (inOccCount == 0 && outOccCount == 0)
4506 return emitError() <<
"iteration space dim at index " << dimIndex
4507 <<
" not used to access any operand";
4518 if (inOccCount == 1 && outOccCount != 1)
4520 <<
"iteration space dim at index " << dimIndex
4521 <<
" is neither a contracting dim nor of parallel iteration type";
4524 if (!hasContractingDim)
4525 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4530LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4534void ContractOp::getEffects(
4535 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4537 if (hasPureTensorSemantics())
4549SmallVector<AffineMap>
4550BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4551 AffineExpr d0, d1, d2, d3;
4552 SmallVector<AffineMap> indexingMaps;
4554 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4555 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4556 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4557 return indexingMaps;
4560bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4561 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4564 if (maps.size() != 3)
4569 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4570 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4571 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4574SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4575 return SmallVector<utils::IteratorType>{
4576 utils::IteratorType::parallel, utils::IteratorType::parallel,
4577 utils::IteratorType::parallel, utils::IteratorType::reduction};
4580unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4582std::string BatchMatmulOp::getLibraryCallName() {
4588bool BatchMatmulOp::hasUserDefinedMaps() {
4589 SmallVector<AffineMap, 3> defaultMaps =
4591 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4592 return defaultMaps != explicitMaps;
4602bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
bool isLHS) {
4604 "Expected less than 3 result dim expr.");
4605 bool isValid =
false;
4606 enum Indices { batchPos, mPos, nPos, kPos };
4608 AffineExpr expr = bcastMap.
getResult(0);
4611 AffineExpr expr0 = bcastMap.
getResult(0);
4612 AffineExpr expr1 = bcastMap.
getResult(1);
4617 : ((expr0.isFunctionOfDim(batchPos) &&
4618 expr1.isFunctionOfDim(kPos)) ||
4619 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4624void BatchMatmulOp::regionBuilder(
4625 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4628 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4633 "BatchMatmulOp regionBuilder expects 3 args");
4634 RegionBuilderHelper helper(
b, block);
4635 SmallVector<Value> yields;
4637 TypeFn castVal = TypeFn::cast_signed;
4638 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4639 return attr.
getName() ==
"cast";
4641 if (castIter != attrs.end()) {
4642 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4647 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4648 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4649 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4651 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
4652 yields.push_back(addVal);
4653 helper.yieldOutputs(yields);
4656ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &
result) {
4657 SmallVector<Attribute, 3> indexingMapsAttr;
4669 if (!isa<AffineMapAttr>(mapAttr)) {
4671 "expected affine map attribute");
4673 indexingMapsAttr.push_back(mapAttr);
4683 if (indexingMapsAttr.empty()) {
4684 indexingMapsAttr = llvm::map_to_vector(
4685 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4686 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4688 result.addAttribute(
"indexing_maps",
4691 return ::parseNamedStructuredOp(parser,
result,
4692 BatchMatmulOp::getNumRegionArgs(),
4693 BatchMatmulOp::getRegionBuilder());
4696void BatchMatmulOp::print(OpAsmPrinter &p) {
4697 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4698 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4699 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4700 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4701 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4703 std::array<StringRef, 3> elidedAttrs = {
4704 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4710LogicalResult BatchMatmulOp::verify() {
4713 if (!hasUserDefinedMaps())
4716 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4723LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4724 SmallVectorImpl<OpFoldResult> &) {
4728void BatchMatmulOp::getEffects(
4729 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4731 if (hasPureTensorSemantics())
4745struct ArityGroupAndKind {
4747 ElementwiseArityGroup arityGroup;
4753 TernaryFn ternaryFn;
4757unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4758 return static_cast<unsigned>(arityGroup);
4763 constexpr int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4764 constexpr int lastBinary =
4765 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4766 constexpr int lastTernary =
4767 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4769 int val =
static_cast<int>(kind);
4770 ArityGroupAndKind
result;
4772 if (val < lastUnary) {
4773 result.arityGroup = ElementwiseArityGroup::Unary;
4774 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4777 if (val < lastBinary) {
4778 result.arityGroup = ElementwiseArityGroup::Binary;
4779 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4782 if (val >= lastTernary) {
4783 llvm_unreachable(
"unhandled ElementwiseFn");
4785 result.arityGroup = ElementwiseArityGroup::Ternary;
4786 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4791 auto rank = getResultRank();
4796ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4802ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &
result) {
4805 mlir::linalg::ElementwiseKind elemwiseKindVal;
4810 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4811 if (!elemwiseKindAttr)
4813 "expected ElementwiseKind attribute");
4814 elemwiseKindVal = elemwiseKindAttr.getValue();
4817 "expected operation 'kind' attribute");
4820 "kind", ElementwiseKindAttr::get(parser.
getContext(), elemwiseKindVal));
4823 SmallVector<Attribute, 3> indexingMapsAttr;
4833 if (!isa<AffineMapAttr>(mapAttr))
4835 "expected affine map attribute");
4836 indexingMapsAttr.push_back(mapAttr);
4847 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4849 ElementwiseOp::getRegionBuilder())) {
4851 "unable to parse elemwise op");
4855 if (indexingMapsAttr.empty()) {
4858 auto resultType =
result.operands[
result.operands.size() - 1].getType();
4859 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4862 "return type needs to be shaped type");
4863 auto numDims = shapedType.getRank();
4864 indexingMapsAttr = llvm::map_to_vector(
4865 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4867 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4870 result.addAttribute(
"indexing_maps",
4875void ElementwiseOp::print(OpAsmPrinter &p) {
4878 SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes",
"kind",
4882 unsigned numDims = getResultRank();
4884 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4885 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4887 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4889 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4890 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4898void ElementwiseOp::regionBuilder(
4899 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4901 ElementwiseKind elemwiseKind;
4902 for (
auto attr : attrs) {
4903 if (attr.getName() ==
b.getStringAttr(
"kind")) {
4904 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4905 assert(kindAttr &&
"op kind attribute incorrectly set");
4906 elemwiseKind = kindAttr.getValue();
4912 auto arityGroup = groupAndKind.arityGroup;
4913 auto kind = groupAndKind.kind;
4915 getArityGroupAsUInt(arityGroup) + 1 ) {
4916 emitError() <<
"Elementwise regionBuilder expects "
4917 << (getArityGroupAsUInt(arityGroup) + 1) <<
" args, got "
4922 getArityGroupAsUInt(arityGroup) + 1
4923 &&
"Elementwise regionBuilder number of block args mismatch");
4925 RegionBuilderHelper helper(
b, block);
4926 SmallVector<Value> yields;
4929 if (arityGroup == ElementwiseArityGroup::Unary) {
4932 }
else if (arityGroup == ElementwiseArityGroup::Binary) {
4936 }
else if (arityGroup == ElementwiseArityGroup::Ternary) {
4941 assert(
false &&
"found unhandled category in elemwise");
4944 yields.push_back(
result);
4945 helper.yieldOutputs(yields);
4948LogicalResult ElementwiseOp::fold(FoldAdaptor,
4949 SmallVectorImpl<OpFoldResult> &) {
4953void ElementwiseOp::getEffects(
4954 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4956 if (hasPureTensorSemantics())
4969template <
typename OpTy,
typename>
4972 RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4973 ? packOrUnPack.getDestType()
4974 : packOrUnPack.getSourceType();
4975 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4976 ? packOrUnPack.getSourceType()
4977 : packOrUnPack.getDestType();
4979 packedType.getShape().take_front(unpackedType.getRank()));
4980 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5002 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5004 .take_back(mixedTiles.size()),
5007 if (
shape == ShapedType::kDynamic) {
5008 newMixedTileSizes.push_back(std::get<1>(it));
5015 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
5017 newMixedTileSizes.push_back(
tile);
5020 "tile size and dim size don't match!");
5021 newMixedTileSizes.push_back(
5026 return newMixedTileSizes;
5029template <
typename OpTy>
5033 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5034 "applies to only pack or unpack operations");
5035 int64_t destRank = op.getDestRank();
5037 reifiedReturnShapes[0] =
5042template <
typename OpTy>
5044 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5045 "applies to only pack or unpack operations");
5049 assert(tiles.size() == dimsToTile.size() &&
5050 "tiles must match indices of dimension to block");
5052 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5053 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5054 return dimAndTileMapping;
5057template <
typename OpTy>
5059 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5060 "applies to only pack or unpack operations");
5063 unsigned dynamicValIndex = 0;
5064 for (
int64_t staticTile : op.getStaticInnerTiles()) {
5065 if (ShapedType::isStatic(staticTile))
5068 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5070 return mixedInnerTiles;
5073template <
typename OpTy>
5075 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5076 "applies to only pack or unpack operations");
5089 size_t dimsPosSize = dimsPos.size();
5090 if (dimsPosSize > rank)
5093 if (dimsPosSize != uniqued.size())
5095 return llvm::any_of(dimsPos, [rank](
int64_t dimPos) {
5096 return dimPos < 0 || dimPos >=
static_cast<int64_t>(rank);
5100template <
typename OpTy>
5102 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5103 "applies to only pack or unpack operations");
5104 Operation *op = packOrUnPack.getOperation();
5113 if (hasZeros(mixedTiles))
5114 return op->
emitError(
"invalid zero tile factor");
5117 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
5118 ? packOrUnPack.getSourceType()
5119 : packOrUnPack.getDestType();
5120 size_t unpackedRank = unpackedType.getRank();
5124 return op->
emitError(
"invalid inner_dims_pos vector");
5126 return op->
emitError(
"invalid outer_dims_perm vector");
5127 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5128 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
5132 if (mixedTiles.size() > unpackedRank) {
5133 return op->
emitError(
"tiling factors must be less than or equal to the "
5134 "input rank for pack or output rank for unpack");
5136 if (mixedTiles.size() != innerDimsPos.size()) {
5138 "tiling factors must equal the number of dimensions to tile");
5141 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5142 ? packOrUnPack.getDestType()
5143 : packOrUnPack.getSourceType();
5144 size_t packedRank = packedType.getRank();
5146 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5147 if (expectedPackedRank != packedRank) {
5149 "packed rank != (unpacked rank + num tiling factors), got ")
5150 << packedRank <<
" != " << expectedPackedRank;
5156 RankedTensorType expectedPackedType = PackOp::inferPackedType(
5157 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
5159 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
5161 [](std::tuple<int64_t, OpFoldResult> it) {
5162 int64_t shape = std::get<0>(it);
5163 if (Attribute attr =
5164 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
5165 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5166 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5167 return shape == staticTileSize;
5169 return ShapedType::isDynamic(
shape);
5171 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
5172 "tiled dimension in the packed type");
5175 packedType.getShape()))) {
5176 return op->emitError(
"expected ")
5177 << expectedPackedType <<
" for the packed domain value, got "
5190struct PackOrUnPackTransposeResult {
5197template <
typename OpTy>
5198static PackOrUnPackTransposeResult
5202 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5203 "applies to only pack or unpack operations");
5204 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5205 "some permutation must be non-empty");
5206 PackOrUnPackTransposeResult metadata;
5207 metadata.innerDimsPos =
5209 metadata.innerTiles =
5211 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5212 ? packOrUnPackOp.getSourceRank()
5213 : packOrUnPackOp.getDestRank();
5214 metadata.outerDimsPerm =
5215 packOrUnPackOp.getOuterDimsPerm().empty()
5216 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5218 if (!innerPermutation.empty()) {
5219 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5221 "invalid inner permutation");
5225 if (!outerPermutation.empty()) {
5226 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5228 "invalid outer permutation");
5239 setNameFn(getResult(),
"pack");
5245 std::optional<Value> paddingValue,
5247 assert(innerDimsPos.size() == innerTiles.size() &&
5248 "number of tile sizes specified must match the specified number of "
5249 "original dimensions to be tiled");
5253 build(builder, state, dest.
getType(), source, dest,
5254 paddingValue ? *paddingValue :
nullptr,
5255 outerDimsPerm.empty() ?
nullptr
5262PackOp::reifyResultShapes(
OpBuilder &builder,
5280 ShapedType inputType = getSourceType();
5281 int64_t inputRank = inputType.getRank();
5282 return getDestType().getShape().take_front(inputRank);
5286 auto innerDimsPos = getInnerDimsPos();
5293 if (!outerDimPermInv.empty())
5297 for (
auto index : innerDimsPos)
5298 res.push_back(outerDims[
index]);
5309 outputShape.take_front(inputShape.size()));
5310 if (!outerDimsPerm.empty()) {
5311 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5312 "expected output and outer_dims_perm to have same size");
5316 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5317 if (ShapedType::isDynamic(inputShape[pos]))
5321 if (!constantTile) {
5322 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5323 (inputShape[pos] % outputTileSizes[pos] != 0))
5325 }
else if (inputShape[pos] % (*constantTile) != 0) {
5338 outputShape.take_front(inputShape.size()));
5339 if (!outerDimsPerm.empty()) {
5340 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5341 "expected output and outer_dims_perm to have same size");
5345 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5346 if (ShapedType::isDynamic(inputShape[pos]) ||
5347 ShapedType::isDynamic(outputTileSizes[pos]))
5352 if (inputShape[pos] % (*constantTile) != 0)
5358LogicalResult PackOp::verify() {
5365 auto paddingValue = getPaddingValue();
5369 << getSourceType().getElementType()
5370 <<
" but got: " << paddingValue.getType();
5373 if (!paddingValue &&
5374 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
5375 getDestType().
getShape(), getOuterDimsPerm(),
5378 "invalid tile factor or output size provided. Only full tiles are "
5379 "supported when padding_value is not set");
5389 for (
auto o : ofrs) {
5391 if (llvm::dyn_cast_if_present<Value>(o))
5392 result.push_back(ShapedType::kDynamic);
5406 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5407 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5409 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5410 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5413 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5414 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5418 if (!outerDimsPerm.empty())
5422 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5435 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5437 builder, loc, ceilDivExpr,
5438 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5440 if (!outerDimsPerm.empty())
5442 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5447 innerDimsPos, outerDimsPerm);
5453 for (
unsigned i = 0; i < resultDims.size(); ++i) {
5454 if (ShapedType::isStatic(resultTypeShape[i]))
5465RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
5470 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5471 return RankedTensorType::get(resultShape, sourceType.getElementType());
5486 for (
auto [
index, value] : llvm::enumerate(
5487 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
5488 if (ShapedType::isDynamic(value))
5489 mixedSizes.push_back(
5490 tensor::DimOp::create(
b, loc, source,
index).getResult());
5492 mixedSizes.push_back(
b.getIndexAttr(value));
5494 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5495 int64_t dimPos = std::get<0>(it);
5497 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5499 if (!outerDimsPerm.empty())
5502 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5503 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5504 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
5511 *
this, innerPermutation, outerPermutation);
5512 Value transposedDest =
5513 createDestinationTensor(
b, loc, getSource(), metadata.innerTiles,
5514 metadata.innerDimsPos, metadata.outerDimsPerm);
5515 return PackOp::create(
b, loc, getSource(), transposedDest,
5516 metadata.innerDimsPos, metadata.innerTiles,
5517 getPaddingValue(), metadata.outerDimsPerm);
5521template <
typename OpTy>
5523 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5524 "applies to only pack or unpack operations");
5525 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5527 : op.getSourceType();
5529 for (
auto [dimDest,
tile] : llvm::zip(
5530 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5532 if (!constTileSize || ShapedType::isDynamic(dimDest))
5539 if (getPaddingValue())
5554 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5556 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5568 auto packTiles = packOp.getMixedTiles();
5569 auto unPackTiles = unPackOp.getMixedTiles();
5570 if (packTiles.size() != unPackTiles.size())
5572 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5581 auto srcType = op.getSourceType();
5582 if (llvm::any_of(op.getInnerDimsPos(),
5583 [&](
int64_t pos) { return srcType.isDynamicDim(pos); }))
5585 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5587 return !PackOp::requirePaddingValue(
5588 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5589 op.getOuterDimsPerm(), op.getMixedTiles());
5596 bool changeNeeded =
false;
5597 srcShape.assign(packOp.getSourceType().getShape().begin(),
5598 packOp.getSourceType().getShape().end());
5599 destShape.assign(packOp.getDestType().getShape().begin(),
5600 packOp.getDestType().getShape().end());
5601 llvm::SmallSetVector<int64_t, 4> innerDims;
5602 innerDims.insert_range(packOp.getInnerDimsPos());
5604 if (!packOp.getOuterDimsPerm().empty())
5606 int srcRank = packOp.getSourceRank();
5607 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5608 if (innerDims.contains(i))
5612 if (!inverseOuterDimsPerm.empty())
5613 destPos = inverseOuterDimsPerm[srcPos];
5614 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5615 ShapedType::isDynamic(destShape[destPos])) {
5618 int64_t size = srcShape[srcPos];
5619 if (ShapedType::isDynamic(size))
5620 size = destShape[destPos];
5621 srcShape[srcPos] = size;
5622 destShape[destPos] = size;
5623 changeNeeded =
true;
5625 return changeNeeded;
5628LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5630 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5631 if (unPackOp.getSourceType() == packOp.getDestType() &&
5632 !packOp.getPaddingValue() &&
5635 rewriter.
replaceOp(packOp, unPackOp.getSource());
5643 packOp.getPaddingValueMutable().clear();
5652 Value source = packOp.getSource();
5653 if (srcShape != packOp.getSourceType().getShape()) {
5654 auto newSrcType = packOp.getSourceType().clone(srcShape);
5656 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5658 Value dest = packOp.getDest();
5659 RankedTensorType originalResultType = packOp.getDestType();
5660 bool needUpdateDestType = (destShape != originalResultType.getShape());
5661 if (needUpdateDestType) {
5662 auto newDestType = packOp.getDestType().clone(destShape);
5664 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5667 packOp.getSourceMutable().assign(source);
5668 packOp.getDestMutable().assign(dest);
5669 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5672 if (needUpdateDestType) {
5675 tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
5684template <
typename PackOrUnpackOp>
5686 RankedTensorType packedTensorType) {
5687 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5688 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5689 "Function meant for pack/unpack");
5694 int64_t numPackedDims = innerDimsPos.size();
5695 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5696 if (orderedDims != innerDimsPos) {
5702 int64_t packedRank = packedTensorType.getRank();
5712 return llvm::all_of(
5713 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5714 [&packedShape](
int64_t i) {
return packedShape[i] == 1; });
5717bool PackOp::isLikePad() {
5718 auto packedTensorType =
5719 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5724 std::optional<Attribute> paddingValue;
5725 if (
auto pad = adaptor.getPaddingValue())
5727 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5728 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5729 getDestType(), paddingValue))
5730 return reshapedSource;
5769 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
5770 op.getInnerDimsPos(), newMixedTileSizes,
5771 op.getPaddingValue(), op.getOuterDimsPerm());
5772 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5775 Value oldResult = op.getResult();
5776 Value newResult = newOp.getResult();
5779 ? tensor::CastOp::create(rewriter, op->getLoc(),
5780 oldResult.
getType(), newResult)
5793void UnPackOp::getAsmResultNames(
5795 setNameFn(getResult(),
"unpack");
5799UnPackOp::reifyResultShapes(
OpBuilder &builder,
5817 ShapedType destType = getDestType();
5818 int64_t destRank = destType.getRank();
5819 return getSourceType().getShape().take_front(destRank);
5823 auto innerDimsPos = getInnerDimsPos();
5830 if (!outerDimPermInv.empty())
5834 for (
auto index : innerDimsPos)
5835 res.push_back(outerDims[
index]);
5840LogicalResult UnPackOp::verify() {
5856 assert(innerDimsPos.size() == innerTiles.size() &&
5857 "number of tile sizes specified must match the specified number of "
5858 "original dimensions to be tiled");
5862 build(builder, state, dest.
getType(), source, dest,
5863 outerDimsPerm.empty() ?
nullptr
5881 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
5883 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5884 if (srcType.isDynamicDim(i))
5885 mixedSizes.push_back(
5886 tensor::DimOp::create(
b, loc, source, i).getResult());
5888 mixedSizes.push_back(
b.getIndexAttr(srcType.getDimSize(i)));
5890 if (!outerDimsPerm.empty()) {
5895 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5896 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5898 auto elemType = srcType.getElementType();
5899 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
5903 Value transposedSource,
5907 *
this, innerPermutation, outerPermutation);
5908 return UnPackOp::create(
b, loc, transposedSource, getDest(),
5909 metadata.innerDimsPos, metadata.innerTiles,
5910 metadata.outerDimsPerm);
5917 bool changeNeeded =
false;
5918 srcShape.assign(op.getSourceType().getShape().begin(),
5919 op.getSourceType().getShape().end());
5920 destShape.assign(op.getDestType().getShape().begin(),
5921 op.getDestType().getShape().end());
5922 llvm::SmallSetVector<int64_t, 4> innerDims;
5923 innerDims.insert_range(op.getInnerDimsPos());
5925 if (!op.getOuterDimsPerm().empty())
5927 int destRank = op.getDestRank();
5928 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
5929 if (innerDims.contains(i))
5933 if (!inverseOuterDimsPerm.empty())
5934 srcPos = inverseOuterDimsPerm[destPos];
5935 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5936 ShapedType::isDynamic(destShape[destPos])) {
5939 int64_t size = srcShape[srcPos];
5940 if (ShapedType::isDynamic(size))
5941 size = destShape[destPos];
5942 srcShape[srcPos] = size;
5943 destShape[destPos] = size;
5944 changeNeeded =
true;
5946 return changeNeeded;
5949LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5952 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5953 if (packOp.getSourceType() != unPackOp.getDestType())
5955 if (packOp.getPaddingValue() ||
5959 rewriter.
replaceOp(unPackOp, packOp.getSource());
5963 if (
auto dstStyleOp =
5964 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5965 auto destValue = cast<OpResult>(unPackOp.getDest());
5966 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5968 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5972 if (unPackOp->hasOneUse()) {
5973 auto extractSliceUser =
5974 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5975 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
5978 auto newDest = tensor::ExtractSliceOp::create(
5979 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
5980 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5981 extractSliceUser.getMixedStrides());
5983 unPackOp.setDpsInitOperand(0, newDest);
5984 unPackOp.getResult().setType(newDest.
getType());
5986 rewriter.
replaceOp(extractSliceUser, unPackOp);
5995 Value source = unPackOp.getSource();
5996 if (srcShape != unPackOp.getSourceType().getShape()) {
5997 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5998 source = tensor::CastOp::create(rewriter, loc, newSrcType,
5999 unPackOp.getSource());
6001 Value dest = unPackOp.getDest();
6002 if (destShape != unPackOp.getDestType().getShape()) {
6003 auto newDestType = unPackOp.getDestType().clone(destShape);
6004 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6005 unPackOp.getDest());
6007 Value newOp = UnPackOp::create(
6008 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6009 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6011 unPackOp, unPackOp.getResult().
getType(), newOp);
6018bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6020 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6025 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6028 for (
auto [pos, tileSize] :
6029 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6030 if (unpackedTypeAfterFold.isDynamicDim(pos))
6032 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6034 if (ShapedType::isDynamic(tileSize))
6036 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6037 unpackedTypeAfterFold.getDimSize(pos);
6038 if (paddingSize >= tileSize)
6044bool UnPackOp::isLikeUnPad() {
6045 RankedTensorType packedTensorType = getSourceType();
6050 if (
OpFoldResult reshapedSource = reshapeConstantSource(
6051 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6053 return reshapedSource;
6082 Value sourceTensor = newOperands[0];
6086 rewriter, sourceTensor.
getType(), op.getMixedTiles());
6092 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6093 newOperands[1], op.getInnerDimsPos(),
6094 newMixedTileSizes, op.getOuterDimsPerm());
6095 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6098 Value oldResult = op.getResult();
6099 Value newResult = newOp.getResult();
6102 ? tensor::CastOp::create(rewriter, op->getLoc(),
6103 oldResult.
getType(), newResult)
6117 utils::IteratorType::reduction, utils::IteratorType::parallel,
6118 utils::IteratorType::parallel, utils::IteratorType::reduction};
6122BatchReduceMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
6126 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
6127 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
6129 return indexingMaps;
6132bool BatchReduceMatmulOp::isDefaultIndexingMaps(
Attribute attr) {
6133 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6136 if (maps.size() != 3)
6145unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
6147std::string BatchReduceMatmulOp::getLibraryCallName() {
6153bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6157 return defaultMaps != explicitMaps;
6167bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
6170 "Expected less than 3 result dim expr.");
6171 bool isValid =
false;
6172 enum Indices { batchPos, mPos, nPos, kPos };
6183 : ((expr0.isFunctionOfDim(batchPos) &&
6184 expr1.isFunctionOfDim(kPos)) ||
6185 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6190void BatchReduceMatmulOp::regionBuilder(
6194 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
6199 "BatchReduceMatmulOp regionBuilder expects 3 args");
6200 RegionBuilderHelper helper(
b, block);
6205 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
6207 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
6208 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
6210 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
6211 yields.push_back(addVal);
6212 helper.yieldOutputs(yields);
6215ParseResult BatchReduceMatmulOp::parse(
OpAsmParser &parser,
6228 if (!isa<AffineMapAttr>(mapAttr)) {
6230 "expected affine map attribute");
6232 indexingMapsAttr.push_back(mapAttr);
6242 if (indexingMapsAttr.empty()) {
6243 indexingMapsAttr = llvm::map_to_vector(
6244 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
6247 result.addAttribute(
"indexing_maps",
6249 return ::parseNamedStructuredOp(parser,
result,
6250 BatchReduceMatmulOp::getNumRegionArgs(),
6251 BatchReduceMatmulOp::getRegionBuilder());
6256 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
6259 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6260 p <<
" indexing_maps = [";
6261 llvm::interleaveComma(getIndexingMaps(), p,
6267 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
6273LogicalResult BatchReduceMatmulOp::verify() {
6276 if (!hasUserDefinedMaps())
6279 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
6285LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6289void BatchReduceMatmulOp::getEffects(
6292 if (hasPureTensorSemantics())
6308void LinalgDialect::getCanonicalizationPatterns(
6317 return arith::ConstantOp::materialize(builder, value, type, loc);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_iterator result_begin()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override