42 #include "llvm/ADT/DenseMap.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/SetOperations.h"
45 #include "llvm/ADT/SmallSet.h"
46 #include "llvm/ADT/SmallVector.h"
47 #include "llvm/ADT/StringSet.h"
48 #include "llvm/ADT/TypeSwitch.h"
49 #include "llvm/Support/FormatVariadic.h"
50 #include "llvm/Support/InterleavedRange.h"
51 #include "llvm/Support/LogicalResult.h"
52 #include "llvm/Support/MathExtras.h"
53 #include "llvm/Support/raw_ostream.h"
63 auto type = cast<ShapedType>(v.
getType());
64 if (!type.isDynamicDim(dim))
69 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
70 return builder.create<tensor::DimOp>(loc, v, dim);
72 .Case<MemRefType>([&](MemRefType t) ->
Value {
73 return builder.create<memref::DimOp>(loc, v, dim);
84 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
85 return b.
create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
88 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
89 return b.
create<memref::SubViewOp>(loc, source, offsets, sizes,
101 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
103 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
105 llvm_unreachable(
"Expected MemRefType or TensorType");
110 auto shapedType = llvm::cast<ShapedType>(source.
getType());
111 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
136 for (
auto containers : {inputTypes, outputTypes}) {
137 for (
auto t : containers) {
149 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
153 regionBuilder(b, *body, attrs,
emitError);
165 std::optional<TypeRange> resultTensorTypes,
172 if (!resultTensorTypes)
173 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
174 llvm::IsaPred<RankedTensorType>);
176 state.addOperands(inputs);
177 state.addOperands(outputs);
178 state.addTypes(derivedResultTypes);
180 state.addAttributes(attributes);
182 "operandSegmentSizes",
184 static_cast<int32_t>(outputs.size())}));
187 Region ®ion = *state.addRegion();
189 state.attributes.getAttrs(), {},
194 std::optional<TypeRange> resultTensorTypes,
201 indexingMapsAttrVal = llvm::map_to_vector(
202 MatmulOp::getDefaultIndexingMaps(b.
getContext()),
204 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
206 attributes, regionBuilder);
210 std::optional<TypeRange> resultTensorTypes,
217 indexingMapsAttrVal =
221 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
223 attributes, regionBuilder);
227 std::optional<TypeRange> resultTensorTypes,
234 indexingMapsAttrVal =
238 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
240 attributes, regionBuilder);
249 bool addOperandSegmentSizes =
true) {
250 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
279 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
281 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
285 if (addOperandSegmentSizes) {
294 attrs.
append(
"operandSegmentSizes",
296 {static_cast<int32_t>(inputsOperands.size()),
297 static_cast<int32_t>(outputsOperands.size())}));
302 {static_cast<int32_t>(inputsOperands.size()),
303 static_cast<int32_t>(outputsOperands.size())}));
307 std::optional<RegisteredOperationName> info =
310 if (failed(info->verifyInherentAttrs(result.
attributes, [&]() {
311 return parser.emitError(attrsLoc)
312 <<
"'" << result.name.getStringRef() <<
"' op ";
323 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
324 if (!outputs.empty())
325 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
336 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
339 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
340 "region expects {0} args, got {1}",
341 numRegionArgs, inputTypes.size() + outputTypes.size()));
345 ParseResult result = success();
347 opBuilder, region, inputTypes, outputTypes, attrs,
366 unsigned numRegionArgs,
383 result.
addTypes(outputTensorsTypes);
385 std::unique_ptr<Region> region = std::make_unique<Region>();
397 if (resultTypes.empty())
442 class RegionBuilderHelper {
445 : builder(builder), block(block) {}
450 if (!isFloatingPoint(arg)) {
452 emitError() <<
"unsupported non numeric type";
455 llvm_unreachable(
"unsupported non numeric type");
458 builder.setInsertionPointToEnd(&block);
461 return builder.create<math::ExpOp>(arg.
getLoc(), arg);
463 return builder.create<math::LogOp>(arg.
getLoc(), arg);
465 return builder.create<math::AbsFOp>(arg.
getLoc(), arg);
467 return builder.create<math::CeilOp>(arg.
getLoc(), arg);
469 return builder.create<math::FloorOp>(arg.
getLoc(), arg);
471 return builder.create<arith::NegFOp>(arg.
getLoc(), arg);
472 case UnaryFn::reciprocal: {
474 auto one = builder.create<arith::ConstantOp>(arg.
getLoc(),
475 ::cast<TypedAttr>(oneAttr));
476 return builder.create<arith::DivFOp>(arg.
getLoc(), one, arg);
479 return builder.create<math::RoundOp>(arg.
getLoc(), arg);
481 return builder.create<math::SqrtOp>(arg.
getLoc(), arg);
483 return builder.create<math::RsqrtOp>(arg.
getLoc(), arg);
484 case UnaryFn::square:
485 return builder.create<arith::MulFOp>(arg.
getLoc(), arg, arg);
487 return builder.create<math::TanhOp>(arg.
getLoc(), arg);
489 return builder.create<math::ErfOp>(arg.
getLoc(), arg);
492 emitError() <<
"unsupported unary function";
495 llvm_unreachable(
"unsupported unary function");
504 bool allComplex = isComplex(arg0) && isComplex(arg1);
505 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
506 bool allInteger = isInteger(arg0) && isInteger(arg1);
509 if (!allComplex && !allFloatingPoint && !allInteger) {
512 <<
"Cannot build binary Linalg operation: expects allComplex, "
513 "allFloatingPoint, or allInteger, got "
517 llvm_unreachable(
"unsupported non numeric type");
520 builder.setInsertionPointToEnd(&block);
524 return builder.create<complex::AddOp>(arg0.
getLoc(), arg0, arg1);
525 if (allFloatingPoint)
526 return builder.create<arith::AddFOp>(arg0.
getLoc(), arg0, arg1);
528 return builder.create<arith::OrIOp>(arg0.
getLoc(), arg0, arg1);
529 return builder.create<arith::AddIOp>(arg0.
getLoc(), arg0, arg1);
532 return builder.create<complex::SubOp>(arg0.
getLoc(), arg0, arg1);
533 if (allFloatingPoint)
534 return builder.create<arith::SubFOp>(arg0.
getLoc(), arg0, arg1);
537 emitError() <<
"unsupported operation: sub with bools";
540 llvm_unreachable(
"unsupported operation: sub with bools");
542 return builder.create<arith::SubIOp>(arg0.
getLoc(), arg0, arg1);
545 return builder.create<complex::MulOp>(arg0.
getLoc(), arg0, arg1);
546 if (allFloatingPoint)
547 return builder.create<arith::MulFOp>(arg0.
getLoc(), arg0, arg1);
549 return builder.create<arith::AndIOp>(arg0.
getLoc(), arg0, arg1);
550 return builder.create<arith::MulIOp>(arg0.
getLoc(), arg0, arg1);
553 return builder.create<complex::DivOp>(arg0.
getLoc(), arg0, arg1);
554 if (allFloatingPoint)
555 return builder.create<arith::DivFOp>(arg0.
getLoc(), arg0, arg1);
558 emitError() <<
"unsupported operation: div with bools";
561 llvm_unreachable(
"unsupported operation: div with bools");
563 return builder.create<arith::DivSIOp>(arg0.
getLoc(), arg0, arg1);
564 case BinaryFn::div_unsigned:
565 if (!allInteger || allBool) {
567 emitError() <<
"unsupported operation: unsigned div not on uint";
570 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
572 return builder.create<arith::DivUIOp>(arg0.
getLoc(), arg0, arg1);
573 case BinaryFn::max_signed:
575 if (allFloatingPoint)
576 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
577 return builder.create<arith::MaxSIOp>(arg0.
getLoc(), arg0, arg1);
578 case BinaryFn::min_signed:
580 if (allFloatingPoint)
581 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
582 return builder.create<arith::MinSIOp>(arg0.
getLoc(), arg0, arg1);
583 case BinaryFn::max_unsigned:
585 if (allFloatingPoint)
586 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
587 return builder.create<arith::MaxUIOp>(arg0.
getLoc(), arg0, arg1);
588 case BinaryFn::min_unsigned:
590 if (allFloatingPoint)
591 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
592 return builder.create<arith::MinUIOp>(arg0.
getLoc(), arg0, arg1);
594 assert(allFloatingPoint);
595 return builder.create<math::PowFOp>(arg0.
getLoc(), arg0, arg1);
598 emitError() <<
"unsupported binary function";
601 llvm_unreachable(
"unsupported binary function");
609 bool tailFloatingPoint =
610 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
611 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
613 builder.setInsertionPointToEnd(&block);
615 case TernaryFn::select:
616 if (!headBool && !(tailFloatingPoint || tailInteger))
617 llvm_unreachable(
"unsupported non numeric type");
618 return builder.create<arith::SelectOp>(arg0.
getLoc(), arg0, arg1, arg2);
621 emitError() <<
"unsupported ternary function";
624 llvm_unreachable(
"unsupported ternary function");
631 case TypeFn::cast_signed:
632 return cast(toType, operand,
false);
633 case TypeFn::cast_unsigned:
634 return cast(toType, operand,
true);
637 emitError() <<
"unsupported type conversion function";
640 llvm_unreachable(
"unsupported type conversion function");
645 builder.setInsertionPointToEnd(&block);
646 Location loc = builder.getUnknownLoc();
647 builder.create<YieldOp>(loc, values);
650 Value constant(
const std::string &value) {
652 builder.setInsertionPointToEnd(&block);
653 Location loc = builder.getUnknownLoc();
655 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
658 Value index(int64_t dim) {
660 builder.setInsertionPointToEnd(&block);
661 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
664 Type getIntegerType(
unsigned width) {
678 builder.setInsertionPointToEnd(&block);
679 auto loc = operand.
getLoc();
680 if (isa<UnknownLoc>(loc)) {
690 bool isComplex(
Value value) {
691 return llvm::isa<ComplexType>(value.
getType());
693 bool isFloatingPoint(
Value value) {
694 return llvm::isa<FloatType>(value.
getType());
696 bool isInteger(
Value value) {
697 return llvm::isa<IntegerType>(value.
getType());
714 LogicalResult matchAndRewrite(CopyOp copyOp,
716 if (copyOp.getInputs() != copyOp.getOutputs())
718 if (copyOp.hasPureBufferSemantics())
721 rewriter.
replaceOp(copyOp, copyOp.getInputs());
731 results.
add<EraseSelfCopy>(context);
744 template <
typename TensorReshapeOp>
747 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
749 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
754 TensorReshapeOp newInit;
755 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
757 newInit = rewriter.
create<TensorReshapeOp>(
758 loc, reshapeOp.getResultType(), oldFill.output(),
759 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
760 reshapeOp.getStaticOutputShape());
762 newInit = rewriter.
create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
764 reshapeOp.getReassociation());
777 LogicalResult matchAndRewrite(tensor::PadOp padOp,
779 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
785 Value padValue = padOp.getConstantPaddingValue();
786 if (!padValue || fillOp.value() != padValue)
792 padOp,
"failed to reify tensor.pad op result shape");
794 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
795 padOp.getLoc(), reifiedShape.front(),
796 padOp.getResultType().getElementType());
802 if (replacement.getType() != padOp.getResultType()) {
803 replacement = rewriter.
create<tensor::CastOp>(
804 fillOp.getLoc(), padOp.getResultType(), replacement);
814 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
817 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
819 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
823 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
828 Value firstDest = insertOp.getDest();
829 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
830 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
835 bool disjoint =
false;
836 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
839 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
840 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
841 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
845 int64_t prevStart = prevOp.getStaticOffset(i);
846 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
847 prevOp.getStaticStride(i);
848 int64_t nextStart = insertOp.getStaticOffset(i);
849 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
850 insertOp.getStaticStride(i);
851 if (prevEnd < nextStart || nextEnd < prevStart) {
859 firstDest = prevOp.getDest();
870 Value padValue = srcPadOp.getConstantPaddingValue();
871 if (!padValue || dstFillOp.value() != padValue)
887 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
889 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
892 RankedTensorType srcPadType = srcPadOp.getSourceType();
894 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
895 if (srcPadType.isDynamicDim(i)) {
897 rewriter.
create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
900 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
905 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
906 newSizes, insertOp.getMixedStrides());
912 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
916 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
920 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
925 Value extractedScalar = fillOp.getInputs()[0];
928 rewriter.
replaceOp(extractOp, extractedScalar);
936 static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
937 linalg::PackOp packOp) {
938 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
942 if (
auto paddingValue = packOp.getPaddingValue())
946 Value packOpDest = packOp.getDest();
950 return rewriter.
create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
960 LogicalResult matchAndRewrite(linalg::PackOp packOp,
962 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
965 rewriter.
replaceOp(packOp, fillOp.value().result());
974 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
976 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
979 copyOp.getOutputs());
982 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
984 fillOp.getOutputs());
995 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
997 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
999 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1000 transposeOp.getDpsInitOperand(0)->get());
1012 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1014 auto concatOperands = concatOp.getInputs();
1015 if (concatOperands.empty()) {
1019 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1028 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1030 auto isDefinedByCompatibleFillOp = [&](
Value v) ->
bool {
1031 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1038 if (fillVal != firstFillVal)
1041 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1044 if (!llvm::all_of(concatOperands.drop_front(),
1045 isDefinedByCompatibleFillOp)) {
1047 concatOp,
"not all operands are defined by a compatible fill op");
1050 Value outsConcat = rewriter.
create<tensor::ConcatOp>(
1051 concatOp.getLoc(), concatOp.getDim(), allOuts);
1053 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1062 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1063 FoldFillWithPack, FoldFillWithPad,
1064 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1065 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1066 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1079 for (
ValueRange container : {inputs, outputs}) {
1080 for (
Value v : container) {
1081 Type t = v.getType();
1082 blockArgTypes.push_back(
1084 blockArgLocs.push_back(v.getLoc());
1090 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1094 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
1096 for (
Value v : getRegionInputArgs())
1098 for (
Value v : getRegionOutputArgs())
1099 setNameFn(v,
"out");
1102 void GenericOp::build(
1105 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1108 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1109 iteratorTypes, doc, libraryCall);
1113 inputs, outputs, bodyBuild);
1116 void GenericOp::build(
1120 StringRef libraryCall,
1123 build(builder, result, resultTensorTypes, inputs, outputs,
1128 return IteratorTypeAttr::get(builder.getContext(), iter);
1131 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1132 bodyBuild, attributes);
1135 void GenericOp::build(
1139 StringRef libraryCall,
1142 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
1143 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1146 void GenericOp::build(
1152 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1154 "", bodyBuild, attributes);
1157 void GenericOp::build(
1163 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1166 "", bodyBuild, attributes);
1173 auto genericAttrNames = linalgTraitAttrNames();
1176 genericAttrNamesSet.insert_range(genericAttrNames);
1178 for (
auto attr : (*this)->getAttrs()) {
1179 if (attr.getName() == getIteratorTypesAttrName()) {
1180 auto iteratorTypes =
1181 llvm::cast<ArrayAttr>(attr.getValue())
1182 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1188 llvm::to_vector(llvm::map_range(
1189 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
1193 genericAttrs.emplace_back(
1194 getIteratorTypesAttrName(),
1196 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1197 genericAttrs.push_back(attr);
1200 if (!genericAttrs.empty()) {
1202 p << genericDictAttr;
1208 genericAttrNames.push_back(
"operandSegmentSizes");
1209 genericAttrNamesSet.insert(genericAttrNames.back());
1211 bool hasExtraAttrs =
false;
1213 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1216 if (hasExtraAttrs) {
1223 if (!getRegion().empty()) {
1233 DictionaryAttr dictAttr;
1242 dictAttr.getValue().end());
1248 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1250 if (!iteratorTypes) {
1251 return parser.
emitError(attributeLocation)
1252 <<
"expected " << getIteratorTypesAttrName(result.
name)
1253 <<
" array attribute";
1258 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1259 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1260 if (!maybeIteratorType.has_value())
1262 <<
"unexpected iterator_type (" << s <<
")";
1264 iteratorTypeAttrs.push_back(
1281 std::unique_ptr<Region> region = std::make_unique<Region>();
1293 result.
addTypes(outputTensorsTypes);
1301 LinalgOp linalgOp) {
1302 for (
auto [index, operand] :
llvm::enumerate(linalgOp.getDpsInputs())) {
1303 if (!llvm::isa<MemRefType>(operand.
getType()))
1305 effects.emplace_back(
1310 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1311 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1313 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1324 void GenericOp::getEffects(
1334 if (!linalgOp.hasPureTensorSemantics())
1354 template <
typename OpTy>
1358 LogicalResult matchAndRewrite(OpTy linalgOp,
1361 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1366 Block &body = linalgOp->getRegion(0).
front();
1367 if (!llvm::hasSingleElement(body))
1369 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1374 if (linalgOp.hasPureBufferSemantics()) {
1375 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1376 linalgOp.getDpsInputOperand(0)->get() !=
1377 linalgOp.getDpsInitOperand(0)->get()) {
1379 linalgOp,
"expected single input and output to be the same value");
1382 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1383 if (!yieldArg || yieldArg.getOwner() != &body) {
1385 "cannot fold fill-like op");
1392 if (!linalgOp.hasPureTensorSemantics()) {
1394 linalgOp,
"mixed semantics is not supported yet");
1401 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1402 if (!yieldArg || yieldArg.getOwner() != &body)
1404 unsigned argumentNumber = yieldArg.getArgNumber();
1405 Value returnedArg = linalgOp->getOperand(argumentNumber);
1406 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1410 if (returnType != resultType) {
1415 returnedArg = rewriter.
create<sparse_tensor::ConvertOp>(
1416 linalgOp.getLoc(), resultType, returnedArg);
1418 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1421 returnedArg = rewriter.
create<tensor::CastOp>(
1422 linalgOp.getLoc(), resultType, returnedArg);
1425 returnedArgs.push_back(returnedArg);
1428 if (returnedArgs.size() != linalgOp->getNumResults())
1430 rewriter.
replaceOp(linalgOp, returnedArgs);
1439 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1461 for (
Type outputType : outputTypes) {
1462 if (llvm::isa<RankedTensorType>(outputType))
1467 if (parseAttrsFn && failed(parseAttrsFn(parser, result.
attributes)))
1476 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1478 for (
Value v : getRegionInputArgs())
1483 if (!getResults().empty())
1484 setNameFn(getResults().front(),
"mapped");
1491 build(builder, result,
TypeRange{}, inputs, init);
1496 if (llvm::isa<RankedTensorType>(initType))
1501 inputs, {}, bodyBuild);
1508 bool initFirst =
false) {
1513 for (
auto &operand : operands) {
1515 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1522 payloadOpOperands.push_back(block.
getArguments().back());
1523 for (
const auto &arg : block.
getArguments().drop_back())
1524 payloadOpOperands.push_back(arg);
1533 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1540 std::optional<OperationName> payloadOpName;
1544 if (failed(operationName))
1548 payloadOpName = operationName.value();
1556 if (payloadOpName.has_value()) {
1594 for (
const auto &[operand, bbArg] :
1596 if (bbArg != operand)
1600 for (
const auto &[operand, bbArg] :
1602 if (bbArg != operand)
1611 std::string attrToElide;
1613 for (
const auto &attr : payloadOp->
getAttrs()) {
1615 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1616 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1617 attrToElide = attr.getName().str();
1618 elidedAttrs.push_back(attrToElide);
1627 Block *mapper = getBody();
1642 [&](
auto arg) { p.printRegionArgument(arg); });
1651 auto *bodyBlock = getBody();
1652 auto blockArgs = bodyBlock->getArguments();
1655 if (getInputs().size() != blockArgs.size())
1656 return emitOpError() <<
"expects number of operands to match the arity of "
1658 << getInputs().size() <<
" and " << blockArgs.size();
1661 for (
const auto &[bbArgType, inputArg] :
1662 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1663 auto inputElemType =
1664 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1665 if (bbArgType != inputElemType) {
1666 return emitOpError() <<
"expected element type of input " << inputElemType
1667 <<
" to match bbArg type " << bbArgType;
1672 auto outputShape = getInit().getType().getShape();
1674 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1675 if (inputElemShape != outputShape) {
1676 return emitOpError() <<
"expected shape of input (" << inputElemShape
1677 <<
") to match shape of output (" << outputShape
1686 int64_t rank = getInit().getType().getRank();
1690 ArrayAttr MapOp::getIndexingMaps() {
1692 int64_t rank = getInit().getType().getRank();
1693 int64_t numIndexingMaps = getOperands().size();
1698 void MapOp::getEffects(
1712 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1714 for (
Value v : getRegionInputArgs())
1716 for (
Value v : getRegionOutputArgs())
1717 setNameFn(v,
"init");
1720 void ReduceOp::getAsmResultNames(
1722 if (!getResults().empty())
1723 setNameFn(getResults().front(),
"reduced");
1726 void ReduceOp::build(
1731 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1735 for (
Value init : inits) {
1737 if (llvm::isa<RankedTensorType>(initType))
1743 inputs, inits, bodyBuild);
1748 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1750 utils::IteratorType::parallel);
1751 for (int64_t reductionDim : getDimensions())
1752 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1753 return iteratorTypes;
1756 ArrayAttr ReduceOp::getIndexingMaps() {
1758 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1765 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1766 affineMaps.push_back(resultMap);
1770 void ReduceOp::getEffects(
1782 StringRef attributeName) {
1791 std::optional<OperationName> payloadOpName;
1795 if (failed(operationName))
1799 payloadOpName = operationName.value();
1810 if (payloadOpName.has_value()) {
1830 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1834 Block *mapper = getBody();
1849 [&](
auto arg) { p.printRegionArgument(arg); });
1860 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1861 if (llvm::cast<ShapedType>(getInputs()[i].
getType()).getShape() !=
1862 llvm::cast<ShapedType>(getInputs()[0].
getType()).getShape()) {
1863 return emitOpError() <<
"expects all inputs to have the same shapes. "
1864 "Shape at input-index "
1866 <<
" is not equal to the shape at input-index 0.";
1869 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1870 if (llvm::cast<ShapedType>(getInits()[i].
getType()).getShape() !=
1871 llvm::cast<ShapedType>(getInits()[0].
getType()).getShape()) {
1872 return emitOpError() <<
"expects all outputs to have the same shapes. "
1873 "Shape at output-index "
1875 <<
" is not equal to the shape at output-index 0.";
1878 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1879 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1882 for (int64_t dimension : dimensionsRef) {
1883 if (dimension < 0 || dimension >= inputType.getRank()) {
1884 return emitOpError()
1885 <<
"dimensions for reduction should be in the range [0, "
1886 << inputType.getRank() - 1 <<
"].";
1888 dimensionsToReduce.insert(dimension);
1891 auto inputDims = inputType.getShape();
1892 auto initDims = initType.getShape();
1897 if (!dimensionsToReduce.count(en.index()))
1898 reducedInputDims.push_back(en.value());
1901 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1902 return emitOpError() <<
"number of dimensions after reduction "
1903 << reducedInputDims.size()
1904 <<
" doesn't match the init rank "
1905 << initType.getRank();
1908 if (reducedInputDims != initDims)
1909 return emitOpError() <<
"init dimensions [" << initDims
1910 <<
"] doesn't match input dimensions after reduction ["
1911 << reducedInputDims <<
"]";
1913 Block *block = getBody();
1915 return emitOpError()
1916 <<
"mismatching number of operands and block arguments";
1919 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1920 Type inputElementType =
1921 llvm::cast<ShapedType>(input.getType()).getElementType();
1922 if (inputElementType != bbArg.getType())
1923 return emitOpError()
1924 <<
"input element type " << inputElementType
1925 <<
" does not match corresponding block argument type "
1930 for (
auto [output, bbArg] : llvm::zip(
1931 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1932 auto outputElementType =
1933 llvm::cast<ShapedType>(output.getType()).getElementType();
1934 if (outputElementType != bbArg.getType())
1935 return emitOpError()
1936 <<
"output element type " << outputElementType
1937 <<
" does not match corresponding block argument type "
1953 b.
create<linalg::YieldOp>(loc, args[0]);
1968 if (llvm::isa<RankedTensorType>(initType))
1997 void TransposeOp::getAsmResultNames(
1999 if (!getResults().empty())
2000 setNameFn(getResults().front(),
"transposed");
2013 return emitOpError(
"permutation is not valid");
2015 auto inputType = getInput().getType();
2016 auto initType = getInit().getType();
2018 int64_t rank = inputType.getRank();
2020 if (rank != initType.getRank())
2021 return emitOpError() <<
"input rank " << rank
2022 <<
" does not match init rank " << initType.getRank();
2024 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2025 return emitOpError() <<
"size of permutation " << permutationRef.size()
2026 <<
" does not match the argument rank " << rank;
2028 auto inputDims = inputType.getShape();
2029 auto initDims = initType.getShape();
2031 for (int64_t i = 0; i < rank; ++i) {
2032 int64_t inputDim = inputDims[permutationRef[i]];
2033 int64_t initDim = initDims[i];
2035 if (inputDim != initDim) {
2036 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2037 <<
" doesn't match dim(input, permutation[" << i
2038 <<
"]) = " << inputDim;
2046 int64_t rank = getInit().getType().getRank();
2050 ArrayAttr TransposeOp::getIndexingMaps() {
2052 int64_t rank = getInit().getType().getRank();
2055 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2059 void TransposeOp::getEffects(
2069 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2072 if (!isa<TensorType>(getInput().
getType()))
2076 if (getPermutation().size() == 0) {
2077 result.push_back(getInput());
2082 result.push_back(getInput());
2095 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2096 if (!defTransposeOp)
2101 foldedPerms.reserve(perms.size());
2102 for (int64_t perm : perms)
2103 foldedPerms.push_back(defPerms[perm]);
2106 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2120 Value input = transposeOp.getInput();
2121 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2132 unsigned dimensionSize = dimensions.size();
2133 for (
unsigned i = 0; i < dimensionSize; ++i)
2134 resultDimensions.push_back(invertPerm[dimensions[i]]);
2137 Value broadcastInput = broadcastOp.getInput();
2138 Location loc = transposeOp.getLoc();
2141 auto broadcastInputTy =
2142 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2143 unsigned inputRank = broadcastInputTy.getRank();
2144 for (
unsigned i = 0; i < inputRank; ++i) {
2145 if (broadcastInputTy.isDynamicDim(i)) {
2146 dims.push_back(rewriter.
create<tensor::DimOp>(loc, broadcastInput, i)
2150 broadcastInputTy.getDimSize(i)));
2155 Value transposeInit = rewriter.
create<tensor::EmptyOp>(
2156 transposeOp.getLoc(), transposeResultShapes,
2157 broadcastInputTy.getElementType());
2160 Value transposeResult =
2162 .
create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2166 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2191 if (llvm::isa<RankedTensorType>(initType))
2220 void BroadcastOp::getAsmResultNames(
2222 if (!getResults().empty())
2223 setNameFn(getResults().front(),
"broadcasted");
2235 auto inputType = getInput().getType();
2236 auto initType = getInit().getType();
2238 int64_t inputRank = inputType.getRank();
2239 int64_t initRank = initType.getRank();
2241 auto inputShape = inputType.getShape();
2242 auto initShape = initType.getShape();
2244 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2245 return emitOpError() <<
"input rank plus added dimensions does not "
2246 "match init rank. input rank: "
2248 <<
", dimensions size: " << dimensionsRef.size()
2249 <<
", init rank: " << initRank;
2252 if (dim < 0 || dim >= initRank)
2253 return emitOpError() <<
"dimension " << idx
2254 <<
" is out of range. expected range: [0, "
2255 << initRank - 1 <<
"], got: " << dim;
2260 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2261 if (!llvm::is_contained(dimensionsRef, dim))
2262 dimMap.push_back(dim);
2265 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
2268 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2269 return emitOpError() <<
"input dim " << inputDimIdx
2270 <<
" should match init dim " << initDimIdx
2271 <<
". input: " << inputShape[inputDimIdx]
2272 <<
", init: " << initShape[initDimIdx];
2279 int64_t rank = getInit().getType().getRank();
2283 ArrayAttr BroadcastOp::getIndexingMaps() {
2285 int64_t rank = getInit().getType().getRank();
2291 void BroadcastOp::getEffects(
2303 results.
add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2311 if (getNumOperands() > 0)
2312 p <<
' ' << getOperands();
2314 if (getNumOperands() > 0)
2315 p <<
" : " << getOperandTypes();
2330 static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2331 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2332 return op.emitOpError(
"expected number of yield values (")
2333 << op.getNumOperands()
2334 <<
") to match the number of inits / outs operands of the enclosing "
2335 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2337 for (
OpOperand &opOperand : op->getOpOperands()) {
2339 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2341 if (isa<MemRefType, RankedTensorType>(elementType))
2343 if (opOperand.get().getType() != elementType)
2344 return op.emitOpError(
"type of yield operand ")
2345 << (opOperand.getOperandNumber() + 1) <<
" ("
2346 << opOperand.get().getType() <<
") doesn't match "
2347 <<
"the element type of the enclosing linalg.generic op ("
2348 << elementType <<
")";
2354 auto *parentOp = (*this)->getParentOp();
2355 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2356 return emitOpError(
"expected single non-empty parent region");
2358 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2361 return emitOpError(
"expected parent op with LinalgOp interface");
2369 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2371 return emitOpError(
"expected parent op with LinalgOp interface");
2372 if (linalgOp.getNumLoops() <= getDim())
2373 return emitOpError(
"expected dim (")
2374 << getDim() <<
") to be lower than the number of loops ("
2375 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2380 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2389 uint64_t dim = getDim();
2390 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2391 if (loopBounds[dim] == 1)
2399 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2401 #define GET_OP_CLASSES
2402 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2404 #define GET_OP_CLASSES
2405 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2406 #define GET_OP_CLASSES
2407 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2424 for (
unsigned i = 0; i < num; ++i)
2431 auto rangeA = llvm::make_range(a.begin(), a.end());
2432 auto rangeB = llvm::make_range(b.begin(), b.end());
2433 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2434 return llvm::to_vector<4>(concatRanges);
2438 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2440 for (
auto size : memref.getShape())
2447 if (
auto as = memref.getMemorySpace()) {
2448 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2449 ss <<
"as" << attr.getInt();
2455 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2458 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2471 assert(isa<LinalgOp>(op));
2473 std::string fun =
"";
2475 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2476 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2477 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2478 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2482 llvm::replace(name,
'.',
'_');
2483 llvm::raw_string_ostream ss(name);
2487 return std::string();
2502 LogicalResult matchAndRewrite(LinalgOp op,
2504 for (
OpOperand &opOperand : op->getOpOperands()) {
2508 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2511 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2522 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2525 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2530 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2537 if (castOp->getBlock() != linalgOp->getBlock())
2544 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2547 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2553 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2555 rewriter.
create<tensor::CastOp>(loc, resultType, outOperand->
get());
2558 linalgOp.getDpsInits().end());
2559 outputOperands[resultNumber] = newOperand;
2560 newOperands.append(outputOperands.begin(), outputOperands.end());
2563 linalgOp->result_type_end());
2564 resultTypes[resultNumber] = resultType;
2565 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2572 results[resultNumber] = castBack;
2584 if (linalgOp.isScalar(&opOperand))
2586 Value src = opOperand.get();
2587 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2588 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2596 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2597 Value castSource = castOp.getSource();
2598 auto castSourceType =
2599 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2600 if (castSourceType && castSourceType.hasStaticShape())
2601 sourceShape = castSourceType.getShape();
2607 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2608 if (sourceType.isDynamicDim(i))
2610 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2611 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2621 static void createNewOperandWithStaticSizes(
2625 bool &changeNeeded) {
2627 newOperands.push_back(src);
2628 if (linalgOp.isScalar(opOperand))
2630 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2631 Type resultType = sourceType;
2632 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2633 resultTypes.push_back(resultType);
2637 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2641 bool newOperandNeeded =
false;
2642 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2643 int64_t dimShape = sourceShape[i];
2645 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2646 newShape.push_back(dimShape);
2652 newShape.push_back(affineExprToSize[dimExpr]);
2653 newOperandNeeded =
true;
2656 sourceType.getEncoding());
2657 if (newOperandNeeded) {
2658 changeNeeded =
true;
2661 Value newOperand = rewriter.
create<tensor::CastOp>(loc, resultType, src);
2663 newOperands[index] = newOperand;
2665 if (linalgOp.isDpsInit(opOperand))
2666 resultTypes.push_back(resultType);
2675 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2677 if (!linalgOp.hasPureTensorSemantics())
2681 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2682 return !map.isProjectedPermutation();
2692 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2699 bool changeNeeded =
false;
2700 newOperands.reserve(linalgOp->getNumOperands());
2701 resultTypes.reserve(linalgOp.getNumDpsInits());
2704 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2705 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2706 affineExprToSize, linalgOp, newOperands,
2707 resultTypes, changeNeeded);
2716 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2719 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2720 Value newResult = std::get<1>(it);
2721 Value oldResult = std::get<0>(it);
2724 replacements.push_back(
2725 (newType != oldType)
2726 ? rewriter.
create<tensor::CastOp>(loc, oldType, newResult)
2729 rewriter.
replaceOp(linalgOp, replacements);
2744 ShapedType inputType = getInputOperandType();
2745 ShapedType outputType = getOutputOperandType();
2750 return emitOpError(
"incompatible output shape");
2752 int64_t inputRank = getInputOperandRank();
2753 int64_t dimension = getDimension();
2754 if ((dimension < 0) || (dimension >= inputRank))
2755 return emitOpError(
"incorrect dimension specified");
2761 int64_t operandRank = getInputOperandRank();
2764 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
2765 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2766 Value source = getInput();
2767 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2768 loopBounds[dim].offset = zero;
2769 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2770 loopBounds[dim].stride = one;
2777 utils::IteratorType::parallel);
2778 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2779 return iteratorTypes;
2782 FailureOr<TilingResult>
2786 int64_t rank = getInputOperandRank();
2791 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2793 return emitOpError(
"failed to compute input slice");
2795 tiledOperands.emplace_back(inputSlice->
getResult(0));
2797 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2799 return emitOpError(
"failed to compute output slice");
2801 tiledOperands.emplace_back(outputSlice->
getResult(0));
2804 if (hasPureTensorSemantics())
2805 resultTypes.push_back(tiledOperands[1].
getType());
2807 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2819 if (resultNumber == 0) {
2820 resultOffsets.assign(offsets.begin(), offsets.end());
2821 resultSizes.assign(sizes.begin(), sizes.end());
2836 Location loc = getOperation()->getLoc();
2838 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2839 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2840 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2841 if (!outputShapedType.isDynamicDim(dim)) {
2843 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2850 reifiedReturnShapes.emplace_back(std::move(shapes));
2854 void SoftmaxOp::getEffects(
2858 if (!llvm::isa<MemRefType>(operand.
getType()))
2861 &getOperation()->getOpOperand(index), 0,
2866 for (
OpOperand &operand : getDpsInitsMutable()) {
2867 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2900 int64_t dim,
bool allParallel =
false) {
2902 utils::IteratorType::parallel);
2904 iteratorTypes[dim] = utils::IteratorType::reduction;
2908 for (
int i = 0; i < inputRank; i++) {
2915 return std::make_tuple(iteratorTypes, indexingMaps);
2920 template <
typename T>
2923 auto inputType = cast<ShapedType>(input.
getType());
2925 int64_t inputRank = inputShape.size();
2926 auto [iteratorTypes, indexingMaps] =
2928 assert(indexingMaps.size() == 2 &&
2929 "We should have two maps: 1 for the input, 1 for the output");
2930 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2932 auto genericOp = builder.
create<linalg::GenericOp>(
2933 loc, output.
getType(), input, output, indexingMaps, iteratorTypes,
2935 Value result = b.create<T>(loc, args[0], args[1]);
2936 b.create<linalg::YieldOp>(loc, result);
2946 auto inputType = cast<ShapedType>(input.
getType());
2948 int64_t inputRank = inputShape.size();
2950 builder, inputRank, dim,
true);
2951 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2952 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2954 indexingMaps.push_back(indexingMaps[0]);
2955 auto genericOp = builder.
create<linalg::GenericOp>(
2958 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2959 Value result = b.create<math::ExpOp>(loc, diff);
2960 b.create<linalg::YieldOp>(loc, result);
2971 Value denominator,
Value output, int64_t dim) {
2972 auto inputType = cast<ShapedType>(numerator.
getType());
2974 int64_t inputRank = inputShape.size();
2976 builder, inputRank, dim,
true);
2977 assert(indexingMaps.size() == 2 &&
2978 "We should have one map for each input (2)");
2979 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
2981 indexingMaps.push_back(indexingMaps[0]);
2982 auto genericOp = builder.
create<linalg::GenericOp>(
2984 indexingMaps, iteratorTypes,
2986 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2987 b.create<linalg::YieldOp>(loc, result);
3011 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(
OpBuilder &b) {
3015 Value input = getInput();
3016 ShapedType inputType = getInputOperandType();
3017 Type elementType = inputType.getElementType();
3018 int64_t reductionDim = getDimension();
3020 Value output = getOutput();
3021 dims.erase(dims.begin() + reductionDim);
3023 Value outputReduce = b.
create<tensor::EmptyOp>(loc, dims, elementType);
3025 elementType, b, loc,
3027 Value neutralForMaxFInit =
3028 b.
create<linalg::FillOp>(loc,
Value{neutralForMaxF}, outputReduce)
3031 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3040 b.
create<linalg::FillOp>(loc,
Value{zero}, outputReduce).result();
3042 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3046 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3055 auto filterType = cast<ShapedType>(getFilter().
getType());
3057 int64_t filterH = filterShape[getFilterHDim()];
3058 int64_t filterW = filterShape[getFilterWDim()];
3059 WinogradConv2DFmr fmr = getFmr();
3063 if (filterH != r && filterH != 1)
3064 return emitOpError(
"expect filter height either equals to r or 1");
3065 if (filterW != r && filterW != 1)
3066 return emitOpError(
"expect filter width either equals to r or 1");
3067 if (filterH == 1 && filterW == 1)
3068 return emitOpError(
"expect either filter height or width equals to r");
3071 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3072 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3073 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3074 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3076 auto outputType = cast<ShapedType>(getOutput().
getType());
3079 return emitOpError(
"the output shape is not expected");
3085 WinogradFilterTransformOp::getIterationDomain(
OpBuilder &builder) {
3089 Value filter = getFilter();
3090 int64_t filterRank = getFilterOperandRank();
3092 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3093 loopBounds[dim].offset = zeroAttr;
3094 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3095 loopBounds[dim].stride = oneAttr;
3101 WinogradFilterTransformOp::getLoopIteratorTypes() {
3102 int64_t filterRank = getFilterOperandRank();
3104 utils::IteratorType::parallel);
3105 return iteratorTypes;
3113 ShapedType filterType = getFilterOperandType();
3115 int64_t filterH = filterShape[getFilterHDim()];
3116 int64_t filterW = filterShape[getFilterWDim()];
3117 WinogradConv2DFmr fmr = getFmr();
3120 int64_t alpha = m + r - 1;
3121 int64_t alphaH = filterH != 1 ? alpha : 1;
3122 int64_t alphaW = filterW != 1 ? alpha : 1;
3126 resultOffsets.append(
3127 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3129 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3145 ShapedType filterType = getFilterOperandType();
3147 int64_t filterH = filterShape[getFilterHDim()];
3148 int64_t filterW = filterShape[getFilterWDim()];
3154 sliceOffsets.append(
3155 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3156 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3157 sizes[getFilterCDim()]});
3158 int64_t filterRank = getFilterOperandRank();
3161 auto filterSlice = builder.
create<tensor::ExtractSliceOp>(
3162 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3163 tiledOperands.emplace_back(filterSlice);
3170 int64_t outputRank = getOutputOperandRank();
3172 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3173 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3174 tiledOperands.emplace_back(outputSlice);
3177 resultTypes.push_back(tiledOperands[1].
getType());
3179 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3192 auto inputType = cast<ShapedType>(getInput().
getType());
3194 int64_t inputH = inputShape[getInputHDim()];
3195 int64_t inputW = inputShape[getInputWDim()];
3196 WinogradConv2DFmr fmr = getFmr();
3199 int64_t tileSize = m + r - 1;
3201 auto outputType = cast<ShapedType>(getOutput().
getType());
3203 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3204 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3207 if (ShapedType::isDynamic(inputH)) {
3208 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3209 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3211 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3212 expectedOutputShape[getOutputTileHDim()] =
3213 leftTransform ? (inputH - (r - 1)) / m : inputH;
3215 if (ShapedType::isDynamic(inputW)) {
3216 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3217 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3219 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3220 expectedOutputShape[getOutputTileWDim()] =
3221 rightTransform ? (inputW - (r - 1)) / m : inputW;
3223 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3224 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3227 return emitOpError(
"the output shape is not expected");
3233 WinogradInputTransformOp::getIterationDomain(
OpBuilder &builder) {
3237 Value output = getOutput();
3238 int64_t outputRank = getOutputOperandRank();
3240 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3241 loopBounds[dim].offset = zeroAttr;
3243 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3244 loopBounds[dim].stride = oneAttr;
3250 WinogradInputTransformOp::getLoopIteratorTypes() {
3251 int64_t outputRank = getOutputOperandRank();
3253 utils::IteratorType::parallel);
3254 return iteratorTypes;
3262 ShapedType outputType = getOutputOperandType();
3264 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3265 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3267 WinogradConv2DFmr fmr = getFmr();
3270 int64_t alpha = m + r - 1;
3271 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3272 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3277 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3278 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3279 offsets[getOutputCDim()]});
3280 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3281 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3282 sizes[getOutputCDim()]});
3293 FailureOr<TilingResult>
3298 WinogradConv2DFmr fmr = getFmr();
3302 ShapedType outputType = getOutputOperandType();
3304 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3305 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3309 auto identityAffineMap =
3311 auto offsetAffineMap =
3314 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3315 offsets[getOutputTileHDim()]);
3317 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3318 offsets[getOutputTileWDim()]);
3322 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3324 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3331 sliceOffsets.append(
3332 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3338 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3339 int64_t inputRank = getInputOperandRank();
3341 auto inputSlice = builder.
create<tensor::ExtractSliceOp>(
3342 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3343 tiledOperands.emplace_back(inputSlice);
3350 int64_t outputRank = getOutputOperandRank();
3352 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3353 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3354 tiledOperands.emplace_back(outputSlice);
3357 resultTypes.push_back(tiledOperands[1].
getType());
3359 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3372 auto valueType = cast<ShapedType>(getValue().
getType());
3374 int64_t valueH = valueShape[getValueAlphaHDim()];
3375 int64_t valueW = valueShape[getValueAlphaWDim()];
3376 int64_t valueTileH = valueShape[getValueTileHDim()];
3377 int64_t valueTileW = valueShape[getValueTileWDim()];
3378 WinogradConv2DFmr fmr = getFmr();
3381 bool leftTransform = valueH != 1;
3382 bool rightTransform = valueW != 1;
3384 int64_t outputRank = getOutputOperandRank();
3386 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3387 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3389 if (valueH != (leftTransform ? m + r - 1 : 1))
3390 return emitOpError(
"expect input height equals to input tile size");
3391 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3393 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3394 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3396 if (valueW != (rightTransform ? m + r - 1 : 1))
3397 return emitOpError(
"expect input width equals to input tile size");
3398 expectedOutputShape[getOutputWDim()] =
3399 (rightTransform ? m : 1) * valueTileW;
3401 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3402 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3404 auto outputType = cast<ShapedType>(getOutput().
getType());
3407 return emitOpError(
"the output shape is not expected");
3413 WinogradOutputTransformOp::getIterationDomain(
OpBuilder &builder) {
3417 Value value = getValue();
3418 int64_t valueRank = getValueOperandRank();
3420 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3421 loopBounds[dim].offset = zeroAttr;
3423 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3424 loopBounds[dim].stride = oneAttr;
3430 WinogradOutputTransformOp::getLoopIteratorTypes() {
3431 int64_t valueRank = getValueOperandRank();
3433 utils::IteratorType::parallel);
3434 return iteratorTypes;
3441 WinogradConv2DFmr fmr = getFmr();
3447 auto identityAffineMap =
3452 ShapedType valueType = getValueOperandType();
3454 int64_t valueH = valueShape[0];
3455 int64_t valueW = valueShape[1];
3457 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3458 offsets[getValueTileHDim()]);
3460 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3461 offsets[getValueTileWDim()]);
3463 builder, loc, affineMap, sizes[getValueTileHDim()]);
3465 builder, loc, affineMap, sizes[getValueTileWDim()]);
3475 resultOffsets.append(
3476 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3478 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3497 ShapedType valueType = getValueOperandType();
3499 int64_t alphaH = valueShape[getValueAlphaHDim()];
3500 int64_t alphaW = valueShape[getValueAlphaWDim()];
3504 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3505 offsets[getValueTileWDim()], offsets[getValueNDim()],
3506 offsets[getValueFDim()]});
3507 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3508 sizes[getValueTileWDim()], sizes[getValueNDim()],
3509 sizes[getValueFDim()]});
3510 int64_t valueRank = getValueOperandRank();
3512 auto valueSlice = builder.
create<tensor::ExtractSliceOp>(
3513 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3514 tiledOperands.emplace_back(valueSlice);
3521 int64_t outputRank = getOutputOperandRank();
3523 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3524 loc, getOutput(), resultOffsets, resultSizes, strides);
3525 tiledOperands.emplace_back(outputSlice);
3528 resultTypes.push_back(tiledOperands[1].
getType());
3530 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3549 llvm::set_union(explicitSet, defaultSet);
3550 return explicitSet == defaultSet;
3570 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3572 auto opIndexingMap = opIndexingMaps[opIndex];
3573 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3576 return matmulOp->emitOpError()
3577 <<
"Unexpected dim expression in map result.";
3580 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3581 return matmulOp->emitOpError()
3582 <<
"Invalid broadcast requested, should be (d2).";
3591 template <
typename OpTy>
3594 AffineMap defaultIndexingMap,
bool isLHS) {
3595 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3596 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3597 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3600 return batchVariantMatmulOp->emitOpError()
3601 <<
"Unexpected result dim expression (outside the set of default "
3606 return batchVariantMatmulOp->emitOpError()
3607 <<
"no. of result dim expressions exceeds 3.";
3609 auto hasValidBatchDim = [](
AffineMap map) {
3616 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3617 return batchVariantMatmulOp->emitOpError()
3618 <<
"Invalid broadcast requested.";
3619 }
else if (!hasValidBatchDim(opIndexingMap)) {
3620 return batchVariantMatmulOp->emitOpError()
3621 <<
"Invalid batch dimension expression.";
3629 template <
typename OpTy>
3632 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3633 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3634 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3635 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3638 return batchVariantMatmulOp->emitOpError()
3639 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3642 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3644 return batchVariantMatmulOp->emitOpError()
3645 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3649 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3650 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3651 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3652 outputMap.getResult(1).isFunctionOfDim(1) &&
3653 outputMap.getResult(2).isFunctionOfDim(2)
3654 : outputMap.getResult(0).isFunctionOfDim(1) &&
3655 outputMap.getResult(1).isFunctionOfDim(2);
3658 if (!areValidOutputResultDim(opIndexingMap)) {
3659 return batchVariantMatmulOp->emitOpError()
3660 <<
"Invalid output map result dimension.";
3669 template <
typename OpTy>
3670 static LogicalResult
3674 batchVariantMatmulOp.getIndexingMapsArray();
3676 batchVariantMatmulOp.getDefaultIndexingMaps(
3677 batchVariantMatmulOp->getContext());
3679 if (opIndexingMaps.size() != 3)
3680 return batchVariantMatmulOp->emitOpError()
3681 <<
"Indexing_map attribute must have 3 affine maps.";
3683 auto opIndexingMap = opIndexingMaps[opIndex];
3684 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3692 defaultIndexingMap, opIndex == 0)))
3702 if (m == 2 && r == 3)
3703 return WinogradConv2DFmr::F_2_3;
3704 if (m == 4 && r == 3)
3705 return WinogradConv2DFmr::F_4_3;
3706 if (m == 2 && r == 5)
3707 return WinogradConv2DFmr::F_2_5;
3708 return std::nullopt;
3713 case WinogradConv2DFmr::F_2_3:
3715 case WinogradConv2DFmr::F_4_3:
3717 case WinogradConv2DFmr::F_2_5:
3734 return indexingMaps;
3739 utils::IteratorType::parallel,
3740 utils::IteratorType::reduction};
3743 unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3745 std::string MatmulOp::getLibraryCallName() {
3749 bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3753 bool MatmulOp::hasUserDefinedMaps() {
3757 return defaultMaps != explicitMaps;
3766 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3771 "MatmulOp regionBuilder expects 3 args");
3772 RegionBuilderHelper helper(b, block);
3775 TypeFn castVal = TypeFn::cast_signed;
3776 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3777 return attr.
getName() ==
"cast";
3779 if (castIter != attrs.end()) {
3780 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3788 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3795 yields.push_back(value4);
3796 helper.yieldOutputs(yields);
3806 bool MatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap) {
3807 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3818 ArrayAttr arrayAttr;
3822 if (llvm::any_of(arrayAttr,
3823 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3825 <<
"element of indexing_maps array is not an affine_map";
3832 if (failed(indexingMapsAttr))
3835 if (*indexingMapsAttr ==
nullptr) {
3836 auto indexingMapAttrs = llvm::map_to_vector(
3837 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3842 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3844 MatmulOp::getRegionBuilder());
3849 MatmulOp::getDefaultIndexingMaps(
getContext()),
3851 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3852 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3854 std::array<StringRef, 3> elidedAttrs = {
3855 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3863 if (!hasUserDefinedMaps())
3866 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3877 void MatmulOp::getEffects(
3880 if (hasPureTensorSemantics())
3894 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3904 for (
auto result : outAffineMap.
getResults()) {
3905 auto dimExpr = dyn_cast<AffineDimExpr>(result);
3906 assert(dimExpr &&
"affine_map is a projected permutation");
3907 dimsInOutput[dimExpr.getPosition()] =
true;
3911 for (
auto dimOccursInOutput : dimsInOutput)
3912 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3913 : utils::IteratorType::reduction);
3915 return iteratorTypes;
3918 unsigned ContractOp::getNumRegionArgs() {
return 3; }
3925 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
3930 "ContractOp regionBuilder expects 3 args");
3931 RegionBuilderHelper helper(b, block);
3933 TypeFn castSignedness = TypeFn::cast_signed;
3934 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3935 return attr.
getName() ==
"cast";
3937 if (castIter != attrs.end()) {
3938 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3944 Value lhsAtOutType =
3945 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
3946 Value rhsAtOutType =
3947 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
3948 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
3950 if (!productAtOutType)
3956 helper.yieldOutputs({result});
3961 if (failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
3963 "expected 'indexing_maps' attribute");
3964 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3971 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3973 p, getOperation(), getInputs(), getOutputs(),
3974 {
"indexing_maps",
"operandSegmentSizes"});
3978 int iterationSpaceDims = -1;
3987 auto checkAffineMapAndType = [&](
AffineMap affineMap,
Type operandType,
3988 bool isInput) -> LogicalResult {
3991 return emitError(
"provided affine_map is not a projected permutation");
3994 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
3996 return emitError(
"ranks of shaped operand and results of corresponding "
3997 "affine_map differ");
3999 return emitError(
"affine_map specifies shaped access while operand has "
4004 if (iterationSpaceDims == -1) {
4008 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4009 return emitError(
"iteration spaces of provided affine_maps differ");
4014 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4016 llvm_unreachable(
"affine_map is a projected permutation");
4019 inOccurrences[affineDimExpr.getPosition()] += 1;
4021 outOccurrences[affineDimExpr.getPosition()] += 1;
4027 for (
auto &&[affineMap, operandType, isInput] :
4028 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4030 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4034 bool hasContractingDim =
false;
4035 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4036 size_t inOccCount = inOccurrences[dimIndex];
4037 size_t outOccCount = outOccurrences[dimIndex];
4040 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4042 if (inOccCount == 0 && outOccCount == 0)
4043 return emitError() <<
"iteration space dim at index " << dimIndex
4044 <<
" not used to access any operand";
4055 if (inOccCount == 1 && outOccCount != 1)
4057 <<
"iteration space dim at index " << dimIndex
4058 <<
" is neither a contracting dim nor of parallel iteration type";
4061 if (!hasContractingDim)
4062 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4071 void ContractOp::getEffects(
4074 if (hasPureTensorSemantics())
4087 BatchMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
4091 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4092 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4093 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4094 return indexingMaps;
4099 utils::IteratorType::parallel, utils::IteratorType::parallel,
4100 utils::IteratorType::parallel, utils::IteratorType::reduction};
4103 unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4105 std::string BatchMatmulOp::getLibraryCallName() {
4111 bool BatchMatmulOp::hasUserDefinedMaps() {
4115 return defaultMaps != explicitMaps;
4125 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
bool isLHS) {
4127 "Expected less than 3 result dim expr.");
4128 bool isValid =
false;
4129 enum Indices { batchPos, mPos, nPos, kPos };
4147 void BatchMatmulOp::regionBuilder(
4151 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4156 "BatchMatmulOp regionBuilder expects 3 args");
4157 RegionBuilderHelper helper(b, block);
4160 TypeFn castVal = TypeFn::cast_signed;
4161 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4162 return attr.
getName() ==
"cast";
4164 if (castIter != attrs.end()) {
4165 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4170 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4171 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4172 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4174 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
4175 yields.push_back(addVal);
4176 helper.yieldOutputs(yields);
4192 if (!isa<AffineMapAttr>(mapAttr)) {
4194 "expected affine map attribute");
4196 indexingMapsAttr.push_back(mapAttr);
4206 if (indexingMapsAttr.empty()) {
4207 indexingMapsAttr = llvm::map_to_vector(
4208 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4215 BatchMatmulOp::getNumRegionArgs(),
4216 BatchMatmulOp::getRegionBuilder());
4221 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4223 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4224 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4226 std::array<StringRef, 3> elidedAttrs = {
4227 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4236 if (!hasUserDefinedMaps())
4239 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4246 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4251 void BatchMatmulOp::getEffects(
4254 if (hasPureTensorSemantics())
4268 struct ArityGroupAndKind {
4280 unsigned getArityGroupAsUInt(ElementwiseArityGroup
arityGroup) {
4286 constexpr
int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4287 constexpr
int lastBinary =
4288 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4289 constexpr
int lastTernary =
4290 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4292 int val =
static_cast<int>(
kind);
4293 ArityGroupAndKind result;
4295 if (val < lastUnary) {
4296 result.arityGroup = ElementwiseArityGroup::Unary;
4297 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4300 if (val < lastBinary) {
4301 result.arityGroup = ElementwiseArityGroup::Binary;
4302 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4305 if (val >= lastTernary) {
4306 llvm_unreachable(
"unhandled ElementwiseFn");
4308 result.arityGroup = ElementwiseArityGroup::Ternary;
4309 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4314 auto rank = getResultRank();
4319 ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4328 mlir::linalg::ElementwiseKind elemwiseKindVal;
4333 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4334 if (!elemwiseKindAttr)
4336 "expected ElementwiseKind attribute");
4337 elemwiseKindVal = elemwiseKindAttr.getValue();
4340 "expected operation 'kind' attribute");
4356 if (!isa<AffineMapAttr>(mapAttr))
4358 "expected affine map attribute");
4359 indexingMapsAttr.push_back(mapAttr);
4370 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4372 ElementwiseOp::getRegionBuilder())) {
4374 "unable to parse elemwise op");
4378 if (indexingMapsAttr.empty()) {
4382 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4385 "return type needs to be shaped type");
4386 auto numDims = shapedType.getRank();
4387 indexingMapsAttr = llvm::map_to_vector(
4388 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4405 unsigned numDims = getResultRank();
4408 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4412 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4413 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4428 void ElementwiseOp::regionBuilder(
4431 ElementwiseKind elemwiseKind;
4432 for (
auto attr : attrs) {
4434 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4435 assert(kindAttr &&
"op kind attribute incorrectly set");
4436 elemwiseKind = kindAttr.getValue();
4443 auto kind = groupAndKind.kind;
4446 emitError() <<
"Elementwise regionBuilder expects "
4447 << (getArityGroupAsUInt(
arityGroup) + 1) <<
" args, got "
4453 &&
"Elementwise regionBuilder number of block args mismatch");
4455 RegionBuilderHelper helper(b, block);
4459 if (
arityGroup == ElementwiseArityGroup::Unary) {
4462 }
else if (
arityGroup == ElementwiseArityGroup::Binary) {
4466 }
else if (
arityGroup == ElementwiseArityGroup::Ternary) {
4471 assert(
false &&
"found unhandled category in elemwise");
4474 yields.push_back(result);
4475 helper.yieldOutputs(yields);
4478 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4483 void ElementwiseOp::getEffects(
4486 if (hasPureTensorSemantics())
4509 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4511 .take_back(mixedTiles.size()),
4513 int64_t shape = std::get<0>(it);
4514 if (shape == ShapedType::kDynamic) {
4515 newMixedTileSizes.push_back(std::get<1>(it));
4522 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
4524 newMixedTileSizes.push_back(
tile);
4527 "tile size and dim size don't match!");
4528 newMixedTileSizes.push_back(
4533 return newMixedTileSizes;
4536 template <
typename OpTy>
4537 static LogicalResult
4540 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4541 "applies to only pack or unpack operations");
4542 int64_t destRank = op.getDestRank();
4544 reifiedReturnShapes[0] =
4549 template <
typename OpTy>
4551 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4552 "applies to only pack or unpack operations");
4556 assert(tiles.size() == dimsToTile.size() &&
4557 "tiles must match indices of dimension to block");
4559 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4560 dimAndTileMapping[dimsToTile[i]] = tiles[i];
4561 return dimAndTileMapping;
4564 template <
typename OpTy>
4566 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4567 "applies to only pack or unpack operations");
4570 unsigned dynamicValIndex = 0;
4571 for (int64_t staticTile : op.getStaticInnerTiles()) {
4572 if (!ShapedType::isDynamic(staticTile))
4575 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4577 return mixedInnerTiles;
4580 template <
typename OpTy>
4582 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4583 "applies to only pack or unpack operations");
4596 size_t dimsPosSize = dimsPos.size();
4597 if (dimsPosSize > rank)
4600 if (dimsPosSize != uniqued.size())
4602 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4603 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
4612 sourceShape.size() == limitShape.size() &&
4613 "expected source shape rank, and limit of the shape to have same rank");
4614 return llvm::all_of(
4615 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4616 int64_t sourceExtent = std::get<0>(it);
4617 int64_t limit = std::get<1>(it);
4618 return ShapedType::isDynamic(sourceExtent) ||
4619 ShapedType::isDynamic(limit) || sourceExtent <= limit;
4623 template <
typename OpTy>
4625 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4626 "applies to only pack or unpack operations");
4627 Operation *op = packOrUnPack.getOperation();
4636 if (hasZeros(mixedTiles))
4637 return op->
emitError(
"invalid zero tile factor");
4640 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4641 ? packOrUnPack.getSourceType()
4642 : packOrUnPack.getDestType();
4643 size_t unpackedRank = unpackedType.getRank();
4647 return op->
emitError(
"invalid inner_dims_pos vector");
4649 return op->
emitError(
"invalid outer_dims_perm vector");
4650 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4651 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
4655 if (mixedTiles.size() > unpackedRank) {
4656 return op->
emitError(
"tiling factors must be less than or equal to the "
4657 "input rank for pack or output rank for unpack");
4661 "tiling factors must equal the number of dimensions to tile");
4664 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4665 ? packOrUnPack.getDestType()
4666 : packOrUnPack.getSourceType();
4667 size_t packedRank = packedType.getRank();
4669 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4670 if (expectedPackedRank != packedRank) {
4672 "packed rank != (unpacked rank + num tiling factors), got ")
4673 << packedRank <<
" != " << expectedPackedRank;
4679 RankedTensorType expectedPackedType = PackOp::inferPackedType(
4680 unpackedType, packOrUnPack.getStaticTiles(),
innerDimsPos, outerDimPerm);
4681 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4682 return op->
emitError(
"the shape of output is not large enough to hold the "
4683 "packed data. Expected at least ")
4684 << expectedPackedType <<
", got " << packedType;
4687 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4689 [](std::tuple<int64_t, OpFoldResult> it) {
4690 int64_t shape = std::get<0>(it);
4691 if (Attribute attr =
4692 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4693 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4694 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4695 return shape == staticTileSize;
4697 return ShapedType::isDynamic(shape);
4699 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
4700 "tiled dimension in the packed type");
4712 struct PackOrUnPackTransposeResult {
4719 template <
typename OpTy>
4720 static PackOrUnPackTransposeResult
4724 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4725 "applies to only pack or unpack operations");
4726 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4727 "some permutation must be non-empty");
4728 PackOrUnPackTransposeResult metadata;
4729 metadata.innerDimsPos =
4731 metadata.innerTiles =
4733 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4734 ? packOrUnPackOp.getSourceRank()
4735 : packOrUnPackOp.getDestRank();
4736 metadata.outerDimsPerm =
4737 packOrUnPackOp.getOuterDimsPerm().empty()
4738 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4740 if (!innerPermutation.empty()) {
4741 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4743 "invalid inner permutation");
4747 if (!outerPermutation.empty()) {
4748 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4750 "invalid outer permutation");
4760 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
4761 setNameFn(getResult(),
"pack");
4767 std::optional<Value> paddingValue,
4770 "number of tile sizes specified must match the specified number of "
4771 "original dimensions to be tiled");
4775 build(builder, state, dest.
getType(), source, dest,
4776 paddingValue ? *paddingValue :
nullptr,
4802 ShapedType inputType = getSourceType();
4803 int64_t inputRank = inputType.getRank();
4804 return getDestType().getShape().take_front(inputRank);
4809 auto packedShape = getDestType().getShape();
4813 res.push_back(packedShape[index]);
4824 outputShape.take_front(inputShape.size()));
4827 "expected output and outer_dims_perm to have same size");
4832 if (ShapedType::isDynamic(inputShape[pos]))
4836 if (!constantTile) {
4837 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4838 (inputShape[pos] % outputTileSizes[pos] != 0))
4840 }
else if (inputShape[pos] % (*constantTile) != 0) {
4854 auto paddingValue = getPaddingValue();
4857 return emitOpError(
"expected padding_value has ")
4858 << getSourceType().getElementType()
4859 <<
" but got: " << paddingValue.getType();
4862 if (!paddingValue &&
4863 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
4864 getDestType().
getShape(), getOuterDimsPerm(),
4867 "invalid tile factor or output size provided. Only full tiles are "
4868 "supported when padding_value is not set");
4878 for (
auto o : ofrs) {
4880 if (llvm::dyn_cast_if_present<Value>(o))
4881 result.push_back(ShapedType::kDynamic);
4896 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4898 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4899 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4902 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4903 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4911 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4926 builder, loc, ceilDivExpr,
4927 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4931 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4942 for (
unsigned i = 0; i < resultDims.size(); ++i) {
4943 if (!ShapedType::isDynamic(resultTypeShape[i]))
4954 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4976 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
4977 if (ShapedType::isDynamic(value))
4978 mixedSizes.push_back(
4983 for (
auto it : llvm::zip(
innerDimsPos, innerTileSizes)) {
4984 int64_t dimPos = std::get<0>(it);
4986 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4989 applyPermutationToVector<OpFoldResult>(mixedSizes,
outerDimsPerm);
4991 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4992 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4993 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
5000 *
this, innerPermutation, outerPermutation);
5001 Value transposedDest =
5002 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5003 metadata.innerDimsPos, metadata.outerDimsPerm);
5004 return b.
create<PackOp>(loc, getSource(), transposedDest,
5005 metadata.innerDimsPos, metadata.innerTiles,
5006 getPaddingValue(), metadata.outerDimsPerm);
5010 template <
typename OpTy>
5012 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5013 "applies to only pack or unpack operations");
5014 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5016 : op.getSourceType();
5018 for (
auto [dimDest,
tile] : llvm::zip(
5019 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5021 if (!constTileSize || ShapedType::isDynamic(dimDest))
5028 if (getPaddingValue())
5043 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5045 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5057 auto packTiles = packOp.getMixedTiles();
5058 auto unPackTiles = unPackOp.getMixedTiles();
5059 if (packTiles.size() != unPackTiles.size())
5061 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5070 auto srcType = op.getSourceType();
5071 if (llvm::any_of(op.getInnerDimsPos(),
5072 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5074 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5076 return !PackOp::requirePaddingValue(
5077 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5078 op.getOuterDimsPerm(), op.getMixedTiles());
5085 bool changeNeeded =
false;
5086 srcShape.assign(packOp.getSourceType().getShape().begin(),
5087 packOp.getSourceType().getShape().end());
5088 destShape.assign(packOp.getDestType().getShape().begin(),
5089 packOp.getDestType().getShape().end());
5090 llvm::SmallSetVector<int64_t, 4> innerDims;
5091 innerDims.insert_range(packOp.getInnerDimsPos());
5093 if (!packOp.getOuterDimsPerm().empty())
5095 int srcRank = packOp.getSourceRank();
5096 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5097 if (innerDims.contains(i))
5100 int64_t destPos = i;
5101 if (!inverseOuterDimsPerm.empty())
5102 destPos = inverseOuterDimsPerm[srcPos];
5103 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5104 ShapedType::isDynamic(destShape[destPos])) {
5107 int64_t size = srcShape[srcPos];
5108 if (ShapedType::isDynamic(size))
5109 size = destShape[destPos];
5110 srcShape[srcPos] = size;
5111 destShape[destPos] = size;
5112 changeNeeded =
true;
5114 return changeNeeded;
5117 LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5119 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5120 if (unPackOp.getSourceType() != packOp.getDestType())
5122 if (packOp.getPaddingValue() ||
5126 rewriter.
replaceOp(packOp, unPackOp.getSource());
5133 packOp.getPaddingValueMutable().clear();
5142 Value source = packOp.getSource();
5143 if (srcShape != packOp.getSourceType().getShape()) {
5144 auto newSrcType = packOp.getSourceType().clone(srcShape);
5146 rewriter.
create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
5148 Value dest = packOp.getDest();
5149 RankedTensorType originalResultType = packOp.getDestType();
5150 bool needUpdateDestType = (destShape != originalResultType.getShape());
5151 if (needUpdateDestType) {
5152 auto newDestType = packOp.getDestType().clone(destShape);
5154 rewriter.
create<tensor::CastOp>(loc, newDestType, packOp.getDest());
5157 packOp.getSourceMutable().assign(source);
5158 packOp.getDestMutable().assign(dest);
5159 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5162 if (needUpdateDestType) {
5165 rewriter.
create<tensor::CastOp>(loc, originalResultType, packOp);
5174 template <
typename PackOrUnpackOp>
5176 RankedTensorType packedTensorType) {
5177 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5178 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5179 "Function meant for pack/unpack");
5185 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5192 int64_t packedRank = packedTensorType.getRank();
5202 return llvm::all_of(
5203 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5204 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
5207 bool PackOp::isLikePad() {
5208 auto packedTensorType =
5209 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5214 std::optional<Attribute> paddingValue;
5215 if (
auto pad = adaptor.getPaddingValue())
5217 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5218 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5219 getDestType(), paddingValue))
5220 return reshapedSource;
5258 PackOp newOp = rewriter.
create<PackOp>(
5259 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
5260 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
5264 Value oldResult = op.getResult();
5265 Value newResult = newOp.getResult();
5267 ? rewriter.
create<tensor::CastOp>(
5268 op->getLoc(), oldResult.
getType(), newResult)
5281 void UnPackOp::getAsmResultNames(
5283 setNameFn(getResult(),
"unpack");
5305 ShapedType destType = getDestType();
5306 int64_t destRank = destType.getRank();
5307 return getSourceType().getShape().take_front(destRank);
5312 auto packedShape = getSourceType().getShape();
5316 res.push_back(packedShape[index]);
5338 "number of tile sizes specified must match the specified number of "
5339 "original dimensions to be tiled");
5343 build(builder, state, dest.
getType(), source, dest,
5362 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
5364 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5365 if (srcType.isDynamicDim(i))
5366 mixedSizes.push_back(b.
create<tensor::DimOp>(loc, source, i).
getResult());
5368 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
5371 applyPermutationToVector<OpFoldResult>(
5375 for (
auto [dimPos, tileSize] : llvm::zip_equal(
innerDimsPos, innerTileSizes))
5376 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5378 auto elemType = srcType.getElementType();
5379 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
5383 Value transposedSource,
5387 *
this, innerPermutation, outerPermutation);
5388 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
5389 metadata.innerDimsPos, metadata.innerTiles,
5390 metadata.outerDimsPerm);
5397 bool changeNeeded =
false;
5398 srcShape.assign(op.getSourceType().getShape().begin(),
5399 op.getSourceType().getShape().end());
5400 destShape.assign(op.getDestType().getShape().begin(),
5401 op.getDestType().getShape().end());
5402 llvm::SmallSetVector<int64_t, 4> innerDims;
5403 innerDims.insert_range(op.getInnerDimsPos());
5405 if (!op.getOuterDimsPerm().empty())
5407 int destRank = op.getDestRank();
5408 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
5409 if (innerDims.contains(i))
5412 int64_t destPos = i;
5413 if (!inverseOuterDimsPerm.empty())
5414 srcPos = inverseOuterDimsPerm[destPos];
5415 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5416 ShapedType::isDynamic(destShape[destPos])) {
5419 int64_t size = srcShape[srcPos];
5420 if (ShapedType::isDynamic(size))
5421 size = destShape[destPos];
5422 srcShape[srcPos] = size;
5423 destShape[destPos] = size;
5424 changeNeeded =
true;
5426 return changeNeeded;
5429 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5432 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5433 if (packOp.getSourceType() != unPackOp.getDestType())
5435 if (packOp.getPaddingValue() ||
5439 rewriter.
replaceOp(unPackOp, packOp.getSource());
5443 if (
auto dstStyleOp =
5444 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5445 auto destValue = cast<OpResult>(unPackOp.getDest());
5446 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5448 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5452 if (unPackOp->hasOneUse()) {
5453 auto extractSliceUser =
5454 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5455 if (extractSliceUser &&
5458 extractSliceUser.getSourceType().getRank() ==
5459 extractSliceUser.getResultType().getRank()) {
5462 auto newDest = rewriter.
create<tensor::ExtractSliceOp>(
5463 unPackOp->getLoc(), unPackOp.getDest(),
5464 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5465 extractSliceUser.getMixedStrides());
5467 unPackOp.setDpsInitOperand(0, newDest);
5468 unPackOp.getResult().setType(newDest.
getType());
5470 rewriter.
replaceOp(extractSliceUser, unPackOp);
5479 Value source = unPackOp.getSource();
5480 if (srcShape != unPackOp.getSourceType().getShape()) {
5481 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5482 source = rewriter.
create<tensor::CastOp>(loc, newSrcType,
5483 unPackOp.getSource());
5485 Value dest = unPackOp.getDest();
5486 if (destShape != unPackOp.getDestType().getShape()) {
5487 auto newDestType = unPackOp.getDestType().clone(destShape);
5489 rewriter.
create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
5492 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5493 unPackOp.getOuterDimsPerm());
5495 unPackOp, unPackOp.getResult().getType(), newOp);
5502 bool UnPackOp::isLikeUnPad() {
5503 RankedTensorType packedTensorType = getSourceType();
5508 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5509 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5511 return reshapedSource;
5540 Value sourceTensor = newOperands[0];
5544 rewriter, sourceTensor.
getType(), op.getMixedTiles());
5550 UnPackOp newOp = rewriter.
create<UnPackOp>(
5551 op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5552 newMixedTileSizes, op.getOuterDimsPerm());
5556 Value oldResult = op.getResult();
5557 Value newResult = newOp.getResult();
5559 ? rewriter.
create<tensor::CastOp>(
5560 op->getLoc(), oldResult.
getType(), newResult)
5574 utils::IteratorType::reduction, utils::IteratorType::parallel,
5575 utils::IteratorType::parallel, utils::IteratorType::reduction};
5579 BatchReduceMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
5583 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
5584 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
5586 return indexingMaps;
5589 unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
5591 std::string BatchReduceMatmulOp::getLibraryCallName() {
5597 bool BatchReduceMatmulOp::hasUserDefinedMaps() {
5601 return defaultMaps != explicitMaps;
5611 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
5614 "Expected less than 3 result dim expr.");
5615 bool isValid =
false;
5616 enum Indices { batchPos, mPos, nPos, kPos };
5634 void BatchReduceMatmulOp::regionBuilder(
5638 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
5643 "BatchReduceMatmulOp regionBuilder expects 3 args");
5644 RegionBuilderHelper helper(b, block);
5649 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
5651 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
5652 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
5654 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
5655 yields.push_back(addVal);
5656 helper.yieldOutputs(yields);
5672 if (!isa<AffineMapAttr>(mapAttr)) {
5674 "expected affine map attribute");
5676 indexingMapsAttr.push_back(mapAttr);
5686 if (indexingMapsAttr.empty()) {
5687 indexingMapsAttr = llvm::map_to_vector(
5688 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
5694 BatchReduceMatmulOp::getNumRegionArgs(),
5695 BatchReduceMatmulOp::getRegionBuilder());
5700 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
5703 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
5704 p <<
" indexing_maps = [";
5705 llvm::interleaveComma(getIndexingMaps(), p,
5711 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
5720 if (!hasUserDefinedMaps())
5723 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
5729 LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
5733 void BatchReduceMatmulOp::getEffects(
5736 if (hasPureTensorSemantics())
5752 void LinalgDialect::getCanonicalizationPatterns(
5761 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
union mlir::linalg::@1216::ArityGroupAndKind::Kind kind
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
SmallVector< int64_t > outerDimsPerm
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
SmallVector< OpFoldResult > innerTiles
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
ElementwiseArityGroup arityGroup
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
SmallVector< int64_t > innerDimsPos
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Kind
An enumeration of the kinds of predicates.
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override