MLIR 22.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
14
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Attributes.h"
29#include "mlir/IR/Builders.h"
38
39#include "llvm/ADT/DenseMap.h"
40#include "llvm/ADT/STLExtras.h"
41#include "llvm/ADT/SetOperations.h"
42#include "llvm/ADT/SmallVector.h"
43#include "llvm/ADT/StringSet.h"
44#include "llvm/ADT/TypeSwitch.h"
45#include "llvm/Support/FormatVariadic.h"
46#include "llvm/Support/InterleavedRange.h"
47#include "llvm/Support/LogicalResult.h"
48#include "llvm/Support/MathExtras.h"
49#include "llvm/Support/raw_ostream.h"
50#include <cassert>
51#include <optional>
52
53using namespace mlir;
54using namespace mlir::linalg;
55
56/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
58 int64_t dim) {
59 auto type = cast<ShapedType>(v.getType());
60 if (!type.isDynamicDim(dim))
61 return builder.getIndexAttr(type.getDimSize(dim));
62
63 return getAsOpFoldResult(
65 .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
66 return tensor::DimOp::create(builder, loc, v, dim);
67 })
68 .Case<MemRefType>([&](MemRefType t) -> Value {
69 return memref::DimOp::create(builder, loc, v, dim);
70 }));
71}
72
73/// Returns a memref.subview or a tensor.extract_slice based on the type of the
74/// `source`.
78 ArrayRef<OpFoldResult> strides) {
80 .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
81 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
82 strides);
83 })
84 .Case<MemRefType>([&](MemRefType type) -> Operation * {
85 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
86 strides);
87 })
88 .Default([&](Type t) -> Operation * { return nullptr; });
89}
90
91//===----------------------------------------------------------------------===//
92// Helper functions
93//===----------------------------------------------------------------------===//
94
96 int64_t dim) {
97 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
98 return b.createOrFold<memref::DimOp>(loc, source, dim);
99 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
100 return b.createOrFold<tensor::DimOp>(loc, source, dim);
101 llvm_unreachable("Expected MemRefType or TensorType");
102}
103
105 int64_t dim) {
106 auto shapedType = llvm::cast<ShapedType>(source.getType());
107 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
108 return createOrFoldDimOp(b, loc, source, dim);
109 return b.getIndexAttr(shapedType.getDimSize(dim));
110}
111
112//===----------------------------------------------------------------------===//
113// Support for named Linalg ops defined in ods-gen.
114//===----------------------------------------------------------------------===//
115
119
120/// Fills the region of a structured operation using the provided
121/// `regionBuilder`. The method is used by both named structured ops created by
122/// ods-gen and by manually defined C++ ops. It is called by both builders and
123/// parsers and creates a block with arguments corresponding to the elemental
124/// types of `inputTypes` and `outputTypes`.
125static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
126 TypeRange inputTypes, TypeRange outputTypes,
129 RegionBuilderFn regionBuilder) {
130 SmallVector<Type, 8> argTypes;
132 for (auto containers : {inputTypes, outputTypes}) {
133 for (auto t : containers) {
134 argTypes.push_back(
135 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
136
137 // TODO: Pass in a proper location here.
138 argLocs.push_back(opBuilder.getUnknownLoc());
139 }
140 }
141
142 // RAII.
143 OpBuilder::InsertionGuard guard(opBuilder);
144 Block *body =
145 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
146
147 opBuilder.setInsertionPointToStart(body);
148 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
149 regionBuilder(b, *body, attrs, emitError);
150
151 // indexing_maps is an auto-generated method.
152
153 // iterator_types is an auto-generated method.
154}
155
156/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
157/// The result types are derived automatically if `resultTensorTypes` is none.
158/// The body of the operation is filled using `regionBuilder`. All ods-gen
159/// created structured operations use the method to implement their builders.
161 std::optional<TypeRange> resultTensorTypes,
162 ValueRange inputs, ValueRange outputs,
163 ArrayRef<NamedAttribute> attributes,
164 RegionBuilderFn regionBuilder) {
165 // Derive the result types if needed.
166 SmallVector<Type> derivedResultTypes =
167 resultTensorTypes.value_or(TypeRange());
168 if (!resultTensorTypes)
169 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
170 llvm::IsaPred<RankedTensorType>);
171
172 state.addOperands(inputs);
173 state.addOperands(outputs);
174 state.addTypes(derivedResultTypes);
175
176 state.addAttributes(attributes);
177 state.addAttribute(
178 "operandSegmentSizes",
179 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
180 static_cast<int32_t>(outputs.size())}));
181
182 // Create and fill the region of the structured operation.
183 Region &region = *state.addRegion();
184 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
185 state.attributes.getAttrs(), /*emitError=*/{},
186 regionBuilder);
187}
188
190 std::optional<TypeRange> resultTensorTypes,
191 ValueRange inputs, ValueRange outputs,
192 ArrayRef<NamedAttribute> attributes,
193 RegionBuilderFn regionBuilder,
194 ArrayRef<AffineMap> indexingMaps) {
195 // Initialize indexingMaps attribute, for MatmulOp.
196 SmallVector<Attribute, 3> indexingMapsAttrVal;
197 indexingMapsAttrVal =
198 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
199 return AffineMapAttr::get(map);
200 });
201 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
202 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
203 attributes, regionBuilder);
204}
205
207 std::optional<TypeRange> resultTensorTypes,
208 ValueRange inputs, ValueRange outputs,
209 ArrayRef<NamedAttribute> attributes,
210 RegionBuilderFn regionBuilder,
211 ArrayRef<AffineMap> indexingMaps) {
212 // Initialize indexingMaps attribute, for BatchMatmulOp.
213 SmallVector<Attribute, 4> indexingMapsAttrVal;
214 indexingMapsAttrVal =
215 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
216 return AffineMapAttr::get(map);
217 });
218 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
219 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
220 attributes, regionBuilder);
221}
222
224 std::optional<TypeRange> resultTensorTypes,
225 ValueRange inputs, ValueRange outputs,
226 ArrayRef<NamedAttribute> attributes,
227 RegionBuilderFn regionBuilder,
228 ArrayRef<AffineMap> indexingMaps) {
229 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
230 SmallVector<Attribute, 4> indexingMapsAttrVal;
231 indexingMapsAttrVal =
232 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
233 return AffineMapAttr::get(map);
234 });
235 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
236 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
237 attributes, regionBuilder);
238}
239
240/// Common parsing used for both named structured ops created by ods-gen and by
241/// manually defined C++ ops. Does not handle regions.
242static ParseResult
244 SmallVectorImpl<Type> &inputTypes,
245 SmallVectorImpl<Type> &outputTypes,
246 bool addOperandSegmentSizes = true) {
247 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
249 outputsOperands;
250
251 if (succeeded(parser.parseOptionalLess())) {
252 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
253 return failure();
254 }
255 attrsLoc = parser.getCurrentLocation();
256 if (parser.parseOptionalAttrDict(result.attributes))
257 return failure();
258
259 if (succeeded(parser.parseOptionalKeyword("ins"))) {
260 if (parser.parseLParen())
261 return failure();
262
263 inputsOperandsLoc = parser.getCurrentLocation();
264 if (parser.parseOperandList(inputsOperands) ||
265 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
266 return failure();
267 }
268
269 if (succeeded(parser.parseOptionalKeyword("outs"))) {
270 outputsOperandsLoc = parser.getCurrentLocation();
271 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
272 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
273 return failure();
274 }
275
276 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
277 result.operands) ||
278 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
279 result.operands))
280 return failure();
281
282 if (addOperandSegmentSizes) {
283 // This is a bit complex because we're trying to be backward compatible with
284 // operation syntax that mix the inherent attributes and the discardable
285 // ones in the same dictionary. If the properties are used, we append the
286 // operandSegmentSizes there directly. Otherwise we append it to the
287 // discardable attributes dictionary where it is handled by the generic
288 // Operation::create(...) method.
289 if (result.propertiesAttr) {
290 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
291 attrs.append("operandSegmentSizes",
293 {static_cast<int32_t>(inputsOperands.size()),
294 static_cast<int32_t>(outputsOperands.size())}));
295 result.propertiesAttr = attrs.getDictionary(parser.getContext());
296 } else {
297 result.addAttribute("operandSegmentSizes",
299 {static_cast<int32_t>(inputsOperands.size()),
300 static_cast<int32_t>(outputsOperands.size())}));
301 }
302 }
303 if (!result.propertiesAttr) {
304 std::optional<RegisteredOperationName> info =
305 result.name.getRegisteredInfo();
306 if (info) {
307 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
308 return parser.emitError(attrsLoc)
309 << "'" << result.name.getStringRef() << "' op ";
310 })))
311 return failure();
312 }
313 }
314 return success();
315}
316
318 ValueRange outputs) {
319 if (!inputs.empty())
320 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
321 if (!outputs.empty())
322 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
323}
324
325//===----------------------------------------------------------------------===//
326// Specific parsing and printing for named structured ops created by ods-gen.
327//===----------------------------------------------------------------------===//
328
330 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
331 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
332 RegionBuilderFn regionBuilder, SMLoc loc) {
333 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
334 return parser.emitError(
335 parser.getCurrentLocation(),
336 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
337 "region expects {0} args, got {1}",
338 numRegionArgs, inputTypes.size() + outputTypes.size()));
339 }
340
341 OpBuilder opBuilder(parser.getContext());
342 ParseResult result = success();
344 opBuilder, region, inputTypes, outputTypes, attrs,
345 [&]() {
346 result = failure();
347 return parser.emitError(loc);
348 },
349 regionBuilder);
350 return result;
351}
352
353static ParseResult
355 SmallVectorImpl<Type> &resultTypes) {
356 if (parser.parseOptionalArrowTypeList(resultTypes))
357 return failure();
358 return success();
359}
360
361static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
363 unsigned numRegionArgs,
364 RegionBuilderFn regionBuilder) {
365 // TODO: Enable when ods-gen supports captures.
366 SmallVector<Type, 1> inputTypes, outputTypes;
367 SMLoc loc = parser.getCurrentLocation();
368 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
369 return failure();
370
371 // Parse optional attributes.
372 if (parser.parseOptionalAttrDict(result.attributes))
373 return failure();
374
375 // TODO: consider merging results parsing into region parsing.
376 // Need to wait for declarative assembly resolution to decide.
377 SmallVector<Type, 1> outputTensorsTypes;
378 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
379 return failure();
380 result.addTypes(outputTensorsTypes);
381
382 std::unique_ptr<Region> region = std::make_unique<Region>();
383 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
384 outputTypes, result.attributes.getAttrs(),
385 regionBuilder, loc))
386 return failure();
387 result.addRegion(std::move(region));
388
389 return success();
390}
391
393 TypeRange resultTypes) {
394 if (resultTypes.empty())
395 return;
396 p.printOptionalArrowTypeList(resultTypes);
397}
398
400 ValueRange inputs, ValueRange outputs,
401 ArrayRef<StringRef> elidedAttrs = {}) {
402 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
403
404 // Printing is shared with generic ops, except for the region and
405 // attributes.
406 printCommonStructuredOpParts(p, inputs, outputs);
407
408 // Results printing.
410
411 // Region is elided.
412}
413
414//===----------------------------------------------------------------------===//
415// Region builder helper.
416// TODO: Move this to a utility library.
417// The public methods on this class are referenced directly from generated code.
418// Helper build the unary, binary, and type conversion functions defined by the
419// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
420// class.
421//
422// Implementations of the math functions must be polymorphic over numeric types,
423// internally performing necessary casts. If the function application makes no
424// sense, then the only recourse is to assert and return nullptr. This can be
425// extended later if it becomes possible to fail construction of the region. The
426// invariant should be enforced at a higher level.
427//
428// TODO: These helpers are currently type polymorphic over the class of integer
429// and floating point types, but they will not internally cast within bit
430// widths of a class (mixed precision such as i8->i32) or across classes
431// (i.e. mixed float and integer). Many such combinations are ambiguous or need
432// to be handled with care and work is being considered to extend the op
433// language to make such cases explicit. In the mean-time, violating this will
434// fail verification, which is deemed acceptable.
435//===----------------------------------------------------------------------===//
436
437namespace {
438
439class RegionBuilderHelper {
440public:
441 RegionBuilderHelper(OpBuilder &builder, Block &block)
442 : builder(builder), block(block) {}
443
444 // Build the unary functions defined by OpDSL.
445 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
446 function_ref<InFlightDiagnostic()> emitError = {}) {
447 if (!isFloatingPoint(arg)) {
448 if (emitError) {
449 emitError() << "unsupported non numeric type";
450 return nullptr;
451 }
452 llvm_unreachable("unsupported non numeric type");
453 }
454 OpBuilder::InsertionGuard g(builder);
455 builder.setInsertionPointToEnd(&block);
456 switch (unaryFn) {
457 case UnaryFn::exp:
458 return math::ExpOp::create(builder, arg.getLoc(), arg);
459 case UnaryFn::log:
460 return math::LogOp::create(builder, arg.getLoc(), arg);
461 case UnaryFn::abs:
462 return math::AbsFOp::create(builder, arg.getLoc(), arg);
463 case UnaryFn::ceil:
464 return math::CeilOp::create(builder, arg.getLoc(), arg);
465 case UnaryFn::floor:
466 return math::FloorOp::create(builder, arg.getLoc(), arg);
467 case UnaryFn::negf:
468 return arith::NegFOp::create(builder, arg.getLoc(), arg);
469 case UnaryFn::reciprocal: {
470 Attribute oneAttr = builder.getOneAttr(arg.getType());
471 auto one = arith::ConstantOp::create(builder, arg.getLoc(),
472 ::cast<TypedAttr>(oneAttr));
473 return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
474 }
475 case UnaryFn::round:
476 return math::RoundOp::create(builder, arg.getLoc(), arg);
477 case UnaryFn::sqrt:
478 return math::SqrtOp::create(builder, arg.getLoc(), arg);
479 case UnaryFn::rsqrt:
480 return math::RsqrtOp::create(builder, arg.getLoc(), arg);
481 case UnaryFn::square:
482 return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
483 case UnaryFn::tanh:
484 return math::TanhOp::create(builder, arg.getLoc(), arg);
485 case UnaryFn::erf:
486 return math::ErfOp::create(builder, arg.getLoc(), arg);
487 }
488 if (emitError) {
489 emitError() << "unsupported unary function";
490 return nullptr;
491 }
492 llvm_unreachable("unsupported unary function");
493 }
494
495 // Build the binary functions defined by OpDSL.
496 // If emitError is provided, an error will be emitted if the operation is not
497 // supported and a nullptr will be returned, otherwise an assertion will be
498 // raised.
499 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
500 function_ref<InFlightDiagnostic()> emitError = {}) {
501 bool allComplex = isComplex(arg0) && isComplex(arg1);
502 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
503 bool allInteger = isInteger(arg0) && isInteger(arg1);
504 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
505 arg1.getType().getIntOrFloatBitWidth() == 1;
506 if (!allComplex && !allFloatingPoint && !allInteger) {
507 if (emitError) {
508 emitError()
509 << "Cannot build binary Linalg operation: expects allComplex, "
510 "allFloatingPoint, or allInteger, got "
511 << arg0.getType() << " and " << arg1.getType();
512 return nullptr;
513 }
514 llvm_unreachable("unsupported non numeric type");
515 }
516 OpBuilder::InsertionGuard g(builder);
517 builder.setInsertionPointToEnd(&block);
518 switch (binaryFn) {
519 case BinaryFn::add:
520 if (allComplex)
521 return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
522 if (allFloatingPoint)
523 return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
524 if (allBool)
525 return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
526 return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
527 case BinaryFn::sub:
528 if (allComplex)
529 return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
532 if (allBool) {
533 if (emitError) {
534 emitError() << "unsupported operation: sub with bools";
535 return nullptr;
536 }
537 llvm_unreachable("unsupported operation: sub with bools");
538 }
539 return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
540 case BinaryFn::mul:
541 if (allComplex)
542 return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
543 if (allFloatingPoint)
544 return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
545 if (allBool)
546 return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
547 return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
548 case BinaryFn::div:
549 if (allComplex)
550 return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
553 if (allBool) {
554 if (emitError) {
555 emitError() << "unsupported operation: div with bools";
556 return nullptr;
557 }
558 llvm_unreachable("unsupported operation: div with bools");
559 }
560 return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
561 case BinaryFn::div_unsigned:
562 if (!allInteger || allBool) {
563 if (emitError) {
564 emitError() << "unsupported operation: unsigned div not on uint";
565 return nullptr;
566 }
567 llvm_unreachable("unsupported operation: unsigned div not on uint");
568 }
569 return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
570 case BinaryFn::max_signed:
571 assert(!allComplex);
572 if (allFloatingPoint)
573 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
574 return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
575 case BinaryFn::min_signed:
576 assert(!allComplex);
577 if (allFloatingPoint)
578 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
579 return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
580 case BinaryFn::max_unsigned:
581 assert(!allComplex);
582 if (!allInteger || allBool) {
583 if (emitError) {
584 emitError() << "unsupported operation: unsigned max not on uint";
585 return nullptr;
586 }
587 llvm_unreachable("unsupported operation: unsigned max not on uint");
588 }
589 return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
590 case BinaryFn::min_unsigned:
591 assert(!allComplex);
592 if (!allInteger || allBool) {
593 if (emitError) {
594 emitError() << "unsupported operation: unsigned min not on uint";
595 return nullptr;
596 }
597 llvm_unreachable("unsupported operation: unsigned min not on uint");
598 }
599 return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
600 case BinaryFn::powf:
601 assert(allFloatingPoint);
602 return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
603 }
604 if (emitError) {
605 emitError() << "unsupported binary function";
606 return nullptr;
607 }
608 llvm_unreachable("unsupported binary function");
609 }
610
611 // Build the ternary functions defined by OpDSL.
612 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
613 function_ref<InFlightDiagnostic()> emitError = {}) {
614 bool headBool =
615 isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
616 bool tailFloatingPoint =
617 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
618 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
619 OpBuilder::InsertionGuard g(builder);
620 builder.setInsertionPointToEnd(&block);
621 switch (ternaryFn) {
622 case TernaryFn::select:
623 if (!headBool && !(tailFloatingPoint || tailInteger))
624 llvm_unreachable("unsupported non numeric type");
625 return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
626 }
627 if (emitError) {
628 emitError() << "unsupported ternary function";
629 return nullptr;
630 }
631 llvm_unreachable("unsupported ternary function");
632 }
633
634 // Build the type functions defined by OpDSL.
635 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
636 function_ref<InFlightDiagnostic()> emitError = {}) {
637 switch (typeFn) {
638 case TypeFn::cast_signed:
639 return cast(toType, operand, false);
640 case TypeFn::cast_unsigned:
641 return cast(toType, operand, true);
642 }
643 if (emitError) {
644 emitError() << "unsupported type conversion function";
645 return nullptr;
646 }
647 llvm_unreachable("unsupported type conversion function");
648 }
649
650 void yieldOutputs(ValueRange values) {
651 OpBuilder::InsertionGuard g(builder);
652 builder.setInsertionPointToEnd(&block);
653 Location loc = builder.getUnknownLoc();
654 YieldOp::create(builder, loc, values);
655 }
656
657 Value constant(const std::string &value) {
658 OpBuilder::InsertionGuard g(builder);
659 builder.setInsertionPointToEnd(&block);
660 Location loc = builder.getUnknownLoc();
661 Attribute valueAttr = parseAttribute(value, builder.getContext());
662 return arith::ConstantOp::create(builder, loc,
663 ::cast<TypedAttr>(valueAttr));
664 }
665
666 Value index(int64_t dim) {
667 OpBuilder::InsertionGuard g(builder);
668 builder.setInsertionPointToEnd(&block);
669 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
670 }
671
672 Type getIntegerType(unsigned width) {
673 return IntegerType::get(builder.getContext(), width);
674 }
675
676 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
677 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
678
679private:
680 // Generates operations to cast the given operand to a specified type.
681 // If the cast cannot be performed, a warning will be issued and the
682 // operand returned as-is (which will presumably yield a verification
683 // issue downstream).
684 Value cast(Type toType, Value operand, bool isUnsignedCast) {
685 OpBuilder::InsertionGuard g(builder);
686 builder.setInsertionPointToEnd(&block);
687 auto loc = operand.getLoc();
688 if (isa<UnknownLoc>(loc)) {
689 if (operand.getDefiningOp())
690 loc = operand.getDefiningOp()->getLoc();
691 else if (operand.getParentBlock() &&
692 operand.getParentBlock()->getParentOp())
693 loc = operand.getParentBlock()->getParentOp()->getLoc();
694 }
695 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
696 }
697
698 bool isComplex(Value value) {
699 return llvm::isa<ComplexType>(value.getType());
700 }
701 bool isFloatingPoint(Value value) {
702 return llvm::isa<FloatType>(value.getType());
703 }
704 bool isInteger(Value value) {
705 return llvm::isa<IntegerType>(value.getType());
706 }
707
708 OpBuilder &builder;
709 Block &block;
710};
711
712} // namespace
713
714//===----------------------------------------------------------------------===//
715// CopyOp
716//===----------------------------------------------------------------------===//
717
718namespace {
719
720struct EraseSelfCopy : OpRewritePattern<CopyOp> {
721 using OpRewritePattern<CopyOp>::OpRewritePattern;
722 LogicalResult matchAndRewrite(CopyOp copyOp,
723 PatternRewriter &rewriter) const override {
724 if (copyOp.getInputs() != copyOp.getOutputs())
725 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
726 if (copyOp.hasPureBufferSemantics())
727 rewriter.eraseOp(copyOp);
728 else
729 rewriter.replaceOp(copyOp, copyOp.getInputs());
730
731 return success();
732 }
733};
734
735} // namespace
736
737void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
738 MLIRContext *context) {
739 results.add<EraseSelfCopy>(context);
740}
741
742//===----------------------------------------------------------------------===//
743// FillOp
744//===----------------------------------------------------------------------===//
745
746namespace {
747
748/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
749///
750/// For such op chains, we can create new linalg.fill ops with the result
751/// type of the tensor.expand/collapse_shape op.
752template <typename TensorReshapeOp>
753struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
754 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
755 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
756 PatternRewriter &rewriter) const override {
757 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
758 if (!oldFill)
759 return failure();
760
761 Location loc = oldFill.getLoc();
762 TensorReshapeOp newInit;
763 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
764
765 newInit = TensorReshapeOp::create(
766 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
767 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
768 reshapeOp.getStaticOutputShape());
769 } else {
770 newInit = TensorReshapeOp::create(
771 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
772 reshapeOp.getReassociation());
773 }
774 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
775 ValueRange{newInit});
776 return success();
777 }
778};
779
780/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
781/// filling value are the same.
782struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
784
785 LogicalResult matchAndRewrite(tensor::PadOp padOp,
786 PatternRewriter &rewriter) const override {
787 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
788 if (!fillOp)
789 return failure();
790
791 // We can only fold if the padding value is the same as the original
792 // filling value.
793 Value padValue = padOp.getConstantPaddingValue();
794 if (!padValue || fillOp.value() != padValue)
795 return failure();
796
797 ReifiedRankedShapedTypeDims reifiedShape;
798 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
799 return rewriter.notifyMatchFailure(
800 padOp, "failed to reify tensor.pad op result shape");
801
802 auto emptyTensor =
803 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
804 padOp.getResultType().getElementType());
805 Value replacement =
806 FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
807 ValueRange{emptyTensor})
808 .getResult(0);
809 if (replacement.getType() != padOp.getResultType()) {
810 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
811 padOp.getResultType(), replacement);
812 }
813 rewriter.replaceOp(padOp, replacement);
814 return success();
815 }
816};
817
818/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
819/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
820/// filling value are the same.
821struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
823
824 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
825 PatternRewriter &rewriter) const override {
826 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
827 if (!srcPadOp)
828 return failure();
829
830 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
831 return failure();
832
833 // Walk back the tensor.insert_slice chain and find the first destination
834 // value at the start of the chain.
835 Value firstDest = insertOp.getDest();
836 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
837 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
838 return failure();
839
840 // Make sure the range of values accessed are disjoint. Without this, we
841 // cannot fold tensor.pad away.
842 bool disjoint = false;
843 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
844 // If the dimension has dynamic offset/size, we cannot guarantee
845 // disjoint. So just skip it.
846 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
847 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
848 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
849 continue;
850
851 // Get the range start and end, inclusively for both.
852 int64_t prevStart = prevOp.getStaticOffset(i);
853 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
854 prevOp.getStaticStride(i);
855 int64_t nextStart = insertOp.getStaticOffset(i);
856 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
857 insertOp.getStaticStride(i);
858 if (prevEnd < nextStart || nextEnd < prevStart) {
859 disjoint = true;
860 break;
861 }
862 }
863
864 if (!disjoint)
865 break;
866 firstDest = prevOp.getDest();
867 }
868
869 // Check whether the first destination is a fill op. For overlapped cases,
870 // this also cannot be true.
871 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
872 if (!dstFillOp)
873 return failure();
874
875 // We can only fold if the padding value is the same as the original
876 // filling value.
877 Value padValue = srcPadOp.getConstantPaddingValue();
878 if (!padValue || dstFillOp.value() != padValue)
879 return failure();
880
881 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
882 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
883
884 Location loc = insertOp.getLoc();
885 MLIRContext *context = getContext();
886
887 AffineExpr sym0, sym1;
888 bindSymbols(context, sym0, sym1);
889 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
890
891 // Calculate the new offsets for the insert. It should be the old offsets
892 // plus low padding sizes.
893 SmallVector<OpFoldResult, 4> newOffsets;
894 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
895 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
896 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
897 }
898
899 RankedTensorType srcPadType = srcPadOp.getSourceType();
900 SmallVector<OpFoldResult, 4> newSizes;
901 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
902 if (srcPadType.isDynamicDim(i)) {
903 newSizes.push_back(
904 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
905 .getResult());
906 } else {
907 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
908 }
909 }
910
911 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
912 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
913 newSizes, insertOp.getMixedStrides());
914 return success();
915 }
916};
917
918/// Fold tensor.extract(linalg.fill(<input>)) into <input>
919struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
920public:
921 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
922
923 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
924 PatternRewriter &rewriter) const override {
925 // See if tensor input of tensor.extract op is the result of a linalg.fill
926 // op.
927 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
928 if (!fillOp)
929 return failure();
930
931 // Get scalar input operand of linalg.fill op.
932 Value extractedScalar = fillOp.getInputs()[0];
933
934 // Replace tensor.extract op with scalar value used to fill the tensor.
935 rewriter.replaceOp(extractOp, extractedScalar);
936 return success();
937 }
938};
939
940/// Folds pack(fill) into a single fill op if
941/// 1. The pack op does not have padding value, or
942/// 2. The filled value and padding value are the same.
943static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
944 linalg::PackOp packOp) {
945 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
946 if (!fillOp)
947 return failure();
948
949 if (auto paddingValue = packOp.getPaddingValue())
950 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
951 return failure();
952
953 Value packOpDest = packOp.getDest();
954 if (!packOpDest.hasOneUse())
955 return failure();
956
957 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
958 packOp.getDest());
959}
960
961/// Wrapper pattern that applies foldFillPackIntoFillOp method.
962struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
963public:
964 FoldFillWithPack(MLIRContext *context)
965 : OpRewritePattern<linalg::PackOp>(context) {}
966
967 LogicalResult matchAndRewrite(linalg::PackOp packOp,
968 PatternRewriter &rewriter) const override {
969 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
970 if (failed(fillOp))
971 return failure();
972 rewriter.replaceOp(packOp, fillOp.value().result());
973 return success();
974 }
975};
976
977/// Fold fill with copy.
978struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
979 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
980
981 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
982 PatternRewriter &rewriter) const override {
983 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
984 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
985 fillOp.getInputs(),
986 copyOp.getOutputs());
987 return success();
988 }
989 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
990 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
991 fillOp.getOutputs());
992 return success();
993 }
994 return failure();
995 }
996};
997
998/// Fold fill with transpose.
999struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1000 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1001
1002 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1003 PatternRewriter &rewriter) const override {
1004 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1005 rewriter.replaceOpWithNewOp<FillOp>(
1006 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1007 transposeOp.getDpsInitOperand(0)->get());
1008 return success();
1009 }
1010 return failure();
1011 }
1012};
1013
1014/// Fold a concat with all elements being fills of the same value
1015/// into a fill of the concat result shape.
1016struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1018
1019 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1020 PatternRewriter &rewriter) const override {
1021 auto concatOperands = concatOp.getInputs();
1022 if (concatOperands.empty()) {
1023 return failure();
1024 }
1025
1026 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1027 if (!firstFillOp) {
1028 return failure();
1029 }
1030 // Prefetch the fill value.
1031 OpFoldResult firstFillVal =
1032 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1033 // Collect all the outs values for the fill operations.
1034 SmallVector<Value> allOuts;
1035 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1036
1037 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1038 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1039 if (!fillOp) {
1040 return false;
1041 }
1042
1043 OpFoldResult fillVal =
1044 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1045 if (fillVal != firstFillVal)
1046 return false;
1047
1048 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1049 return true;
1050 };
1051 if (!llvm::all_of(concatOperands.drop_front(),
1052 isDefinedByCompatibleFillOp)) {
1053 return rewriter.notifyMatchFailure(
1054 concatOp, "not all operands are defined by a compatible fill op");
1055 }
1056
1057 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1058 concatOp.getDim(), allOuts);
1059 rewriter.replaceOpWithNewOp<linalg::FillOp>(
1060 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1061 return success();
1062 }
1063};
1064
1065} // namespace
1066
1067void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1068 MLIRContext *context) {
1069 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1070 FoldFillWithPack, FoldFillWithPad,
1071 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1072 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1073 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1074}
1075
1076//===----------------------------------------------------------------------===//
1077// GenericOp
1078//===----------------------------------------------------------------------===//
1079
1081 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1082 ValueRange outputs,
1083 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1084 SmallVector<Type, 4> blockArgTypes;
1085 SmallVector<Location, 4> blockArgLocs;
1086 for (ValueRange container : {inputs, outputs}) {
1087 for (Value v : container) {
1088 Type t = v.getType();
1089 blockArgTypes.push_back(
1090 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1091 blockArgLocs.push_back(v.getLoc());
1092 }
1093 }
1094
1095 OpBuilder::InsertionGuard guard(builder);
1096 Block *bodyBlock =
1097 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1098 bodyBuild(builder, loc, bodyBlock->getArguments());
1099}
1100
1101void GenericOp::getAsmBlockArgumentNames(Region &region,
1102 OpAsmSetValueNameFn setNameFn) {
1103 for (Value v : getRegionInputArgs())
1104 setNameFn(v, "in");
1105 for (Value v : getRegionOutputArgs())
1106 setNameFn(v, "out");
1107}
1108
1109void GenericOp::build(
1110 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1111 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1112 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1113 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1114 ArrayRef<NamedAttribute> attributes) {
1115 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1116 iteratorTypes, doc, libraryCall);
1117 result.addAttributes(attributes);
1118 if (bodyBuild)
1119 buildGenericRegion(builder, result.location, *result.regions.front(),
1120 inputs, outputs, bodyBuild);
1121}
1122
1123void GenericOp::build(
1124 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1125 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1126 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1127 StringRef libraryCall,
1128 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1129 ArrayRef<NamedAttribute> attributes) {
1130 build(builder, result, resultTensorTypes, inputs, outputs,
1131 builder.getAffineMapArrayAttr(indexingMaps),
1132 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1133 iteratorTypes,
1134 [&](utils::IteratorType iter) -> mlir::Attribute {
1135 return IteratorTypeAttr::get(builder.getContext(), iter);
1136 }))),
1137 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1138 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1139 bodyBuild, attributes);
1140}
1141
1142void GenericOp::build(
1143 OpBuilder &builder, OperationState &result, ValueRange inputs,
1144 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1145 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1146 StringRef libraryCall,
1147 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1148 ArrayRef<NamedAttribute> attributes) {
1149 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1150 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1151}
1152
1153void GenericOp::build(
1154 OpBuilder &builder, OperationState &result, ValueRange inputs,
1155 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1156 ArrayRef<utils::IteratorType> iteratorTypes,
1157 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1158 ArrayRef<NamedAttribute> attributes) {
1159 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1160 /*doc=*/"",
1161 /*libraryCall=*/"", bodyBuild, attributes);
1162}
1163
1164void GenericOp::build(
1165 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1166 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1167 ArrayRef<utils::IteratorType> iteratorTypes,
1168 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1169 ArrayRef<NamedAttribute> attributes) {
1170 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1171 iteratorTypes,
1172 /*doc=*/"",
1173 /*libraryCall=*/"", bodyBuild, attributes);
1174}
1175
1176void GenericOp::print(OpAsmPrinter &p) {
1177 p << " ";
1178
1179 // Print extra attributes.
1180 auto genericAttrNames = linalgTraitAttrNames();
1181
1182 llvm::StringSet<> genericAttrNamesSet;
1183 genericAttrNamesSet.insert_range(genericAttrNames);
1184 SmallVector<NamedAttribute, 8> genericAttrs;
1185 for (auto attr : (*this)->getAttrs()) {
1186 if (attr.getName() == getIteratorTypesAttrName()) {
1187 auto iteratorTypes =
1188 llvm::cast<ArrayAttr>(attr.getValue())
1189 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1190 // Convert IteratorType enums into the string representation. This is
1191 // needed, because tests still use the old format when 'iterator_types'
1192 // attribute is represented as an array of strings.
1193 // TODO: Remove this conversion once tests are fixed.
1194 SmallVector<Attribute> iteratorTypeNames =
1195 llvm::to_vector(llvm::map_range(
1196 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1197 return StringAttr::get(getContext(), stringifyIteratorType(t));
1198 }));
1199
1200 genericAttrs.emplace_back(
1201 getIteratorTypesAttrName(),
1202 ArrayAttr::get(getContext(), iteratorTypeNames));
1203 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1204 genericAttrs.push_back(attr);
1205 }
1206 }
1207 if (!genericAttrs.empty()) {
1208 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1209 p << genericDictAttr;
1210 }
1211
1212 // Printing is shared with named ops, except for the region and attributes
1213 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1214
1215 genericAttrNames.push_back("operandSegmentSizes");
1216 genericAttrNamesSet.insert(genericAttrNames.back());
1217
1218 bool hasExtraAttrs = false;
1219 for (NamedAttribute n : (*this)->getAttrs()) {
1220 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1221 break;
1222 }
1223 if (hasExtraAttrs) {
1224 p << " attrs = ";
1225 p.printOptionalAttrDict((*this)->getAttrs(),
1226 /*elidedAttrs=*/genericAttrNames);
1227 }
1228
1229 // Print region.
1230 if (!getRegion().empty()) {
1231 p << ' ';
1232 p.printRegion(getRegion());
1233 }
1234
1235 // Print results.
1236 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1237}
1238
1239ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1240 DictionaryAttr dictAttr;
1241 // Parse the core linalg traits that must check into a dictAttr.
1242 // The name is unimportant as we will overwrite result.attributes.
1243 // The core linalg traits must contain the information necessary to pass the
1244 // verifier.
1245 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1246 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1247 return failure();
1248 result.attributes.assign(dictAttr.getValue().begin(),
1249 dictAttr.getValue().end());
1250
1251 // Convert array of string into an array of IteratorType enums. This is
1252 // needed, because tests still use the old format when 'iterator_types'
1253 // attribute is represented as an array of strings.
1254 // TODO: Remove this conversion once tests are fixed.
1255 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1256 result.attributes.get(getIteratorTypesAttrName(result.name)));
1257 if (!iteratorTypes) {
1258 return parser.emitError(attributeLocation)
1259 << "expected " << getIteratorTypesAttrName(result.name)
1260 << " array attribute";
1261 }
1262
1263 SmallVector<Attribute> iteratorTypeAttrs;
1264
1265 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1266 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1267 if (!maybeIteratorType.has_value())
1268 return parser.emitError(parser.getCurrentLocation())
1269 << "unexpected iterator_type (" << s << ")";
1270
1271 iteratorTypeAttrs.push_back(
1272 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1273 }
1274 result.attributes.set(getIteratorTypesAttrName(result.name),
1275 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1276
1277 // Parsing is shared with named ops, except for the region.
1278 SmallVector<Type, 1> inputTypes, outputTypes;
1279 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1280 return failure();
1281
1282 // Optional attributes may be added.
1283 if (succeeded(parser.parseOptionalKeyword("attrs")))
1284 if (failed(parser.parseEqual()) ||
1285 failed(parser.parseOptionalAttrDict(result.attributes)))
1286 return failure();
1287
1288 std::unique_ptr<Region> region = std::make_unique<Region>();
1289 if (parser.parseRegion(*region, {}))
1290 return failure();
1291 result.addRegion(std::move(region));
1292
1293 // Generic ops may specify that a subset of its outputs are tensors. Such
1294 // outputs are specified in the result type.
1295 // TODO: may need to move output parsing before region parsing.
1296 // Need to wait for declarative assembly resolution to decide.
1297 SmallVector<Type, 1> outputTensorsTypes;
1298 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1299 return failure();
1300 result.addTypes(outputTensorsTypes);
1301
1302 return success();
1303}
1304
1307 &effects,
1308 LinalgOp linalgOp) {
1309 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1310 if (!llvm::isa<MemRefType>(operand.getType()))
1311 continue;
1312 effects.emplace_back(
1313 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1314 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1315 }
1316
1317 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1318 if (!llvm::isa<MemRefType>(operand.get().getType()))
1319 continue;
1320 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1321 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1322 /*effectOnFullRegion=*/true,
1324 }
1325 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1326 /*effectOnFullRegion=*/true,
1328 }
1329}
1330
1331void GenericOp::getEffects(
1332 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1333 &effects) {
1334 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1335}
1336
1339 // Operands with value semantics are speculatable, while operands with memory
1340 // semantics are not.
1341 if (!linalgOp.hasPureTensorSemantics())
1343 // The body of the op can still have speculation in its region.
1345}
1346
1347Speculation::Speculatability GenericOp::getSpeculatability() {
1348 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1349}
1350
1351namespace {
1352
1353/// Remove linalg operations that are just copying the values from inputs to
1354/// results. In the memref case, the operation must be copying to and from the
1355/// same value. Requirements are:
1356/// 1) All iterator types are parallel
1357/// 2) The body contains just a yield operation with the yielded values being
1358/// the arguments corresponding to the operands.
1359template <typename OpTy>
1360struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1361 using OpRewritePattern<OpTy>::OpRewritePattern;
1362
1363 LogicalResult matchAndRewrite(OpTy linalgOp,
1364 PatternRewriter &rewriter) const override {
1365 // All indexing maps must be equal. It follows that they are permutations.
1366 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1367 return failure();
1368
1369 // Check that the body of the linalg operation is just a linalg.yield
1370 // operation.
1371 Block &body = linalgOp->getRegion(0).front();
1372 if (!llvm::hasSingleElement(body))
1373 return failure();
1374 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1375 if (!yieldOp)
1376 return failure();
1377
1378 // In the buffer case, we need to check exact buffer equality.
1379 if (linalgOp.hasPureBufferSemantics()) {
1380 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1381 linalgOp.getDpsInputOperand(0)->get() !=
1382 linalgOp.getDpsInitOperand(0)->get()) {
1383 return rewriter.notifyMatchFailure(
1384 linalgOp, "expected single input and output to be the same value");
1385 }
1386
1387 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1388 if (!yieldArg || yieldArg.getOwner() != &body) {
1389 return rewriter.notifyMatchFailure(linalgOp,
1390 "cannot fold fill-like op");
1391 }
1392
1393 rewriter.eraseOp(linalgOp);
1394 return success();
1395 }
1396
1397 if (!linalgOp.hasPureTensorSemantics()) {
1398 return rewriter.notifyMatchFailure(
1399 linalgOp, "mixed semantics is not supported yet");
1400 }
1401
1402 // Get the argument number of the returned values. That is the operand
1403 // number to use for replacing uses of this operation.
1404 SmallVector<Value> returnedArgs;
1405 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1406 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1407 if (!yieldArg || yieldArg.getOwner() != &body)
1408 return failure();
1409 unsigned argumentNumber = yieldArg.getArgNumber();
1410 Value returnedArg = linalgOp->getOperand(argumentNumber);
1411 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1412 // The input can have a different type than the result, e.g. a dynamic
1413 // input dimension can be turned into a static output dimension.
1414 Type returnType = returnedArg.getType();
1415 if (returnType != resultType) {
1416 // Distinguish between sparse conversion or dense tensor casting.
1417 // TODO: unify the two ops?
1420 returnedArg = sparse_tensor::ConvertOp::create(
1421 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1422 else {
1423 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1424 resultType))
1425 return failure();
1426 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1427 resultType, returnedArg);
1428 }
1429 }
1430 returnedArgs.push_back(returnedArg);
1431 }
1432
1433 if (returnedArgs.size() != linalgOp->getNumResults())
1434 return failure();
1435 rewriter.replaceOp(linalgOp, returnedArgs);
1436 return success();
1437 }
1438};
1439
1440} // namespace
1441
1442void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1443 MLIRContext *context) {
1444 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1445}
1446
1447LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1448 return memref::foldMemRefCast(*this);
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// MapOp
1453//===----------------------------------------------------------------------===//
1454
1455static ParseResult parseDstStyleOp(
1457 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1458 nullptr) {
1459 // Parse `ins` and `outs`.
1460 SmallVector<Type, 4> inputTypes, outputTypes;
1461 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1462 /*addOperandSegmentSizes=*/false))
1463 return failure();
1464
1465 // Add result types.
1466 for (Type outputType : outputTypes) {
1467 if (llvm::isa<RankedTensorType>(outputType))
1468 result.addTypes(outputType);
1469 }
1470
1471 // Parse required attributes.
1472 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1473 return failure();
1474
1475 // Parse optional attributes.
1476 if (parser.parseOptionalAttrDict(result.attributes))
1477 return failure();
1478 return success();
1479}
1480
1481void MapOp::getAsmBlockArgumentNames(Region &region,
1482 OpAsmSetValueNameFn setNameFn) {
1483 for (Value v : getRegionInputArgs())
1484 setNameFn(v, "in");
1485 for (Value v : getRegionOutputArgs())
1486 setNameFn(v, "init");
1487}
1488
1489void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1490 if (!getResults().empty())
1491 setNameFn(getResults().front(), "mapped");
1492}
1493
1494void MapOp::build(
1495 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1496 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1497 ArrayRef<NamedAttribute> attributes) {
1498 build(builder, result, TypeRange{}, inputs, init);
1499 result.addAttributes(attributes);
1500
1501 // Add output types for `RankedTensorType` output arguments.
1502 Type initType = init.getType();
1503 if (llvm::isa<RankedTensorType>(initType))
1504 result.addTypes(initType);
1505
1506 if (bodyBuild)
1507 buildGenericRegion(builder, result.location, *result.regions.front(),
1508 inputs, /*outputs=*/{init}, bodyBuild);
1509}
1510
1512 const OperationName &payloadOpName,
1513 const NamedAttrList &payloadOpAttrs,
1514 ArrayRef<Value> operands,
1515 bool initFirst = false, bool mapInit = true) {
1516 OpBuilder b(parser.getContext());
1517 Region *body = result.addRegion();
1518 Block &block = body->emplaceBlock();
1519 b.setInsertionPointToStart(&block);
1520 for (auto &operand : operands) {
1521 block.addArgument(
1522 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1523 b.getUnknownLoc());
1524 }
1525 SmallVector<Value> payloadOpOperands;
1526 // If initFirst flag is enabled, we consider init as the first position of
1527 // payload operands.
1528 if (initFirst) {
1529 if (mapInit)
1530 payloadOpOperands.push_back(block.getArguments().back());
1531 for (const auto &arg : block.getArguments().drop_back())
1532 payloadOpOperands.push_back(arg);
1533 } else {
1534 payloadOpOperands = {block.getArguments().begin(),
1535 block.getArguments().end() - int(!mapInit)};
1536 }
1537
1538 Operation *payloadOp = b.create(
1539 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1540 payloadOpOperands,
1541 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1542 .getElementType()},
1543 payloadOpAttrs);
1544 YieldOp::create(b, result.location, payloadOp->getResults());
1545}
1546
1547ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1548 std::optional<OperationName> payloadOpName;
1549 NamedAttrList payloadOpAttrs;
1550 if (succeeded(parser.parseOptionalLBrace())) {
1551 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1552 if (failed(operationName))
1553 return failure();
1554 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1555 return failure();
1556 payloadOpName = operationName.value();
1557 if (parser.parseRBrace())
1558 return failure();
1559 }
1560
1561 if (parseDstStyleOp(parser, result))
1562 return failure();
1563
1564 if (payloadOpName.has_value()) {
1565 if (!result.operands.empty())
1566 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1567 payloadOpAttrs, ArrayRef(result.operands), false,
1568 false);
1569 else
1570 result.addRegion();
1571 } else {
1572 SmallVector<OpAsmParser::Argument> regionArgs;
1573 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1574 /*allowType=*/true, /*allowAttrs=*/true)) {
1575 return failure();
1576 }
1577 Region *body = result.addRegion();
1578 if (parser.parseRegion(*body, regionArgs))
1579 return failure();
1580 }
1581 return success();
1582}
1583
1584static bool canUseShortForm(Block *body, bool initFirst = false,
1585 bool mapInit = true) {
1586 // `intFirst == true` implies that we want to map init arg
1587 if (initFirst && !mapInit)
1588 return false;
1589 // Check if the body can be printed in short form. The following 4 conditions
1590 // must be satisfied:
1591
1592 // 1) The body must contain exactly 2 operations: the payload op and a yield.
1593 if (body->getOperations().size() != 2)
1594 return false;
1595 Operation &payload = body->getOperations().front();
1596
1597 // 2) The payload op must have the same number of operands as the number of
1598 // block arguments.
1599 if (payload.getNumOperands() == 0 ||
1600 payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
1601 return false;
1602
1603 // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1604 // must be the first operand of the payload op, otherwise, the operands
1605 // must match the block arguments in order.
1606 if (initFirst) {
1607 // check init
1608 if (payload.getOperands().back() != body->getArgument(0))
1609 return false;
1610 // check rest
1611 for (const auto &[operand, bbArg] :
1612 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1613 if (bbArg != operand)
1614 return false;
1615 }
1616 } else {
1617 for (const auto &[operand, bbArg] :
1618 llvm::zip(payload.getOperands(),
1619 body->getArguments().drop_back(int(!mapInit)))) {
1620 if (bbArg != operand)
1621 return false;
1622 }
1623 }
1624
1625 // 4) The `yield` operand must be the result of the payload op.
1626 auto yieldOp = cast<YieldOp>(body->getTerminator());
1627 return yieldOp.getNumOperands() == 1 &&
1628 yieldOp.getOperand(0).getDefiningOp() &&
1629 yieldOp.getOperand(0).getDefiningOp() == &payload;
1630}
1631
1632static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1633 SmallVector<StringRef> elidedAttrs;
1634 std::string attrToElide;
1635 p << " { " << payloadOp->getName().getStringRef();
1636 for (const auto &attr : payloadOp->getAttrs()) {
1637 auto fastAttr =
1638 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1639 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1640 attrToElide = attr.getName().str();
1641 elidedAttrs.push_back(attrToElide);
1642 break;
1643 }
1644 }
1645 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1646 p << " }";
1647}
1648
1649void MapOp::print(OpAsmPrinter &p) {
1650 Block *mapper = getBody();
1651 bool useShortForm =
1652 canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
1653 if (useShortForm) {
1654 printShortForm(p, &mapper->getOperations().front());
1655 }
1656
1657 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1658 p.printOptionalAttrDict((*this)->getAttrs());
1659
1660 if (!useShortForm) {
1661 // Print region if the payload op was not detected.
1662 p.increaseIndent();
1663 p.printNewline();
1664 p << "(";
1665 llvm::interleaveComma(mapper->getArguments(), p,
1666 [&](auto arg) { p.printRegionArgument(arg); });
1667 p << ") ";
1668
1669 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1670 p.decreaseIndent();
1671 }
1672}
1673
1674LogicalResult MapOp::verify() {
1675 auto *bodyBlock = getBody();
1676 auto blockArgs = bodyBlock->getArguments();
1677
1678 // Checks if the number of `inputs` + `init` match the arity of the `mapper`
1679 // region.
1680 if (getInputs().size() + 1 != blockArgs.size())
1681 return emitOpError() << "expects number of operands to match the arity of "
1682 "mapper, but got: "
1683 << getInputs().size() + 1 << " and "
1684 << blockArgs.size();
1685
1686 // The parameters of mapper should all match the element type of inputs.
1687 for (const auto &[bbArgType, inputArg] :
1688 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1689 auto inputElemType =
1690 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1691 if (bbArgType != inputElemType) {
1692 return emitOpError() << "expected element type of input " << inputElemType
1693 << " to match bbArg type " << bbArgType;
1694 }
1695 }
1696
1697 // The shape of each input must match the shape of the output.
1698 auto outputShape = getInit().getType().getShape();
1699 for (Type inputArgType : TypeRange{getInputs()}) {
1700 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1701 if (inputElemShape != outputShape) {
1702 return emitOpError() << "expected shape of input (" << inputElemShape
1703 << ") to match shape of output (" << outputShape
1704 << ")";
1705 }
1706 }
1707
1708 return success();
1709}
1710
1711SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1712 int64_t rank = getInit().getType().getRank();
1713 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1714}
1715
1716ArrayAttr MapOp::getIndexingMaps() {
1717 Builder builder(getContext());
1718 int64_t rank = getInit().getType().getRank();
1719 int64_t numIndexingMaps = getOperands().size();
1720 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1721 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1722}
1723
1724void MapOp::getEffects(
1725 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1726 &effects) {
1727 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1728}
1729
1730Speculation::Speculatability MapOp::getSpeculatability() {
1731 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1732}
1733
1734//===----------------------------------------------------------------------===//
1735// ReduceOp
1736//===----------------------------------------------------------------------===//
1737
1738void ReduceOp::getAsmBlockArgumentNames(Region &region,
1739 OpAsmSetValueNameFn setNameFn) {
1740 for (Value v : getRegionInputArgs())
1741 setNameFn(v, "in");
1742 for (Value v : getRegionOutputArgs())
1743 setNameFn(v, "init");
1744}
1745
1746void ReduceOp::getAsmResultNames(
1747 function_ref<void(Value, StringRef)> setNameFn) {
1748 if (!getResults().empty())
1749 setNameFn(getResults().front(), "reduced");
1750}
1751
1752void ReduceOp::build(
1753 OpBuilder &builder, OperationState &result, ValueRange inputs,
1754 ValueRange inits, ArrayRef<int64_t> dimensions,
1755 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1756 ArrayRef<NamedAttribute> attributes) {
1757 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1758 result.addAttributes(attributes);
1759
1760 // Add output types for `RankedTensorType` output arguments.
1761 for (Value init : inits) {
1762 Type initType = init.getType();
1763 if (llvm::isa<RankedTensorType>(initType))
1764 result.addTypes(initType);
1765 }
1766
1767 if (bodyBuild)
1768 buildGenericRegion(builder, result.location, *result.regions.front(),
1769 inputs, inits, bodyBuild);
1770}
1771
1772SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1773 int64_t inputRank =
1774 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1775 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1776 utils::IteratorType::parallel);
1777 for (int64_t reductionDim : getDimensions())
1778 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1779 return iteratorTypes;
1780}
1781
1782ArrayAttr ReduceOp::getIndexingMaps() {
1783 int64_t inputRank =
1784 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1785 SmallVector<AffineMap> affineMaps(
1786 getNumDpsInputs(),
1788 AffineMap resultMap =
1790 .dropResults(getDimensions());
1791 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1792 affineMaps.push_back(resultMap);
1793 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1794}
1795
1796void ReduceOp::getEffects(
1797 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1798 &effects) {
1799 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1800}
1801
1802Speculation::Speculatability ReduceOp::getSpeculatability() {
1803 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1804}
1805
1806static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1807 NamedAttrList &attributes,
1808 StringRef attributeName) {
1809 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1810 return failure();
1811
1812 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1813 return success();
1814}
1815
1816ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1817 std::optional<OperationName> payloadOpName;
1818 NamedAttrList payloadOpAttrs;
1819 if (succeeded(parser.parseOptionalLBrace())) {
1820 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1821 if (failed(operationName))
1822 return failure();
1823 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1824 return failure();
1825 payloadOpName = operationName.value();
1826 if (parser.parseRBrace())
1827 return failure();
1828 }
1829
1830 if (parseDstStyleOp(
1831 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1832 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1833 }))
1834 return failure();
1835
1836 if (payloadOpName.has_value()) {
1837 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1838 ArrayRef(result.operands), /*initFirst=*/true);
1839 } else {
1840 SmallVector<OpAsmParser::Argument> regionArgs;
1841 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1842 /*allowType=*/true, /*allowAttrs=*/true)) {
1843 return failure();
1844 }
1845
1846 Region *body = result.addRegion();
1847 if (parser.parseRegion(*body, regionArgs))
1848 return failure();
1849 }
1850
1851 return success();
1852}
1853
1854static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1855 ArrayRef<int64_t> attributeValue) {
1856 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1857}
1858
1859void ReduceOp::print(OpAsmPrinter &p) {
1860 Block *mapper = getBody();
1861 bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1862 if (useShortForm) {
1863 printShortForm(p, &mapper->getOperations().front());
1864 }
1865
1866 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1867 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1868 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1869 if (!useShortForm) {
1870 // Print region if the payload op was not detected.
1871 p.increaseIndent();
1872 p.printNewline();
1873 p << "(";
1874 llvm::interleaveComma(mapper->getArguments(), p,
1875 [&](auto arg) { p.printRegionArgument(arg); });
1876 p << ") ";
1877
1878 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1879 p.decreaseIndent();
1880 }
1881}
1882
1883LogicalResult ReduceOp::verify() {
1884 ArrayRef<int64_t> dimensionsRef = getDimensions();
1885
1886 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1887 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1888 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1889 return emitOpError() << "expects all inputs to have the same shapes. "
1890 "Shape at input-index "
1891 << i
1892 << " is not equal to the shape at input-index 0.";
1893 }
1894 }
1895 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1896 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1897 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1898 return emitOpError() << "expects all outputs to have the same shapes. "
1899 "Shape at output-index "
1900 << i
1901 << " is not equal to the shape at output-index 0.";
1902 }
1903 }
1904 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1905 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1906
1907 DenseSet<int64_t> dimensionsToReduce;
1908 for (int64_t dimension : dimensionsRef) {
1909 if (dimension < 0 || dimension >= inputType.getRank()) {
1910 return emitOpError()
1911 << "dimensions for reduction should be in the range [0, "
1912 << inputType.getRank() - 1 << "].";
1913 }
1914 dimensionsToReduce.insert(dimension);
1915 }
1916
1917 auto inputDims = inputType.getShape();
1918 auto initDims = initType.getShape();
1919
1920 // Input dimensions that will be left after the reduction.
1921 SmallVector<int64_t> reducedInputDims;
1922 for (const auto &en : llvm::enumerate(inputDims)) {
1923 if (!dimensionsToReduce.count(en.index()))
1924 reducedInputDims.push_back(en.value());
1925 }
1926
1927 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1928 return emitOpError() << "number of dimensions after reduction "
1929 << reducedInputDims.size()
1930 << " doesn't match the init rank "
1931 << initType.getRank();
1932 }
1933
1934 if (reducedInputDims != initDims)
1935 return emitOpError() << "init dimensions [" << initDims
1936 << "] doesn't match input dimensions after reduction ["
1937 << reducedInputDims << "]";
1938
1939 Block *block = getBody();
1940 if (block->getNumArguments() != this->getNumOperands())
1941 return emitOpError()
1942 << "mismatching number of operands and block arguments";
1943
1944 // Check that the first block arguments match the element type of the inputs.
1945 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1946 Type inputElementType =
1947 llvm::cast<ShapedType>(input.getType()).getElementType();
1948 if (inputElementType != bbArg.getType())
1949 return emitOpError()
1950 << "input element type " << inputElementType
1951 << " does not match corresponding block argument type "
1952 << bbArg.getType();
1953 }
1954
1955 // Check that the last block arguments match the element type of the outputs.
1956 for (auto [output, bbArg] : llvm::zip(
1957 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1958 auto outputElementType =
1959 llvm::cast<ShapedType>(output.getType()).getElementType();
1960 if (outputElementType != bbArg.getType())
1961 return emitOpError()
1962 << "output element type " << outputElementType
1963 << " does not match corresponding block argument type "
1964 << bbArg.getType();
1965 }
1966 return success();
1967}
1968
1969//===----------------------------------------------------------------------===//
1970// TransposeOp
1971//===----------------------------------------------------------------------===//
1972
1973static void buildIdentityRegion(OpBuilder &builder, Location loc,
1974 Region &region, ValueRange inputs,
1975 ValueRange outputs) {
1976 buildGenericRegion(builder, loc, region, inputs, outputs,
1977 [](OpBuilder &b, Location loc, ValueRange args) {
1978 if (!args.empty())
1979 linalg::YieldOp::create(b, loc, args[0]);
1980 });
1981}
1982
1983void TransposeOp::build(::mlir::OpBuilder &builder,
1984 ::mlir::OperationState &result, Value input, Value init,
1985 DenseI64ArrayAttr permutation,
1986 ArrayRef<NamedAttribute> attributes) {
1987 result.addOperands(input);
1988 result.addOperands(init);
1989 result.addAttribute(getPermutationAttrName(result.name), permutation);
1990 result.addAttributes(attributes);
1991
1992 // Add output types for `RankedTensorType` output arguments.
1993 Type initType = init.getType();
1994 if (llvm::isa<RankedTensorType>(initType))
1995 result.addTypes(initType);
1996
1997 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1998 init);
1999}
2000
2001void TransposeOp::build(::mlir::OpBuilder &builder,
2002 ::mlir::OperationState &result, Value input, Value init,
2003 ArrayRef<int64_t> permutation,
2004 ArrayRef<NamedAttribute> attributes) {
2005 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
2006 attributes);
2007}
2008
2009ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2011 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2012 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
2013 })))
2014 return failure();
2015
2016 OpBuilder builder(parser.getContext());
2017 buildIdentityRegion(builder, result.location, *result.addRegion(),
2018 /*inputs=*/result.operands,
2019 /*outputs=*/{});
2020 return success();
2021}
2022
2023void TransposeOp::getAsmResultNames(
2024 function_ref<void(Value, StringRef)> setNameFn) {
2025 if (!getResults().empty())
2026 setNameFn(getResults().front(), "transposed");
2027}
2028
2029void TransposeOp::print(OpAsmPrinter &p) {
2030 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2031 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2032 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2033}
2034
2035LogicalResult TransposeOp::verify() {
2036 ArrayRef<int64_t> permutationRef = getPermutation();
2037
2038 if (!isPermutationVector(permutationRef))
2039 return emitOpError("permutation is not valid");
2040
2041 auto inputType = getInput().getType();
2042 auto initType = getInit().getType();
2043
2044 int64_t rank = inputType.getRank();
2045
2046 if (rank != initType.getRank())
2047 return emitOpError() << "input rank " << rank
2048 << " does not match init rank " << initType.getRank();
2049
2050 if (rank != static_cast<int64_t>(permutationRef.size()))
2051 return emitOpError() << "size of permutation " << permutationRef.size()
2052 << " does not match the argument rank " << rank;
2053
2054 auto inputDims = inputType.getShape();
2055 auto initDims = initType.getShape();
2056
2057 for (int64_t i = 0; i < rank; ++i) {
2058 int64_t inputDim = inputDims[permutationRef[i]];
2059 int64_t initDim = initDims[i];
2060
2061 if (inputDim != initDim) {
2062 return emitOpError() << "dim(result, " << i << ") = " << initDim
2063 << " doesn't match dim(input, permutation[" << i
2064 << "]) = " << inputDim;
2065 }
2066 }
2067
2068 return success();
2069}
2070
2071SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2072 int64_t rank = getInit().getType().getRank();
2073 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2074}
2075
2076ArrayAttr TransposeOp::getIndexingMaps() {
2077 Builder builder(getContext());
2078 int64_t rank = getInit().getType().getRank();
2079 return builder.getAffineMapArrayAttr(
2081 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2082 builder.getMultiDimIdentityMap(rank)});
2083}
2084
2085void TransposeOp::getEffects(
2086 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2087 &effects) {
2088 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2089}
2090
2091Speculation::Speculatability TransposeOp::getSpeculatability() {
2092 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2093}
2094
2095LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2096 SmallVectorImpl<OpFoldResult> &result) {
2097 // Only the tensor type is supported.
2098 if (!isa<TensorType>(getInput().getType()))
2099 return failure();
2100
2101 // Single dimension transpose.
2102 if (getPermutation().empty()) {
2103 result.push_back(getInput());
2104 return success();
2105 }
2106 // Identity permutation.
2107 if (isIdentityPermutation(getPermutation())) {
2108 result.push_back(getInput());
2109 return success();
2110 }
2111
2112 return failure();
2113}
2114
2115/// Fold transpose with transpose.
2116struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2117 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2118
2119 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2120 PatternRewriter &rewriter) const override {
2121 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2122 if (!defTransposeOp)
2123 return failure();
2124 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2125 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2126 SmallVector<int64_t> foldedPerms;
2127 foldedPerms.reserve(perms.size());
2128 for (int64_t perm : perms)
2129 foldedPerms.push_back(defPerms[perm]);
2130
2131 rewriter.replaceOpWithNewOp<TransposeOp>(
2132 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2133 foldedPerms);
2134 return success();
2135 }
2136};
2137
2138/// This pattern canonicalize transpose by swapping the order of
2139/// broadcast and transpose:
2140/// transpose(broadcast(input)) -> broadcast(transpose(input))
2141struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2142 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2143
2144 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2145 PatternRewriter &rewriter) const override {
2146 Value input = transposeOp.getInput();
2147 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2148 if (!input.hasOneUse() || !broadcastOp)
2149 return failure();
2150
2151 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2152 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2153
2154 // Get new perms and new dimensions.
2155 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2157 SmallVector<int64_t> resultDimensions;
2158 unsigned dimensionSize = dimensions.size();
2159 for (unsigned i = 0; i < dimensionSize; ++i)
2160 resultDimensions.push_back(invertPerm[dimensions[i]]);
2161
2162 // Create transpose result.
2163 Value broadcastInput = broadcastOp.getInput();
2164 Location loc = transposeOp.getLoc();
2165 MLIRContext *ctx = transposeOp.getContext();
2167 auto broadcastInputTy =
2168 mlir::cast<RankedTensorType>(broadcastInput.getType());
2169 unsigned inputRank = broadcastInputTy.getRank();
2170 for (unsigned i = 0; i < inputRank; ++i) {
2171 if (broadcastInputTy.isDynamicDim(i)) {
2172 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2173 ->getResult(0));
2174 } else {
2175 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2176 broadcastInputTy.getDimSize(i)));
2177 }
2178 }
2179 SmallVector<OpFoldResult> transposeResultShapes =
2180 applyPermutation(dims, resultPerms);
2181 Value transposeInit = tensor::EmptyOp::create(
2182 rewriter, transposeOp.getLoc(), transposeResultShapes,
2183 broadcastInputTy.getElementType());
2184
2185 // Create broadcast(transpose(input)).
2186 Value transposeResult =
2187 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2188 transposeInit, resultPerms)
2189 ->getResult(0);
2190 rewriter.replaceOpWithNewOp<BroadcastOp>(
2191 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2192 return success();
2193 }
2194};
2195
2196void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2197 MLIRContext *context) {
2198 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2199}
2200
2201//===----------------------------------------------------------------------===//
2202// BroadcastOp
2203//===----------------------------------------------------------------------===//
2204
2205void BroadcastOp::build(::mlir::OpBuilder &builder,
2206 ::mlir::OperationState &result, Value input, Value init,
2207 DenseI64ArrayAttr dimensions,
2208 ArrayRef<NamedAttribute> attributes) {
2209 result.addOperands(input);
2210 result.addOperands(init);
2211 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2212 result.addAttributes(attributes);
2213
2214 // Add output types for `RankedTensorType` output arguments.
2215 Type initType = init.getType();
2216 if (llvm::isa<RankedTensorType>(initType))
2217 result.addTypes(initType);
2218
2219 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2220 init);
2221}
2222
2223void BroadcastOp::build(::mlir::OpBuilder &builder,
2224 ::mlir::OperationState &result, Value input, Value init,
2225 ArrayRef<int64_t> dimensions,
2226 ArrayRef<NamedAttribute> attributes) {
2227 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2228 attributes);
2229}
2230
2231ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2233 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2234 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2235 })))
2236 return failure();
2237
2238 OpBuilder builder(parser.getContext());
2239 buildIdentityRegion(builder, result.location, *result.addRegion(),
2240 /*inputs=*/result.operands,
2241 /*outputs=*/{});
2242 return success();
2243}
2244
2245void BroadcastOp::getAsmResultNames(
2246 function_ref<void(Value, StringRef)> setNameFn) {
2247 if (!getResults().empty())
2248 setNameFn(getResults().front(), "broadcasted");
2249}
2250
2251void BroadcastOp::print(OpAsmPrinter &p) {
2252 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2253 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2254 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2255}
2256
2257LogicalResult BroadcastOp::verify() {
2258 ArrayRef<int64_t> dimensionsRef = getDimensions();
2259
2260 auto inputType = getInput().getType();
2261 auto initType = getInit().getType();
2262
2263 int64_t inputRank = inputType.getRank();
2264 int64_t initRank = initType.getRank();
2265
2266 auto inputShape = inputType.getShape();
2267 auto initShape = initType.getShape();
2268
2269 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2270 return emitOpError() << "input rank plus added dimensions does not "
2271 "match init rank. input rank: "
2272 << inputRank
2273 << ", dimensions size: " << dimensionsRef.size()
2274 << ", init rank: " << initRank;
2275
2276 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2277 if (dim < 0 || dim >= initRank)
2278 return emitOpError() << "dimension " << idx
2279 << " is out of range. expected range: [0, "
2280 << initRank - 1 << "], got: " << dim;
2281 }
2282
2283 // Mapping from input dims to init dims.
2284 SmallVector<int64_t> dimMap;
2285 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2286 if (!llvm::is_contained(dimensionsRef, dim))
2287 dimMap.push_back(dim);
2288 }
2289
2290 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2291 // This dimensions is mapped from the input. Init and input dims should
2292 // match.
2293 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2294 return emitOpError() << "input dim " << inputDimIdx
2295 << " should match init dim " << initDimIdx
2296 << ". input: " << inputShape[inputDimIdx]
2297 << ", init: " << initShape[initDimIdx];
2298 }
2299
2300 return success();
2301}
2302
2303SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2304 int64_t rank = getInit().getType().getRank();
2305 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2306}
2307
2308ArrayAttr BroadcastOp::getIndexingMaps() {
2309 Builder builder(getContext());
2310 int64_t rank = getInit().getType().getRank();
2311 return builder.getAffineMapArrayAttr(
2312 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2313 builder.getMultiDimIdentityMap(rank)});
2314}
2315
2316void BroadcastOp::getEffects(
2317 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2318 &effects) {
2319 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2320}
2321
2322Speculation::Speculatability BroadcastOp::getSpeculatability() {
2323 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2324}
2325
2326/// Fold back-to-back broadcasts together.
2327struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2328 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2329
2330 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2331 PatternRewriter &rewriter) const override {
2332 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2333 if (!defBroadcastOp)
2334 return failure();
2335 ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2336 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2337 SmallVector<int64_t> foldedDims(dimensions);
2338 Value init = broadcastOp.getInit();
2339 int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2340 // Mapping from input dims to init dims.
2341 SmallVector<int64_t> dimMap;
2342 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2343 if (!llvm::is_contained(dimensions, dim))
2344 dimMap.push_back(dim);
2345 }
2346 for (auto dim : defDimensions)
2347 foldedDims.push_back(dimMap[dim]);
2348
2349 llvm::sort(foldedDims);
2350 rewriter.replaceOpWithNewOp<BroadcastOp>(
2351 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2352 return success();
2353 }
2354};
2355
2356void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2357 MLIRContext *context) {
2358 results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2359}
2360
2361//===----------------------------------------------------------------------===//
2362// YieldOp
2363//===----------------------------------------------------------------------===//
2364
2365void linalg::YieldOp::print(OpAsmPrinter &p) {
2366 if (getNumOperands() > 0)
2367 p << ' ' << getOperands();
2368 p.printOptionalAttrDict((*this)->getAttrs());
2369 if (getNumOperands() > 0)
2370 p << " : " << getOperandTypes();
2371}
2372
2373ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2374 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2375 SmallVector<Type, 2> types;
2376 SMLoc loc = parser.getCurrentLocation();
2377 return failure(parser.parseOperandList(opInfo) ||
2378 parser.parseOptionalAttrDict(result.attributes) ||
2379 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2380 parser.resolveOperands(opInfo, types, loc, result.operands));
2381}
2382
2383// Check the operand number and types must match the element types of the
2384// LinalgOp interface's shaped operands.
2385static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2386 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2387 return op.emitOpError("expected number of yield values (")
2388 << op.getNumOperands()
2389 << ") to match the number of inits / outs operands of the enclosing "
2390 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2391
2392 for (OpOperand &opOperand : op->getOpOperands()) {
2393 OpOperand *outputOperand =
2394 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2395 Type elementType = outputOperand->get().getType();
2396 if (isa<MemRefType, RankedTensorType>(elementType))
2397 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2398 if (opOperand.get().getType() != elementType)
2399 return op.emitOpError("type of yield operand ")
2400 << (opOperand.getOperandNumber() + 1) << " ("
2401 << opOperand.get().getType() << ") doesn't match "
2402 << "the element type of the enclosing linalg.generic op ("
2403 << elementType << ")";
2404 }
2405 return success();
2406}
2407
2408LogicalResult linalg::YieldOp::verify() {
2409 auto *parentOp = (*this)->getParentOp();
2410 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2411 return emitOpError("expected single non-empty parent region");
2412
2413 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2414 return verifyYield(*this, linalgOp);
2415
2416 return emitOpError("expected parent op with LinalgOp interface");
2417}
2418
2419//===----------------------------------------------------------------------===//
2420// IndexOp
2421//===----------------------------------------------------------------------===//
2422
2423LogicalResult IndexOp::verify() {
2424 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2425 if (!linalgOp)
2426 return emitOpError("expected parent op with LinalgOp interface");
2427 if (linalgOp.getNumLoops() <= getDim())
2428 return emitOpError("expected dim (")
2429 << getDim() << ") to be lower than the number of loops ("
2430 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2431 return success();
2432}
2433
2434OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2435 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2436 // Bail out if `linalg.index` does not have a proper parent yet at this
2437 // point, e.g., when calling `createOrFold` during IR construction in
2438 // `genericOp::build`.
2439 if (!linalgOp)
2440 return OpFoldResult{};
2441
2442 // Index of unit dims is always 0.
2443 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2444 uint64_t dim = getDim();
2445 assert(dim < loopBounds.size() && "Dim is out of bounds");
2446 if (loopBounds[dim] == 1)
2447 return IntegerAttr::get(IndexType::get(getContext()), 0);
2448
2449 return OpFoldResult{};
2450}
2451
2452/////// Operations corresponding to library calls defined with Tablegen ////////
2453
2454#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2455
2456#define GET_OP_CLASSES
2457#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2458
2459#define GET_OP_CLASSES
2460#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2461#define GET_OP_CLASSES
2462#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2463
2464AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2465 unsigned rank,
2466 MLIRContext *context) {
2467 if (maybeMap)
2468 return *maybeMap;
2469 if (rank == 0)
2470 return AffineMap::get(context);
2471 return AffineMap::getMultiDimIdentityMap(rank, context);
2472}
2473
2475mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2476 MLIRContext *context) {
2478 res.reserve(num);
2479 for (unsigned i = 0; i < num; ++i)
2480 res.push_back(getAffineDimExpr(startIdx++, context));
2481 return res;
2482}
2483
2486 auto rangeA = llvm::make_range(a.begin(), a.end());
2487 auto rangeB = llvm::make_range(b.begin(), b.end());
2488 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2489 return llvm::to_vector<4>(concatRanges);
2490}
2491
2492static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2493 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2494 ss << "view";
2495 for (auto size : memref.getShape())
2496 if (size < 0)
2497 ss << "sx";
2498 else
2499 ss << size << "x";
2500 if (failed(appendMangledType(ss, memref.getElementType())))
2501 return failure();
2502 if (auto as = memref.getMemorySpace()) {
2503 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2504 ss << "as" << attr.getInt();
2505 else
2506 return failure();
2507 }
2508 return success();
2509 }
2510 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2511 ss << "vector";
2512 llvm::interleave(
2513 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2514 if (failed(appendMangledType(ss, vec.getElementType())))
2515 return failure();
2516 return success();
2517 }
2519 ss << t;
2520 return success();
2521 }
2522 return failure();
2523}
2524
2526 assert(isa<LinalgOp>(op));
2527 std::string name(op->getName().getStringRef().str());
2528 std::string fun = "";
2529 for (NamedAttribute kv : op->getAttrs()) {
2530 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2531 fun = stringifyEnum(ufa.getValue()).str() + "_";
2532 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2533 fun = stringifyEnum(bfa.getValue()).str() + "_";
2534 }
2535 }
2536 name.reserve(128);
2537 llvm::replace(name, '.', '_');
2538 llvm::raw_string_ostream ss(name);
2539 ss << "_" << fun;
2540 for (Type t : op->getOperandTypes()) {
2541 if (failed(appendMangledType(ss, t)))
2542 return std::string();
2543 ss << "_";
2544 }
2545 name.pop_back();
2546 return name;
2547}
2548
2549//===----------------------------------------------------------------------===//
2550// Canonicalizers and Folders.
2551//===----------------------------------------------------------------------===//
2552
2553namespace {
2554struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2556
2557 LogicalResult matchAndRewrite(LinalgOp op,
2558 PatternRewriter &rewriter) const override {
2559 for (OpOperand &opOperand : op->getOpOperands()) {
2560 // Linalg "inputs" may be either tensor or memref type.
2561 // tensor<0xelt_type> is a convention that may not always mean
2562 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2563 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2564 if (!mt)
2565 continue;
2566 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2567 rewriter.eraseOp(op);
2568 return success();
2569 }
2570 }
2571 return failure();
2572 }
2573};
2574
2575/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2576/// result that is more static than the linalg op.
2577struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2578 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2579
2580 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2581 PatternRewriter &rewriter) const override {
2582 if (!tensor::canFoldIntoProducerOp(castOp))
2583 return failure();
2584
2585 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2586 if (!linalgOp)
2587 return failure();
2588
2589 // Cast can be in conditionally reachable region, if which case folding will
2590 // generate invalid code. Only conservatively fold ops in same block for
2591 // now.
2592 if (castOp->getBlock() != linalgOp->getBlock())
2593 return failure();
2594
2595 OpBuilder::InsertionGuard guard(rewriter);
2596 rewriter.setInsertionPoint(linalgOp);
2597
2598 Location loc = linalgOp.getLoc();
2599 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2600 unsigned resultNumber = resultValue.getResultNumber();
2601 auto resultType =
2602 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2603 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2604 // going from a more dynamic shape to a less dynamic shape. If the producer
2605 // for this cast, i.e. producer of the out operand, is also an operation
2606 // that folds with tensor.cast consumer (like this pattern), the cast will
2607 // continue to propagate as far up the stack as it can go.
2608 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2609 Value newOperand =
2610 tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2611 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2612 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2613 linalgOp.getDpsInits().end());
2614 outputOperands[resultNumber] = newOperand;
2615 newOperands.append(outputOperands.begin(), outputOperands.end());
2616
2617 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2618 linalgOp->result_type_end());
2619 resultTypes[resultNumber] = resultType;
2620 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2621
2622 // Create a tensor.cast operation back to the original type.
2623 Value castBack = tensor::CastOp::create(
2624 rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2625
2626 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2627 results[resultNumber] = castBack;
2628 rewriter.replaceOp(linalgOp, results);
2629 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2630 return success();
2631 }
2632};
2633
2634/// For each of the operand in `operands` this function maps the static sizes of
2635/// dimensions to their affine dim expressions.
2636static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2637 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2638 for (OpOperand &opOperand : operands) {
2639 if (linalgOp.isScalar(&opOperand))
2640 continue;
2641 Value src = opOperand.get();
2642 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2643 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2644
2645 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2646 // `tensor.cast` operation and source of the cast operation has a static
2647 // shape, then assign it to the `sourceShape`.
2648 auto *parentOp = src.getDefiningOp();
2649 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2650 if (parentOp) {
2651 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2652 Value castSource = castOp.getSource();
2653 auto castSourceType =
2654 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2655 if (castSourceType && castSourceType.hasStaticShape())
2656 sourceShape = castSourceType.getShape();
2657 }
2658 }
2659
2660 // If the source shape's dimension has a static shape, map the affine dim
2661 // expression to the known static size.
2662 for (unsigned i = 0; i < sourceShape.size(); i++) {
2663 if (sourceType.isDynamicDim(i))
2664 continue;
2665 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2666 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2667 }
2668 }
2669}
2670
2671/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2672/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2673/// their result types is stored in `resultTypes`. If `opOperand` requires no
2674/// change then `changeNeeded` is false and same operand is added in the
2675/// `newOperands` list.
2676static void createNewOperandWithStaticSizes(
2677 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2678 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2679 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2680 bool &changeNeeded) {
2681 Value src = opOperand->get();
2682 newOperands.push_back(src);
2683 if (linalgOp.isScalar(opOperand))
2684 return;
2685 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2686 Type resultType = sourceType;
2687 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2688 resultTypes.push_back(resultType);
2689 return;
2690 }
2691 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2692 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2693 SmallVector<int64_t> newShape;
2694 // If operand is updated with new shape, `newOperandNeeded` will be
2695 // true.
2696 bool newOperandNeeded = false;
2697 for (unsigned i = 0; i < sourceShape.size(); i++) {
2698 int64_t dimShape = sourceShape[i];
2699 AffineExpr dimExpr = sourceMap.getResult(i);
2700 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2701 newShape.push_back(dimShape);
2702 continue;
2703 }
2704 // Dimension has a dynamic shape and corresponding affine dim
2705 // expression is present in the map. So assign the size for the
2706 // given affine dim expression to the dimension.
2707 newShape.push_back(affineExprToSize[dimExpr]);
2708 newOperandNeeded = true;
2709 }
2710 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2711 sourceType.getEncoding());
2712 if (newOperandNeeded) {
2713 changeNeeded = true;
2714 // Get the new operand value given its size and element type by
2715 // casting it.
2716 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2717 unsigned index = opOperand->getOperandNumber();
2718 newOperands[index] = newOperand;
2719 }
2720 if (linalgOp.isDpsInit(opOperand))
2721 resultTypes.push_back(resultType);
2722}
2723
2724/// Static shapes for the operands can be inferred if any one of the operands
2725/// have a static shape. This can be done by referring to the affine dim
2726/// expressions for the operand.
2727struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2728 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2729
2730 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2731 PatternRewriter &rewriter) const override {
2732 if (!linalgOp.hasPureTensorSemantics())
2733 return failure();
2734
2735 // Maps must be projected permutations.
2736 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2737 return !map.isProjectedPermutation();
2738 }))
2739 return failure();
2740
2741 // Maps affine dim expressions to the static size of that dimension.
2742 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2743 Location loc = linalgOp.getLoc();
2744
2745 // For each of the affine dim expression, check if the size is known. If
2746 // known add that in the map.
2747 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2748
2749 SmallVector<Value> newOperands;
2750 SmallVector<Type> resultTypes;
2751
2752 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2753 // change in their types.
2754 bool changeNeeded = false;
2755 newOperands.reserve(linalgOp->getNumOperands());
2756 resultTypes.reserve(linalgOp.getNumDpsInits());
2757
2758 // Iterate over all the operands and update the static sizes.
2759 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2760 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2761 affineExprToSize, linalgOp, newOperands,
2762 resultTypes, changeNeeded);
2763 }
2764
2765 // If the generic op has all the required static information, no
2766 // canonicalization needed.
2767 if (!changeNeeded)
2768 return failure();
2769
2770 // Clone op.
2771 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2772 SmallVector<Value> replacements;
2773 replacements.reserve(newOp->getNumResults());
2774 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2775 Value newResult = std::get<1>(it);
2776 Value oldResult = std::get<0>(it);
2777 Type newType = newResult.getType();
2778 Type oldType = oldResult.getType();
2779 replacements.push_back(
2780 (newType != oldType)
2781 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2782 : newResult);
2783 }
2784 rewriter.replaceOp(linalgOp, replacements);
2785 return success();
2786 }
2787};
2788
2789} // namespace
2790
2791// All named ops canonicalizers and folders are auto-generated in the
2792// .cpp.inc.
2793
2794//===----------------------------------------------------------------------===//
2795// SoftmaxOp
2796//===----------------------------------------------------------------------===//
2797
2798LogicalResult SoftmaxOp::verify() {
2799 ShapedType inputType = getInputOperandType();
2800 ShapedType outputType = getOutputOperandType();
2801
2802 ArrayRef<int64_t> inputShape = inputType.getShape();
2803 ArrayRef<int64_t> outputShape = outputType.getShape();
2804 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2805 return emitOpError("incompatible output shape");
2806
2807 int64_t inputRank = getInputOperandRank();
2808 int64_t dimension = getDimension();
2809 if ((dimension < 0) || (dimension >= inputRank))
2810 return emitOpError("incorrect dimension specified");
2811
2812 return success();
2813}
2814
2815SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2816 int64_t operandRank = getInputOperandRank();
2817 SmallVector<Range> loopBounds(operandRank);
2818 Location loc = getLoc();
2819 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2820 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2821 Value source = getInput();
2822 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2823 loopBounds[dim].offset = zero;
2824 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2825 loopBounds[dim].stride = one;
2826 }
2827 return loopBounds;
2828}
2829
2830SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2831 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2832 utils::IteratorType::parallel);
2833 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2834 return iteratorTypes;
2835}
2836
2837FailureOr<TilingResult>
2838SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2839 ArrayRef<OpFoldResult> offsets,
2840 ArrayRef<OpFoldResult> sizes) {
2841 int64_t rank = getInputOperandRank();
2842 auto oneAttr = builder.getI64IntegerAttr(1);
2843 SmallVector<OpFoldResult> strides(rank, oneAttr);
2844 SmallVector<Value> tiledOperands;
2845 Operation *inputSlice =
2846 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2847 if (!inputSlice) {
2848 return emitOpError("failed to compute input slice");
2849 }
2850 tiledOperands.emplace_back(inputSlice->getResult(0));
2851 Operation *outputSlice =
2852 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2853 if (!outputSlice) {
2854 return emitOpError("failed to compute output slice");
2855 }
2856 tiledOperands.emplace_back(outputSlice->getResult(0));
2857
2858 SmallVector<Type, 4> resultTypes;
2859 if (hasPureTensorSemantics())
2860 resultTypes.push_back(tiledOperands[1].getType());
2861 Operation *tiledOp =
2862 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2863
2864 return TilingResult{
2865 {tiledOp},
2866 SmallVector<Value>(tiledOp->getResults()),
2867 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2868}
2869
2870LogicalResult SoftmaxOp::getResultTilePosition(
2871 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2872 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2873 SmallVector<OpFoldResult> &resultSizes) {
2874 if (resultNumber == 0) {
2875 resultOffsets.assign(offsets.begin(), offsets.end());
2876 resultSizes.assign(sizes.begin(), sizes.end());
2877 return success();
2878 }
2879 return failure();
2880}
2881
2882// cast(dynamic) -> static.
2883LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2884 return memref::foldMemRefCast(*this);
2885}
2886
2887LogicalResult
2888SoftmaxOp::reifyResultShapes(OpBuilder &b,
2889 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2890 SmallVector<OpFoldResult> shapes;
2891 Location loc = getOperation()->getLoc();
2892 IRRewriter rewriter(b);
2893 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2894 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2895 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2896 if (!outputShapedType.isDynamicDim(dim)) {
2897 // Static dim: Return IntegerAttr.
2898 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2899 } else {
2900 // Dynamic dim: Return Value.
2901 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2902 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2903 }
2904 }
2905 reifiedReturnShapes.emplace_back(std::move(shapes));
2906 return success();
2907}
2908
2909void SoftmaxOp::getEffects(
2910 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2911 &effects) {
2912 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2913 if (!llvm::isa<MemRefType>(operand.getType()))
2914 continue;
2915 effects.emplace_back(MemoryEffects::Read::get(),
2916 &getOperation()->getOpOperand(index), /*stage=*/0,
2917 /*effectOnFullRegion=*/true,
2919 }
2920
2921 for (OpOperand &operand : getDpsInitsMutable()) {
2922 if (!llvm::isa<MemRefType>(operand.get().getType()))
2923 continue;
2924 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2925 /*effectOnFullRegion=*/true,
2927 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2928 /*effectOnFullRegion=*/true,
2930 }
2931}
2932
2933// Helper functions for softmax decomposition.
2934// @{
2935
2936// Helper function to produce the iterator types (reduction or parallel) and
2937// affine maps for the iterators used in the decomposition of softmax.
2938// This method creates:
2939// If allParallel == true:
2940// - iterator type: {parallel, ..., parallel}
2941// - affine maps:
2942// -- identity with inputRank dimensions.
2943// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2944// where N == inputRank.
2945//
2946// If allParallel == false:
2947// - iterator type at dim(i) == parallel for i != \p dim and
2948// dim(dim) == reduction.
2949// - affine map:
2950// -- identity with inputRank dimensions.
2951// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2952// where N == inputRank.
2953static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2955 int64_t dim, bool allParallel = false) {
2956 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2957 utils::IteratorType::parallel);
2958 if (!allParallel)
2959 iteratorTypes[dim] = utils::IteratorType::reduction;
2960 MLIRContext *ctxt = builder.getContext();
2961 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2962 SmallVector<AffineExpr, 2> affineExprs;
2963 for (int i = 0; i < inputRank; i++) {
2964 if (i != dim)
2965 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2966 }
2967 auto reductionMap =
2968 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2969 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2970 return std::make_tuple(iteratorTypes, indexingMaps);
2971}
2972
2973// Helper function to produce a linalg.generic that computes a reduction on
2974// dimension \p dim with the operation type \p T.
2975template <typename T>
2976static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2977 int64_t dim) {
2978 auto inputType = cast<ShapedType>(input.getType());
2979 ArrayRef<int64_t> inputShape = inputType.getShape();
2980 int64_t inputRank = inputShape.size();
2981 auto [iteratorTypes, indexingMaps] =
2982 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2983 assert(indexingMaps.size() == 2 &&
2984 "We should have two maps: 1 for the input, 1 for the output");
2985 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2986
2987 auto genericOp = linalg::GenericOp::create(
2988 builder, loc, output.getType(), input, output, indexingMaps,
2989 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2990 Value result = T::create(b, loc, args[0], args[1]);
2991 linalg::YieldOp::create(b, loc, result);
2992 });
2993 return genericOp.getResult(0);
2994}
2995
2996/// Produce a linalg generic that computes the second step of the softmax
2997/// decomposition: res = exp(input - max), where \p max is the max of \p input
2998/// on dimension \p dim.
2999static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
3000 Value max, Value output, int64_t dim) {
3001 auto inputType = cast<ShapedType>(input.getType());
3002 ArrayRef<int64_t> inputShape = inputType.getShape();
3003 int64_t inputRank = inputShape.size();
3004 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3005 builder, inputRank, dim, /*allParallel=*/true);
3006 assert(indexingMaps.size() == 2 && "We should have one map for each input");
3007 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3008 // Add the affine map for the output argument.
3009 indexingMaps.push_back(indexingMaps[0]);
3010 auto genericOp = linalg::GenericOp::create(
3011 builder, loc, input.getType(), ValueRange{input, max}, output,
3012 indexingMaps, iteratorTypes,
3013 [&](OpBuilder &b, Location loc, ValueRange args) {
3014 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3015 Value result = math::ExpOp::create(b, loc, diff);
3016 linalg::YieldOp::create(b, loc, result);
3017 });
3018 return genericOp.getResult(0);
3019}
3020
3021/// Produce a linalg generic that computes the final step of the softmax
3022/// decomposition.
3023/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
3024/// yield n / d
3025/// }
3026static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
3027 Value denominator, Value output, int64_t dim) {
3028 auto inputType = cast<ShapedType>(numerator.getType());
3029 ArrayRef<int64_t> inputShape = inputType.getShape();
3030 int64_t inputRank = inputShape.size();
3031 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3032 builder, inputRank, dim, /*allParallel=*/true);
3033 assert(indexingMaps.size() == 2 &&
3034 "We should have one map for each input (2)");
3035 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
3036 // Add the affine map for the output tensor.
3037 indexingMaps.push_back(indexingMaps[0]);
3038 auto genericOp = linalg::GenericOp::create(
3039 builder, loc, numerator.getType(), ValueRange{numerator, denominator},
3040 output, indexingMaps, iteratorTypes,
3041 [&](OpBuilder &b, Location loc, ValueRange args) {
3042 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3043 linalg::YieldOp::create(b, loc, result);
3044 });
3045 return genericOp.getResult(0);
3046}
3047// @} End helper functions for softmax decomposition.
3048
3049/// Given an N-dimensional tensor x, this method converts
3050/// softmax(x) to the following sequence of operations:
3051///
3052/// 1. Compute the max of x along dimension d. This results
3053/// in a N-1 dimensional tensor m.
3054/// m = max(x, dim = d)
3055///
3056/// 2. Subtract a broadcasted m from x and exponentiate. This results in
3057/// a N dimensional tensor z.
3058/// z = exp(x - m)
3059///
3060/// 3. Compute the sum of z along dimension d. This results in
3061/// a N-1 dimensional tensor l.
3062/// l = sum(z, dim = d)
3063///
3064/// 4. Divide z and l. This gives the N-dimensional softmax.
3065/// softmax = z / l
3066///
3067FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3068 OpBuilder::InsertionGuard guard(b);
3069 b.setInsertionPoint(*this);
3070 Location loc = getLoc();
3071 Value input = getInput();
3072 ShapedType inputType = getInputOperandType();
3073 Type elementType = inputType.getElementType();
3074 int64_t reductionDim = getDimension();
3075 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3076 Value output = getOutput();
3077 dims.erase(dims.begin() + reductionDim);
3078 // Step 1: Compute max along dim.
3079 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3080 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3081 elementType, b, loc,
3082 /*useOnlyFiniteValue=*/true);
3083 Value neutralForMaxFInit =
3084 linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3085 .result();
3086 Value max =
3087 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3088
3089 // Step 2: Subtract max from input and exponentiate.
3090 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3091
3092 // Step 3: Compute sum along dim.
3093 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3094 b, loc, /*useOnlyFiniteValue=*/true);
3095 Value zeroInit =
3096 linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3097 Value denominator =
3098 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3099
3100 // Step 4: Compute softmax.
3101 Value result =
3102 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3103 return SmallVector<Value>{result};
3104}
3105
3106//===----------------------------------------------------------------------===//
3107// WinogradFilterTransformOp
3108//===----------------------------------------------------------------------===//
3109
3110LogicalResult WinogradFilterTransformOp::verify() {
3111 auto filterType = cast<ShapedType>(getFilter().getType());
3112 ArrayRef<int64_t> filterShape = filterType.getShape();
3113 int64_t filterH = filterShape[getFilterHDim()];
3114 int64_t filterW = filterShape[getFilterWDim()];
3115 WinogradConv2DFmr fmr = getFmr();
3116 int64_t m, r;
3117 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3118
3119 if (filterH != r && filterH != 1)
3120 return emitOpError("expect filter height either equals to r or 1");
3121 if (filterW != r && filterW != 1)
3122 return emitOpError("expect filter width either equals to r or 1");
3123 if (filterH == 1 && filterW == 1)
3124 return emitOpError("expect either filter height or width equals to r");
3125
3126 SmallVector<int64_t> expectedOutputShape;
3127 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3128 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3129 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3130 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3131
3132 auto outputType = cast<ShapedType>(getOutput().getType());
3133 ArrayRef<int64_t> outputShape = outputType.getShape();
3134 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3135 return emitOpError("the output shape is not expected");
3136 }
3137 return success();
3138}
3139
3140SmallVector<Range>
3141WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3142 Location loc = getLoc();
3143 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3144 IntegerAttr oneAttr = builder.getIndexAttr(1);
3145 Value filter = getFilter();
3146 int64_t filterRank = getFilterOperandRank();
3147 SmallVector<Range> loopBounds(filterRank);
3148 for (unsigned dim = 0; dim < filterRank; ++dim) {
3149 loopBounds[dim].offset = zeroAttr;
3150 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3151 loopBounds[dim].stride = oneAttr;
3152 }
3153 return loopBounds;
3154}
3155
3156SmallVector<utils::IteratorType>
3157WinogradFilterTransformOp::getLoopIteratorTypes() {
3158 int64_t filterRank = getFilterOperandRank();
3159 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3160 utils::IteratorType::parallel);
3161 return iteratorTypes;
3162}
3163
3164LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3165 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3166 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3167 SmallVector<OpFoldResult> &resultSizes) {
3168 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3169 ShapedType filterType = getFilterOperandType();
3170 ArrayRef<int64_t> filterShape = filterType.getShape();
3171 int64_t filterH = filterShape[getFilterHDim()];
3172 int64_t filterW = filterShape[getFilterWDim()];
3173 WinogradConv2DFmr fmr = getFmr();
3174 int64_t m, r;
3175 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3176 int64_t alpha = m + r - 1;
3177 int64_t alphaH = filterH != 1 ? alpha : 1;
3178 int64_t alphaW = filterW != 1 ? alpha : 1;
3179 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3180 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3181
3182 resultOffsets.append(
3183 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3184 resultSizes.append(
3185 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3186
3187 return success();
3188}
3189
3190/// Implement tiling for winograd_filter_transform
3191/// The input of winograd_filter_transform is (F, KH, KW, C).
3192/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3193/// Users can specify the tile sizes of F and C.
3194/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3195/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3196FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3197 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3198 ArrayRef<OpFoldResult> sizes) {
3199 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3200 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3201 ShapedType filterType = getFilterOperandType();
3202 ArrayRef<int64_t> filterShape = filterType.getShape();
3203 int64_t filterH = filterShape[getFilterHDim()];
3204 int64_t filterW = filterShape[getFilterWDim()];
3205 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3206 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3207 SmallVector<Value> tiledOperands;
3208 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3209
3210 sliceOffsets.append(
3211 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3212 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3213 sizes[getFilterCDim()]});
3214 int64_t filterRank = getFilterOperandRank();
3215 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3216 Location loc = getLoc();
3217 auto filterSlice = tensor::ExtractSliceOp::create(
3218 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3219 tiledOperands.emplace_back(filterSlice);
3220
3221 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3222 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3223 resultSizes)))
3224 return failure();
3225
3226 int64_t outputRank = getOutputOperandRank();
3227 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3228 auto outputSlice = tensor::ExtractSliceOp::create(
3229 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3230 tiledOperands.emplace_back(outputSlice);
3231
3232 SmallVector<Type> resultTypes;
3233 resultTypes.push_back(tiledOperands[1].getType());
3234 Operation *tiledOp =
3235 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3236
3237 return TilingResult{
3238 {tiledOp},
3239 SmallVector<Value>(tiledOp->getResults()),
3240 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3241}
3242
3243//===----------------------------------------------------------------------===//
3244// WinogradInputTransformOp
3245//===----------------------------------------------------------------------===//
3246
3247LogicalResult WinogradInputTransformOp::verify() {
3248 auto inputType = cast<ShapedType>(getInput().getType());
3249 ArrayRef<int64_t> inputShape = inputType.getShape();
3250 int64_t inputH = inputShape[getInputHDim()];
3251 int64_t inputW = inputShape[getInputWDim()];
3252 WinogradConv2DFmr fmr = getFmr();
3253 int64_t m, r;
3254 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3255 int64_t tileSize = m + r - 1;
3256
3257 auto outputType = cast<ShapedType>(getOutput().getType());
3258 ArrayRef<int64_t> outputShape = outputType.getShape();
3259 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3260 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3261
3262 SmallVector<int64_t> expectedOutputShape(6, inputH);
3263 if (ShapedType::isDynamic(inputH)) {
3264 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3265 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3266 } else {
3267 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3268 expectedOutputShape[getOutputTileHDim()] =
3269 leftTransform ? (inputH - (r - 1)) / m : inputH;
3270 }
3271 if (ShapedType::isDynamic(inputW)) {
3272 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3273 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3274 } else {
3275 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3276 expectedOutputShape[getOutputTileWDim()] =
3277 rightTransform ? (inputW - (r - 1)) / m : inputW;
3278 }
3279 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3280 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3281
3282 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3283 return emitOpError("the output shape is not expected");
3284 }
3285 return success();
3286}
3287
3288SmallVector<Range>
3289WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3290 Location loc = getLoc();
3291 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3292 IntegerAttr oneAttr = builder.getIndexAttr(1);
3293 Value output = getOutput();
3294 int64_t outputRank = getOutputOperandRank();
3295 SmallVector<Range> loopBounds(outputRank);
3296 for (unsigned dim = 0; dim < outputRank; ++dim) {
3297 loopBounds[dim].offset = zeroAttr;
3298 // alphaH, alphaW, tileH, tileW, N, C
3299 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3300 loopBounds[dim].stride = oneAttr;
3301 }
3302 return loopBounds;
3303}
3304
3305SmallVector<utils::IteratorType>
3306WinogradInputTransformOp::getLoopIteratorTypes() {
3307 int64_t outputRank = getOutputOperandRank();
3308 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3309 utils::IteratorType::parallel);
3310 return iteratorTypes;
3311}
3312
3313LogicalResult WinogradInputTransformOp::getResultTilePosition(
3314 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3315 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3316 SmallVector<OpFoldResult> &resultSizes) {
3317 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3318 ShapedType outputType = getOutputOperandType();
3319 ArrayRef<int64_t> outputShape = outputType.getShape();
3320 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3321 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3322
3323 WinogradConv2DFmr fmr = getFmr();
3324 int64_t m, r;
3325 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3326 int64_t alpha = m + r - 1;
3327 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3328 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3329
3330 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3331 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3332
3333 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3334 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3335 offsets[getOutputCDim()]});
3336 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3337 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3338 sizes[getOutputCDim()]});
3339
3340 return success();
3341}
3342
3343/// Implement tiling for winograd_input_transform
3344/// The input of winograd_input_transform is (N, H, W, C).
3345/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3346/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3347/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3348/// the values for the sizes of tileH, tileW, N, C for one tile.
3349FailureOr<TilingResult>
3350WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3351 ArrayRef<OpFoldResult> offsets,
3352 ArrayRef<OpFoldResult> sizes) {
3353 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3354 WinogradConv2DFmr fmr = getFmr();
3355 int64_t m, r;
3356 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3357
3358 ShapedType outputType = getOutputOperandType();
3359 ArrayRef<int64_t> outputShape = outputType.getShape();
3360 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3361 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3362
3363 Location loc = getLoc();
3364 MLIRContext *context = builder.getContext();
3365 auto identityAffineMap =
3366 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3367 auto offsetAffineMap =
3368 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3369 Value mappedOffsetH = affine::makeComposedAffineApply(
3370 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3371 offsets[getOutputTileHDim()]);
3372 Value mappedOffsetW = affine::makeComposedAffineApply(
3373 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3374 offsets[getOutputTileWDim()]);
3375 auto sizeAffineMap = AffineMap::get(
3376 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3377 Value mappedSizeH = affine::makeComposedAffineApply(
3378 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3379 Value mappedSizeW = affine::makeComposedAffineApply(
3380 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3381
3382 SmallVector<Value> tiledOperands;
3383 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3384
3385 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3386 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3387 sliceOffsets.append(
3388 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3389 OpFoldResult sizeH =
3390 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3391 OpFoldResult sizeW =
3392 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3393 sliceSizes.append(
3394 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3395 int64_t inputRank = getInputOperandRank();
3396 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3397 auto inputSlice = tensor::ExtractSliceOp::create(
3398 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3399 tiledOperands.emplace_back(inputSlice);
3400
3401 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3402 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3403 resultSizes)))
3404 return failure();
3405
3406 int64_t outputRank = getOutputOperandRank();
3407 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3408 auto outputSlice = tensor::ExtractSliceOp::create(
3409 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3410 tiledOperands.emplace_back(outputSlice);
3411
3412 SmallVector<Type> resultTypes;
3413 resultTypes.push_back(tiledOperands[1].getType());
3414 Operation *tiledOp =
3415 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3416
3417 return TilingResult{
3418 {tiledOp},
3419 SmallVector<Value>(tiledOp->getResults()),
3420 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3421}
3422
3423//===----------------------------------------------------------------------===//
3424// WinogradOutputTransformOp
3425//===----------------------------------------------------------------------===//
3426
3427LogicalResult WinogradOutputTransformOp::verify() {
3428 auto valueType = cast<ShapedType>(getValue().getType());
3429 ArrayRef<int64_t> valueShape = valueType.getShape();
3430 int64_t valueH = valueShape[getValueAlphaHDim()];
3431 int64_t valueW = valueShape[getValueAlphaWDim()];
3432 int64_t valueTileH = valueShape[getValueTileHDim()];
3433 int64_t valueTileW = valueShape[getValueTileWDim()];
3434 WinogradConv2DFmr fmr = getFmr();
3435 int64_t m, r;
3436 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3437 bool leftTransform = valueH != 1;
3438 bool rightTransform = valueW != 1;
3439
3440 int64_t outputRank = getOutputOperandRank();
3441 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3442 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3443 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3444 } else {
3445 if (valueH != (leftTransform ? m + r - 1 : 1))
3446 return emitOpError("expect input height equals to input tile size");
3447 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3448 }
3449 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3450 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3451 } else {
3452 if (valueW != (rightTransform ? m + r - 1 : 1))
3453 return emitOpError("expect input width equals to input tile size");
3454 expectedOutputShape[getOutputWDim()] =
3455 (rightTransform ? m : 1) * valueTileW;
3456 }
3457 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3458 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3459
3460 auto outputType = cast<ShapedType>(getOutput().getType());
3461 ArrayRef<int64_t> outputShape = outputType.getShape();
3462 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3463 return emitOpError("the output shape is not expected");
3464 }
3465 return success();
3466}
3467
3468SmallVector<Range>
3469WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3470 Location loc = getLoc();
3471 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3472 IntegerAttr oneAttr = builder.getIndexAttr(1);
3473 Value value = getValue();
3474 int64_t valueRank = getValueOperandRank();
3475 SmallVector<Range> loopBounds(valueRank);
3476 for (unsigned dim = 0; dim < valueRank; ++dim) {
3477 loopBounds[dim].offset = zeroAttr;
3478 // alphaH, alphaW, tileH, tileW, N, F
3479 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3480 loopBounds[dim].stride = oneAttr;
3481 }
3482 return loopBounds;
3483}
3484
3485SmallVector<utils::IteratorType>
3486WinogradOutputTransformOp::getLoopIteratorTypes() {
3487 int64_t valueRank = getValueOperandRank();
3488 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3489 utils::IteratorType::parallel);
3490 return iteratorTypes;
3491}
3492
3493LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3494 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3495 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3496 SmallVector<OpFoldResult> &resultSizes) {
3497 WinogradConv2DFmr fmr = getFmr();
3498 int64_t m, r;
3499 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3500
3501 Location loc = getLoc();
3502 MLIRContext *context = builder.getContext();
3503 auto identityAffineMap =
3504 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3505 auto affineMap =
3506 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3507
3508 ShapedType valueType = getValueOperandType();
3509 ArrayRef<int64_t> valueShape = valueType.getShape();
3510 int64_t valueH = valueShape[0];
3511 int64_t valueW = valueShape[1];
3512 Value mappedOffsetH = affine::makeComposedAffineApply(
3513 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3514 offsets[getValueTileHDim()]);
3515 Value mappedOffsetW = affine::makeComposedAffineApply(
3516 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3517 offsets[getValueTileWDim()]);
3518 Value mappedSizeH = affine::makeComposedAffineApply(
3519 builder, loc, affineMap, sizes[getValueTileHDim()]);
3520 Value mappedSizeW = affine::makeComposedAffineApply(
3521 builder, loc, affineMap, sizes[getValueTileWDim()]);
3522
3523 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3524 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3525 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3526 OpFoldResult sizeH =
3527 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3528 OpFoldResult sizeW =
3529 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3530
3531 resultOffsets.append(
3532 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3533 resultSizes.append(
3534 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3535 return success();
3536}
3537
3538/// Implement tiling for winograd_output_transform
3539/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3540/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3541/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3542/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3543/// for the sizes of tileH, tileW, N, F for one tile.
3544FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3545 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3546 ArrayRef<OpFoldResult> sizes) {
3547 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3548 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3549 Location loc = getLoc();
3550 SmallVector<Value> tiledOperands;
3551 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3552
3553 ShapedType valueType = getValueOperandType();
3554 ArrayRef<int64_t> valueShape = valueType.getShape();
3555 int64_t alphaH = valueShape[getValueAlphaHDim()];
3556 int64_t alphaW = valueShape[getValueAlphaWDim()];
3557 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3558 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3559
3560 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3561 offsets[getValueTileWDim()], offsets[getValueNDim()],
3562 offsets[getValueFDim()]});
3563 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3564 sizes[getValueTileWDim()], sizes[getValueNDim()],
3565 sizes[getValueFDim()]});
3566 int64_t valueRank = getValueOperandRank();
3567 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3568 auto valueSlice = tensor::ExtractSliceOp::create(
3569 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3570 tiledOperands.emplace_back(valueSlice);
3571
3572 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3573 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3574 resultSizes)))
3575 return failure();
3576
3577 int64_t outputRank = getOutputOperandRank();
3578 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3579 auto outputSlice = tensor::ExtractSliceOp::create(
3580 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3581 tiledOperands.emplace_back(outputSlice);
3582
3583 SmallVector<Type> resultTypes;
3584 resultTypes.push_back(tiledOperands[1].getType());
3585 Operation *tiledOp =
3586 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3587
3588 return TilingResult{
3589 {tiledOp},
3590 SmallVector<Value>(tiledOp->getResults()),
3591 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3592}
3593
3594//===----------------------------------------------------------------------===//
3595// LinalgDialect
3596// TODO: Merge with the LinalgDialect block at the bottom
3597//===----------------------------------------------------------------------===//
3598
3599// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3600static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3601 auto explicitRange = subMap.getResults();
3602 auto defaultRange = fullMap.getResults();
3603 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3604 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3605 llvm::set_union(explicitSet, defaultSet);
3606 return explicitSet == defaultSet;
3607}
3608
3609/// Check if the user defined map is valid broadcast map. Here broadcast
3610/// indexing maps are defined in context of corresponding default indexing maps
3611/// for the given Op. This way the check becomes very simple i.e just check the
3612/// number of result dims.
3613/// Returns true if the explictMap is broadcasted with respect to the
3614/// defaultMap.
3615static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3616 return explictMap.getNumResults() < defaultMap.getNumResults();
3617}
3618
3619/// Verifies the broadcast and transpose semantic sepecified by the explicit
3620/// indexing map for the MatmulOp \p op for each operand specified by \p
3621/// opIndex.
3622static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3623 unsigned opIndex) {
3624 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3625 SmallVector<AffineMap, 3> defaultIndexingMaps =
3626 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3627
3628 auto opIndexingMap = opIndexingMaps[opIndex];
3629 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3630 // Check general validity of indexing map results.
3631 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3632 return matmulOp->emitOpError()
3633 << "Unexpected dim expression in map result.";
3634
3635 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3636 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3637 return matmulOp->emitOpError()
3638 << "Invalid broadcast requested, should be (d2).";
3639 }
3640 return success();
3641 }
3642 return success();
3643}
3644
3645// Check general validity of input indexing map of
3646// BatchMatmulOp/BatchReduceMatmulOp.
3647template <typename OpTy>
3648static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3649 AffineMap opIndexingMap,
3650 AffineMap defaultIndexingMap, bool isLHS) {
3651 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3652 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3653 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3654 // Check the result dims are valid.
3655 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3656 return batchVariantMatmulOp->emitOpError()
3657 << "Unexpected result dim expression (outside the set of default "
3658 "result dims).";
3659
3660 // Check for valid number of result dims of input maps.
3661 if (opIndexingMap.getNumResults() > 3)
3662 return batchVariantMatmulOp->emitOpError()
3663 << "no. of result dim expressions exceeds 3.";
3664
3665 auto hasValidBatchDim = [](AffineMap map) {
3666 AffineExpr batchDim = map.getResult(0);
3667 return batchDim.isFunctionOfDim(0);
3668 };
3669
3670 // Check if the requested broadcast is valid.
3671 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3672 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3673 return batchVariantMatmulOp->emitOpError()
3674 << "Invalid broadcast requested.";
3675 } else if (!hasValidBatchDim(opIndexingMap)) {
3676 return batchVariantMatmulOp->emitOpError()
3677 << "Invalid batch dimension expression.";
3678 }
3679 return success();
3680}
3681
3682/// This function checks if the given AffineMap for the output of a
3683/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3684/// dimensions and if the output map result dimensions are valid.
3685template <typename OpTy>
3686static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3687 AffineMap opIndexingMap) {
3688 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3689 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3690 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3691 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3692 opIndexingMap.getNumResults() != 3) {
3693
3694 return batchVariantMatmulOp->emitOpError()
3695 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3696 << ").";
3697 }
3698 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3699 opIndexingMap.getNumResults() != 2) {
3700 return batchVariantMatmulOp->emitOpError()
3701 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3702 << ").";
3703 }
3704
3705 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3706 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3707 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3708 outputMap.getResult(1).isFunctionOfDim(1) &&
3709 outputMap.getResult(2).isFunctionOfDim(2)
3710 : outputMap.getResult(0).isFunctionOfDim(1) &&
3711 outputMap.getResult(1).isFunctionOfDim(2);
3712 };
3713
3714 if (!areValidOutputResultDim(opIndexingMap)) {
3715 return batchVariantMatmulOp->emitOpError()
3716 << "Invalid output map result dimension.";
3717 }
3718
3719 return success();
3720}
3721
3722/// Verifies the broadcast and transpose semantic specified by the explicit
3723/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3724/// specified by opIndex.
3725template <typename OpTy>
3726static LogicalResult
3728 unsigned opIndex) {
3729 SmallVector<AffineMap, 3> opIndexingMaps =
3730 batchVariantMatmulOp.getIndexingMapsArray();
3731 SmallVector<AffineMap, 3> defaultIndexingMaps =
3732 batchVariantMatmulOp.getDefaultIndexingMaps(
3733 batchVariantMatmulOp->getContext());
3734
3735 if (opIndexingMaps.size() != 3)
3736 return batchVariantMatmulOp->emitOpError()
3737 << "Indexing_map attribute must have 3 affine maps.";
3738
3739 auto opIndexingMap = opIndexingMaps[opIndex];
3740 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3741
3742 if (opIndex == 2 &&
3743 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3744 return failure();
3745
3746 if (opIndex != 2 &&
3747 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3748 defaultIndexingMap, opIndex == 0)))
3749 return failure();
3750
3751 return success();
3752}
3753
3754namespace mlir {
3755namespace linalg {
3756
3757std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3758 if (m == 2 && r == 3)
3759 return WinogradConv2DFmr::F_2_3;
3760 if (m == 4 && r == 3)
3761 return WinogradConv2DFmr::F_4_3;
3762 if (m == 2 && r == 5)
3763 return WinogradConv2DFmr::F_2_5;
3764 return std::nullopt;
3765}
3766
3767std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3768 switch (fmr) {
3769 case WinogradConv2DFmr::F_2_3:
3770 return {2, 3};
3771 case WinogradConv2DFmr::F_4_3:
3772 return {4, 3};
3773 case WinogradConv2DFmr::F_2_5:
3774 return {2, 5};
3775 }
3776}
3777
3778//===----------------------------------------------------------------------===//
3779// MatMulOp
3780//===----------------------------------------------------------------------===//
3781
3782static FailureOr<SmallVector<SmallVector<int64_t>>>
3785 for (auto map : maps) {
3786 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3787 if (!attr)
3788 return failure();
3790 for (auto result : attr.getAffineMap().getResults()) {
3791 auto dim = dyn_cast<AffineDimExpr>(result);
3792 if (!dim)
3793 return failure();
3794 pos.push_back(dim.getPosition());
3795 }
3796 positions.push_back(pos);
3797 }
3798 return positions;
3799}
3800
3801/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3802SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3803 AffineExpr d0, d1, d2;
3804 SmallVector<AffineMap> indexingMaps;
3805 bindDims(context, d0, d1, d2);
3806 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3807 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3808 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3809 return indexingMaps;
3810}
3811
3812bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3813 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3814 if (!maps)
3815 return false;
3816 if (maps.size() != 3)
3817 return false;
3818 auto positions = getAffineResultPositions(maps);
3819 if (failed(positions))
3820 return false;
3821 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3822 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3823 (*positions)[2] == SmallVector<int64_t>{0, 1};
3824}
3825
3826SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3827 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3828 utils::IteratorType::parallel,
3829 utils::IteratorType::reduction};
3830}
3831
3832unsigned MatmulOp::getNumRegionArgs() { return 3; }
3833
3834std::string MatmulOp::getLibraryCallName() {
3835 return generateLibraryCallName(getOperation());
3836}
3837
3838bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3839
3840/// Check if the op has broadcast and/or transpose semantic. Returns true if
3841/// the user defined indexing maps are not equal to default map.
3842bool MatmulOp::hasUserDefinedMaps() {
3843 SmallVector<AffineMap, 3> defaultMaps =
3844 getDefaultIndexingMaps(this->getContext());
3845 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3846 return defaultMaps != explicitMaps;
3847}
3848
3849/// Implements the block region builder for the MatmulOp. This is called by
3850/// 'fillStructuredOpRegion'.
3851void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3852 ArrayRef<NamedAttribute> attrs,
3853 function_ref<InFlightDiagnostic()> emitError) {
3854 if (emitError && block.getNumArguments() != 3) {
3855 emitError() << "MatmulOp regionBuilder expects 3 args, got "
3856 << block.getNumArguments();
3857 return;
3858 }
3859 assert(block.getNumArguments() == 3 &&
3860 "MatmulOp regionBuilder expects 3 args");
3861 RegionBuilderHelper helper(b, block);
3862 SmallVector<Value> yields;
3863
3864 TypeFn castVal = TypeFn::cast_signed;
3865 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3866 return attr.getName() == "cast";
3867 });
3868 if (castIter != attrs.end()) {
3869 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3870 castVal = attr.getValue();
3871 }
3872
3873 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3874 block.getArgument(0));
3875 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3876 block.getArgument(1));
3877 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3878 if (!value3)
3879 return;
3880 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3881 value3, emitError);
3882 if (!value4)
3883 return;
3884 yields.push_back(value4);
3885 helper.yieldOutputs(yields);
3886}
3887
3888/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3889/// broadcast map must include K dimension.
3890/// TODO: Strict inclusion of K dimension in the broadcast map is not
3891/// necessary for both input matrices simultaneously. We can relax this
3892/// condition to have K dimension for one input matrix map and infer the K
3893/// dimension for other input matrix map from the one already having K
3894/// dimension.
3895bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3896 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3897 AffineExpr expr = bcastMap.getResult(0);
3898 // Invalid map if the common dimension of matmul not found.
3899 return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3900}
3901
3902static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3903 if (parser.parseOptionalKeyword("indexing_maps"))
3904 return ArrayAttr{
3905 nullptr}; // Success in case indexing_maps was not provided.
3906
3907 ArrayAttr arrayAttr;
3908 if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3909 return failure();
3910
3911 if (llvm::any_of(arrayAttr,
3912 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3913 return parser.emitError(parser.getCurrentLocation())
3914 << "element of indexing_maps array is not an affine_map";
3915
3916 return arrayAttr;
3917}
3918
3919ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3920 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3921 if (failed(indexingMapsAttr))
3922 return failure();
3923
3924 if (*indexingMapsAttr == nullptr) {
3925 auto indexingMapAttrs = llvm::map_to_vector(
3926 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3927 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3928 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3929 }
3930
3931 result.addAttribute("indexing_maps", *indexingMapsAttr);
3932 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3933 MatmulOp::getRegionBuilder());
3934}
3935
3936void MatmulOp::print(OpAsmPrinter &p) {
3937 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3938 MatmulOp::getDefaultIndexingMaps(getContext()),
3939 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3940 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3941 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3942
3943 std::array<StringRef, 3> elidedAttrs = {
3944 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3945 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3946 elidedAttrs);
3947}
3948
3949/// Verify the user defined indexing maps.
3950LogicalResult MatmulOp::verify() {
3951 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3952 if (!hasUserDefinedMaps())
3953 return success();
3954
3955 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3956 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3957 return failure();
3958 }
3959 return success();
3960}
3961
3962LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3963 return memref::foldMemRefCast(*this);
3964}
3965
3966void MatmulOp::getEffects(
3967 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3968 &effects) {
3969 if (hasPureTensorSemantics())
3970 return;
3971 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3972}
3973
3974Speculation::Speculatability MatmulOp::getSpeculatability() {
3975 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3976}
3977
3978SmallVector<AffineMap>
3979MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3980 AffineExpr d0, d1, d2;
3981 MLIRContext *context = builder.getContext();
3982 bindDims(context, d0, d1, d2);
3983 AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
3984 AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
3985 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
3986 return {mapLHS, mapRHS, mapOut};
3987}
3988
3990 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3991 if (!maps)
3992 return false;
3993 if (maps.size() != 3)
3994 return false;
3995 auto positions = getAffineResultPositions(maps);
3996 if (failed(positions))
3997 return false;
3998 return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
3999 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
4000 (*positions)[2] == SmallVector<int64_t>{0, 1};
4001}
4002
4005 ValueRange inputs, ValueRange outputs,
4006 ArrayRef<NamedAttribute> attributes) {
4007 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4008 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4009}
4010
4013 ValueRange inputs, ValueRange outputs,
4014 ArrayRef<NamedAttribute> attributes) {
4015 OperationState state(location, getOperationName());
4016 build(builder, state, inputs, outputs, attributes);
4017 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4018 assert(res && "builder didn't return the right type");
4019 return res;
4020}
4021
4024 TypeRange resultTensorTypes,
4025 ValueRange inputs, ValueRange outputs,
4026 ArrayRef<NamedAttribute> attributes) {
4027 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4028 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4029}
4030
4033 TypeRange resultTensorTypes, ValueRange inputs,
4034 ValueRange outputs,
4035 ArrayRef<NamedAttribute> attributes) {
4036 OperationState state(location, getOperationName());
4037 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4038 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4039 assert(res && "builder didn't return the right type");
4040 return res;
4041}
4042
4045 TypeRange resultTensorTypes,
4046 ValueRange inputs, ValueRange outputs,
4047 Attribute cast,
4048 ArrayRef<NamedAttribute> attributes) {
4049 result.addAttribute("cast", cast);
4050 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4051 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4052}
4053
4056 TypeRange resultTensorTypes, ValueRange inputs,
4057 ValueRange outputs, Attribute cast,
4058 ArrayRef<NamedAttribute> attributes) {
4059 OperationState state(location, getOperationName());
4060 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4061 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4062 assert(res && "builder didn't return the right type");
4063 return res;
4064}
4065
4067 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4069 op->getAttr("indexing_maps"));
4070}
4071
4073MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4074 AffineExpr d0, d1, d2;
4075 MLIRContext *context = builder.getContext();
4076 bindDims(context, d0, d1, d2);
4077 AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
4078 AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
4079 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4080 return {mapLHS, mapRHS, mapOut};
4081}
4082
4084 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4085 if (!maps)
4086 return false;
4087 if (maps.size() != 3)
4088 return false;
4089 auto positions = getAffineResultPositions(maps);
4090 if (failed(positions))
4091 return false;
4092 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
4093 (*positions)[1] == SmallVector<int64_t>{1, 2} &&
4094 (*positions)[2] == SmallVector<int64_t>{0, 1};
4095}
4096
4099 ValueRange inputs, ValueRange outputs,
4100 ArrayRef<NamedAttribute> attributes) {
4101 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4102 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4103}
4104
4107 ValueRange inputs, ValueRange outputs,
4108 ArrayRef<NamedAttribute> attributes) {
4109 OperationState state(location, getOperationName());
4110 build(builder, state, inputs, outputs, attributes);
4111 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4112 assert(res && "builder didn't return the right type");
4113 return res;
4114}
4115
4118 TypeRange resultTensorTypes,
4119 ValueRange inputs, ValueRange outputs,
4120 ArrayRef<NamedAttribute> attributes) {
4121 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4122 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4123}
4124
4127 TypeRange resultTensorTypes, ValueRange inputs,
4128 ValueRange outputs,
4129 ArrayRef<NamedAttribute> attributes) {
4130 OperationState state(location, getOperationName());
4131 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4132 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4133 assert(res && "builder didn't return the right type");
4134 return res;
4135}
4136
4139 TypeRange resultTensorTypes,
4140 ValueRange inputs, ValueRange outputs,
4141 Attribute cast,
4142 ArrayRef<NamedAttribute> attributes) {
4143 result.addAttribute("cast", cast);
4144 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4145 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4146}
4147
4150 TypeRange resultTensorTypes, ValueRange inputs,
4151 ValueRange outputs, Attribute cast,
4152 ArrayRef<NamedAttribute> attributes) {
4153 OperationState state(location, getOperationName());
4154 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4155 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4156 assert(res && "builder didn't return the right type");
4157 return res;
4158}
4159
4161 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4163 op->getAttr("indexing_maps"));
4164}
4165
4167BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4168 AffineExpr d0, d1, d2, d3;
4169 MLIRContext *context = builder.getContext();
4170 bindDims(context, d0, d1, d2, d3);
4171 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
4172 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
4173 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4174 return {mapLHS, mapRHS, mapOut};
4175}
4176
4178 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4179 if (!maps)
4180 return false;
4181 if (maps.size() != 3)
4182 return false;
4183 auto positions = getAffineResultPositions(maps);
4184 if (failed(positions))
4185 return false;
4186 return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4187 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4188 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4189}
4190
4192 OpBuilder &builder, OperationState &result, ValueRange inputs,
4193 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4194 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4195 BatchMatmulOp::getRegionBuilder(),
4196 getDefaultIndexingMaps(builder));
4197}
4198
4201 ValueRange inputs, ValueRange outputs,
4202 ArrayRef<NamedAttribute> attributes) {
4203 OperationState state(location, getOperationName());
4204 build(builder, state, inputs, outputs, attributes);
4205 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4206 assert(res && "builder didn't return the right type");
4207 return res;
4208}
4209
4211 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4212 ValueRange inputs, ValueRange outputs,
4213 ArrayRef<NamedAttribute> attributes) {
4214 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4215 BatchMatmulOp::getRegionBuilder(),
4216 getDefaultIndexingMaps(builder));
4217}
4218
4221 TypeRange resultTensorTypes, ValueRange inputs,
4222 ValueRange outputs,
4223 ArrayRef<NamedAttribute> attributes) {
4224 OperationState state(location, getOperationName());
4225 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4226 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4227 assert(res && "builder didn't return the right type");
4228 return res;
4229}
4230
4232 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4233 ValueRange inputs, ValueRange outputs, Attribute cast,
4234 ArrayRef<NamedAttribute> attributes) {
4235 result.addAttribute("cast", cast);
4236 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4237 BatchMatmulOp::getRegionBuilder(),
4238 getDefaultIndexingMaps(builder));
4239}
4240
4243 TypeRange resultTensorTypes, ValueRange inputs,
4244 ValueRange outputs, Attribute cast,
4245 ArrayRef<NamedAttribute> attributes) {
4246 OperationState state(location, getOperationName());
4247 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4248 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4249 assert(res && "builder didn't return the right type");
4250 return res;
4251}
4252
4254 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4256 op->getAttr("indexing_maps"));
4257}
4258
4260BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4261 AffineExpr d0, d1, d2, d3;
4262 MLIRContext *context = builder.getContext();
4263 bindDims(context, d0, d1, d2, d3);
4264 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4265 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4266 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4267 return {mapLHS, mapRHS, mapOut};
4268}
4269
4271 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4272 if (!maps)
4273 return false;
4274 if (maps.size() != 3)
4275 return false;
4276 auto positions = getAffineResultPositions(maps);
4277 if (failed(positions))
4278 return false;
4279 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4280 (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4281 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4282}
4283
4285 OpBuilder &builder, OperationState &result, ValueRange inputs,
4286 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4287 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4288 BatchMatmulOp::getRegionBuilder(),
4289 getDefaultIndexingMaps(builder));
4290}
4291
4294 ValueRange inputs, ValueRange outputs,
4295 ArrayRef<NamedAttribute> attributes) {
4296 OperationState state(location, getOperationName());
4297 build(builder, state, inputs, outputs, attributes);
4298 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4299 assert(res && "builder didn't return the right type");
4300 return res;
4301}
4302
4304 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4305 ValueRange inputs, ValueRange outputs,
4306 ArrayRef<NamedAttribute> attributes) {
4307 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4308 BatchMatmulOp::getRegionBuilder(),
4309 getDefaultIndexingMaps(builder));
4310}
4311
4314 TypeRange resultTensorTypes, ValueRange inputs,
4315 ValueRange outputs,
4316 ArrayRef<NamedAttribute> attributes) {
4317 OperationState state(location, getOperationName());
4318 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4319 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4320 assert(res && "builder didn't return the right type");
4321 return res;
4322}
4323
4325 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4326 ValueRange inputs, ValueRange outputs, Attribute cast,
4327 ArrayRef<NamedAttribute> attributes) {
4328 result.addAttribute("cast", cast);
4329 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4330 BatchMatmulOp::getRegionBuilder(),
4331 getDefaultIndexingMaps(builder));
4332}
4333
4336 TypeRange resultTensorTypes, ValueRange inputs,
4337 ValueRange outputs, Attribute cast,
4338 ArrayRef<NamedAttribute> attributes) {
4339 OperationState state(location, getOperationName());
4340 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4341 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4342 assert(res && "builder didn't return the right type");
4343 return res;
4344}
4345
4347 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4349 op->getAttr("indexing_maps"));
4350}
4351
4352//===----------------------------------------------------------------------===//
4353// ContractOp
4354//===----------------------------------------------------------------------===//
4355
4356SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
4357 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4358 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
4359 // domains are all the same, and each implements a projected permutation.
4360 // Each iteration space dim must occur for at least one operand and either
4361 // takes part in a contraction/reduction or else has parallel iteration type.
4362 // We have that a dim is a contraction/reduction dim if and only if the dim
4363 // occurs for the output operand. We use this fact for fast inference:
4364 // NB: In case we allow dims to occur solely for one input, the above still
4365 // holds: per the einsum semantics, these are reduction dims as well.
4366 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
4367 for (auto result : outAffineMap.getResults()) {
4368 auto dimExpr = dyn_cast<AffineDimExpr>(result);
4369 assert(dimExpr && "affine_map is a projected permutation");
4370 dimsInOutput[dimExpr.getPosition()] = true;
4371 }
4372
4374 for (auto dimOccursInOutput : dimsInOutput)
4375 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4376 : utils::IteratorType::reduction);
4377
4378 return iteratorTypes;
4379}
4380
4381unsigned ContractOp::getNumRegionArgs() { return 3; }
4382
4383/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
4384void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4385 ArrayRef<NamedAttribute> attrs,
4386 function_ref<InFlightDiagnostic()> emitError) {
4387 if (emitError && block.getNumArguments() != 3) {
4388 emitError() << "ContractOp regionBuilder expects 3 args, got "
4389 << block.getNumArguments();
4390 return;
4391 }
4392 assert(block.getNumArguments() == 3 &&
4393 "ContractOp regionBuilder expects 3 args");
4394 RegionBuilderHelper helper(b, block);
4395
4396 TypeFn castSignedness = TypeFn::cast_signed;
4397 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4398 return attr.getName() == "cast";
4399 });
4400 if (castIter != attrs.end()) {
4401 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4402 castSignedness = attr.getValue();
4403 }
4404
4405 // TODO: Support fields with operators besides mult & add.
4406 Type outType = block.getArgument(2).getType();
4407 Value lhsAtOutType =
4408 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
4409 Value rhsAtOutType =
4410 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
4411 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4412 rhsAtOutType, emitError);
4413 if (!productAtOutType)
4414 return;
4415 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4416 productAtOutType, emitError);
4417 if (!result)
4418 return;
4419 helper.yieldOutputs({result});
4420}
4421
4422ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
4423 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4424 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
4425 return parser.emitError(parser.getCurrentLocation(),
4426 "expected 'indexing_maps' attribute");
4427 result.addAttribute("indexing_maps", *indexingMapsAttr);
4428
4429 return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
4430 regionBuilder);
4431}
4432
4433void ContractOp::print(OpAsmPrinter &p) {
4434 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4436 p, getOperation(), getInputs(), getOutputs(),
4437 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
4438}
4439
4440LogicalResult ContractOp::verify() {
4441 int iterationSpaceDims = -1;
4442 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
4443 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
4444 // access an input operand (so occurrence count can be at most 2) and
4445 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
4446 SmallVector<size_t> inOccurrences;
4447 SmallVector<size_t> outOccurrences;
4448
4449 // A helper so that for each operand's affine_map and type we check that ...
4450 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4451 bool isInput) -> LogicalResult {
4452 // ... the affine_map is a projected permutation;
4453 if (!affineMap.isProjectedPermutation())
4454 return emitError("provided affine_map is not a projected permutation");
4455
4456 // ... the rank of the affine_map's results and corresponding type match;
4457 if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
4458 if (affineMap.getNumResults() != shapedType.getRank())
4459 return emitError("ranks of shaped operand and results of corresponding "
4460 "affine_map differ");
4461 } else if (affineMap.getNumResults() != 0) {
4462 return emitError("affine_map specifies shaped access while operand has "
4463 "non-shaped type");
4464 }
4465
4466 // ... the rank of the affine_map's domain is the same as those seen prior;
4467 if (iterationSpaceDims == -1) {
4468 iterationSpaceDims = affineMap.getNumDims();
4469 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4470 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4471 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4472 return emitError("iteration spaces of provided affine_maps differ");
4473 }
4474
4475 // ... update counts of dims used to access either an input or the output.
4476 for (AffineExpr affineExpr : affineMap.getResults()) {
4477 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4478 if (!affineDimExpr)
4479 llvm_unreachable("affine_map is a projected permutation");
4480
4481 if (isInput)
4482 inOccurrences[affineDimExpr.getPosition()] += 1;
4483 else
4484 outOccurrences[affineDimExpr.getPosition()] += 1;
4485 }
4486
4487 return success();
4488 };
4489
4490 for (auto &&[affineMap, operandType, isInput] :
4491 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4492 SmallVector<bool>{true, true, false})) {
4493 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4494 return failure(); // NB: checkAffineMapAndType will emit relevant error.
4495 }
4496
4497 bool hasContractingDim = false;
4498 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4499 size_t inOccCount = inOccurrences[dimIndex];
4500 size_t outOccCount = outOccurrences[dimIndex];
4501
4502 // We have a contracting dim if and only if ...
4503 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4504
4505 if (inOccCount == 0 && outOccCount == 0)
4506 return emitError() << "iteration space dim at index " << dimIndex
4507 << " not used to access any operand";
4508
4509 // NB: We disallow a dim which occurs for only one input operand and not
4510 // for the output. In terms of einsum semantics such dims have a
4511 // sensible meaning - namely an additional reduction per each such dim.
4512 // By contrast, the ContractionOpInterface does not know about this
4513 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4514 // while vector.contract's verifier accepts dims of this kind many of
4515 // its lowerings give up on encountering these dims.
4516 // TODO: Remove following once we have comprehensive support for input-only
4517 // reduction dims, at both the linalg- and vector-dialect levels.
4518 if (inOccCount == 1 && outOccCount != 1)
4519 return emitError()
4520 << "iteration space dim at index " << dimIndex
4521 << " is neither a contracting dim nor of parallel iteration type";
4522 }
4523
4524 if (!hasContractingDim)
4525 return emitError("'indexing_maps' do not specify a contracting dimension");
4526
4527 return success();
4528}
4529
4530LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4531 return memref::foldMemRefCast(*this);
4532}
4533
4534void ContractOp::getEffects(
4535 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4536 &effects) {
4537 if (hasPureTensorSemantics())
4538 return;
4539 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4540}
4541
4542Speculation::Speculatability ContractOp::getSpeculatability() {
4543 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4544}
4545
4546//===----------------------------------------------------------------------===//
4547// Implementation of BatchMatmulOp
4548//===----------------------------------------------------------------------===//
4549SmallVector<AffineMap>
4550BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4551 AffineExpr d0, d1, d2, d3;
4552 SmallVector<AffineMap> indexingMaps;
4553 bindDims(context, d0, d1, d2, d3);
4554 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4555 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4556 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4557 return indexingMaps;
4558}
4559
4560bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4561 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4562 if (!maps)
4563 return false;
4564 if (maps.size() != 3)
4565 return false;
4566 auto positions = getAffineResultPositions(maps);
4567 if (failed(positions))
4568 return false;
4569 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4570 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4571 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4572}
4573
4574SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4575 return SmallVector<utils::IteratorType>{
4576 utils::IteratorType::parallel, utils::IteratorType::parallel,
4577 utils::IteratorType::parallel, utils::IteratorType::reduction};
4578}
4579
4580unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4581
4582std::string BatchMatmulOp::getLibraryCallName() {
4583 return generateLibraryCallName(getOperation());
4584}
4585
4586/// Check if the op has broadcast and/or transpose semantic. Returns true if
4587/// the user defined indexing maps are not equal to default map.
4588bool BatchMatmulOp::hasUserDefinedMaps() {
4589 SmallVector<AffineMap, 3> defaultMaps =
4590 getDefaultIndexingMaps(this->getContext());
4591 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4592 return defaultMaps != explicitMaps;
4593}
4594
4595/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4596/// broadcast map must include K dimension.
4597/// TODO: Strict inclusion of K dimension in the broadcast map is not
4598/// necessary for both input matrices simultaneously. We can relax this
4599/// condition to have K dimension for one input matrix map and infer the K
4600/// dimension for other input matrix map from the one already having K
4601/// dimension.
4602bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4603 assert(bcastMap.getNumResults() < 3 &&
4604 "Expected less than 3 result dim expr.");
4605 bool isValid = false;
4606 enum Indices { batchPos, mPos, nPos, kPos };
4607 if (bcastMap.getNumResults() == 1) {
4608 AffineExpr expr = bcastMap.getResult(0);
4609 isValid = expr.isFunctionOfDim(kPos);
4610 } else if (bcastMap.getNumResults() == 2) {
4611 AffineExpr expr0 = bcastMap.getResult(0);
4612 AffineExpr expr1 = bcastMap.getResult(1);
4613 isValid =
4614 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4615 expr0.isFunctionOfDim(mPos)) &&
4616 expr1.isFunctionOfDim(kPos))
4617 : ((expr0.isFunctionOfDim(batchPos) &&
4618 expr1.isFunctionOfDim(kPos)) ||
4619 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4620 }
4621 return isValid;
4622}
4623
4624void BatchMatmulOp::regionBuilder(
4625 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4626 function_ref<InFlightDiagnostic()> emitError) {
4627 if (emitError && block.getNumArguments() != 3) {
4628 emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4629 << block.getNumArguments();
4630 return;
4631 }
4632 assert(block.getNumArguments() == 3 &&
4633 "BatchMatmulOp regionBuilder expects 3 args");
4634 RegionBuilderHelper helper(b, block);
4635 SmallVector<Value> yields;
4636
4637 TypeFn castVal = TypeFn::cast_signed;
4638 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4639 return attr.getName() == "cast";
4640 });
4641 if (castIter != attrs.end()) {
4642 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4643 castVal = attr.getValue();
4644 }
4645
4646 auto toType = block.getArgument(2).getType();
4647 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4648 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4649 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4650 Value addVal =
4651 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
4652 yields.push_back(addVal);
4653 helper.yieldOutputs(yields);
4654}
4655
4656ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4657 SmallVector<Attribute, 3> indexingMapsAttr;
4658 Attribute mapAttr;
4659 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4660 if (parser.parseEqual())
4661 return failure();
4662
4663 if (parser.parseLSquare())
4664 return failure();
4665
4666 do {
4667 if (parser.parseAttribute(mapAttr))
4668 return failure();
4669 if (!isa<AffineMapAttr>(mapAttr)) {
4670 return parser.emitError(parser.getCurrentLocation(),
4671 "expected affine map attribute");
4672 }
4673 indexingMapsAttr.push_back(mapAttr);
4674
4675 if (parser.parseOptionalComma())
4676 break;
4677 } while (true);
4678
4679 if (parser.parseRSquare())
4680 return failure();
4681 }
4682 // Initialize indexingMaps, if not supplied explicitly.
4683 if (indexingMapsAttr.empty()) {
4684 indexingMapsAttr = llvm::map_to_vector(
4685 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4686 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4687 }
4688 result.addAttribute("indexing_maps",
4689 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4690
4691 return ::parseNamedStructuredOp(parser, result,
4692 BatchMatmulOp::getNumRegionArgs(),
4693 BatchMatmulOp::getRegionBuilder());
4694}
4695
4696void BatchMatmulOp::print(OpAsmPrinter &p) {
4697 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4698 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4699 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4700 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4701 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4702
4703 std::array<StringRef, 3> elidedAttrs = {
4704 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4705 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4706 elidedAttrs);
4707}
4708
4709/// Verify the user defined indexing maps.
4710LogicalResult BatchMatmulOp::verify() {
4711 // Verification of pure batch_matmul is handled by
4712 // verifyStructuredOpInterface().
4713 if (!hasUserDefinedMaps())
4714 return success();
4715
4716 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4718 return failure();
4719 }
4720 return success();
4721}
4722
4723LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4724 SmallVectorImpl<OpFoldResult> &) {
4725 return memref::foldMemRefCast(*this);
4726}
4727
4728void BatchMatmulOp::getEffects(
4729 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4730 &effects) {
4731 if (hasPureTensorSemantics())
4732 return;
4733 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4734}
4735
4736Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4737 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4738}
4739
4740//===----------------------------------------------------------------------===//
4741// ElementwiseOp
4742//===----------------------------------------------------------------------===//
4743//
4744namespace {
4745struct ArityGroupAndKind {
4746 // The enum class {Unary, Binary, Ternary, ..}
4747 ElementwiseArityGroup arityGroup;
4748
4749 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4750 union Kind {
4751 UnaryFn unaryFn;
4752 BinaryFn binaryFn;
4753 TernaryFn ternaryFn;
4754 } kind;
4755};
4756
4757unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4758 return static_cast<unsigned>(arityGroup);
4759}
4760} // namespace
4761
4762static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4763 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4764 constexpr int lastBinary =
4765 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4766 constexpr int lastTernary =
4767 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4768
4769 int val = static_cast<int>(kind);
4770 ArityGroupAndKind result;
4771
4772 if (val < lastUnary) {
4773 result.arityGroup = ElementwiseArityGroup::Unary;
4774 result.kind.unaryFn = static_cast<UnaryFn>(val);
4775 return result;
4776 }
4777 if (val < lastBinary) {
4778 result.arityGroup = ElementwiseArityGroup::Binary;
4779 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4780 return result;
4781 }
4782 if (val >= lastTernary) {
4783 llvm_unreachable("unhandled ElementwiseFn");
4784 }
4785 result.arityGroup = ElementwiseArityGroup::Ternary;
4786 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4787 return result;
4788}
4789
4790SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4791 auto rank = getResultRank();
4792 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4793}
4794
4796ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4797 MLIRContext *context) {
4798 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4799 return SmallVector<AffineMap>(numMaps, map);
4800}
4801
4802ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4803 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4804 Attribute attr;
4805 mlir::linalg::ElementwiseKind elemwiseKindVal;
4806 if (parser.parseKeyword("kind") || parser.parseEqual())
4807 return failure();
4808
4809 if (succeeded(parser.parseAttribute(attr))) {
4810 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4811 if (!elemwiseKindAttr)
4812 return parser.emitError(parser.getCurrentLocation(),
4813 "expected ElementwiseKind attribute");
4814 elemwiseKindVal = elemwiseKindAttr.getValue();
4815 } else {
4816 return parser.emitError(parser.getCurrentLocation(),
4817 "expected operation 'kind' attribute");
4818 }
4819 result.addAttribute(
4820 "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4821
4822 // Parse optional `indexing_maps`
4823 SmallVector<Attribute, 3> indexingMapsAttr;
4824 Attribute mapAttr;
4825 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4826 if (parser.parseEqual())
4827 return failure();
4828 if (parser.parseLSquare())
4829 return failure();
4830 do {
4831 if (parser.parseAttribute(mapAttr))
4832 return failure();
4833 if (!isa<AffineMapAttr>(mapAttr))
4834 return parser.emitError(parser.getCurrentLocation(),
4835 "expected affine map attribute");
4836 indexingMapsAttr.push_back(mapAttr);
4837 if (parser.parseOptionalComma())
4838 break;
4839 } while (true);
4840 if (parser.parseRSquare())
4841 return failure();
4842 }
4843 // At this stage of parsing the only way to infer number of region
4844 // args is through op kind, as input output tensors are not parsed yet.
4845 auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4846 int numRegionArgs =
4847 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4848 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4849 ElementwiseOp::getRegionBuilder())) {
4850 return parser.emitError(parser.getCurrentLocation(),
4851 "unable to parse elemwise op");
4852 }
4853
4854 // Initialize indexingMaps, if not supplied explicitly.
4855 if (indexingMapsAttr.empty()) {
4856 // We need to infer the numDims of the indexing maps from the output
4857 // type which is already parsed by now.
4858 auto resultType = result.operands[result.operands.size() - 1].getType();
4859 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4860 if (!shapedType)
4861 return parser.emitError(parser.getCurrentLocation(),
4862 "return type needs to be shaped type");
4863 auto numDims = shapedType.getRank();
4864 indexingMapsAttr = llvm::map_to_vector(
4865 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4866 parser.getContext()),
4867 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4868 }
4869
4870 result.addAttribute("indexing_maps",
4871 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4872 return success();
4873}
4874
4875void ElementwiseOp::print(OpAsmPrinter &p) {
4876 p << " kind=";
4877 p.printAttribute(getKindAttr());
4878 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4879 "indexing_maps"};
4880 unsigned arity =
4881 getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4882 unsigned numDims = getResultRank();
4883
4884 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4885 ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4886 getContext()),
4887 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4888
4889 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4890 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4891
4892 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4893 elidedAttrs);
4894}
4895
4896/// Implements the block region builder for the ElementwiseOp. This is called by
4897/// 'fillStructuredOpRegion'.
4898void ElementwiseOp::regionBuilder(
4899 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4900 function_ref<InFlightDiagnostic()> emitError) {
4901 ElementwiseKind elemwiseKind;
4902 for (auto attr : attrs) {
4903 if (attr.getName() == b.getStringAttr("kind")) {
4904 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4905 assert(kindAttr && "op kind attribute incorrectly set");
4906 elemwiseKind = kindAttr.getValue();
4907 break;
4908 }
4909 }
4910
4911 ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4912 auto arityGroup = groupAndKind.arityGroup;
4913 auto kind = groupAndKind.kind;
4914 if (emitError && block.getNumArguments() !=
4915 getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
4916 emitError() << "Elementwise regionBuilder expects "
4917 << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
4918 << block.getNumArguments();
4919 return;
4920 }
4921 assert(block.getNumArguments() ==
4922 getArityGroupAsUInt(arityGroup) + 1 /*output*/
4923 && "Elementwise regionBuilder number of block args mismatch");
4924
4925 RegionBuilderHelper helper(b, block);
4926 SmallVector<Value> yields;
4927 Value result;
4928
4929 if (arityGroup == ElementwiseArityGroup::Unary) {
4930 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4931
4932 } else if (arityGroup == ElementwiseArityGroup::Binary) {
4933 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4934 block.getArgument(1));
4935
4936 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4937 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4938 block.getArgument(1), block.getArgument(2));
4939
4940 } else {
4941 assert(false && "found unhandled category in elemwise");
4942 }
4943
4944 yields.push_back(result);
4945 helper.yieldOutputs(yields);
4946}
4947
4948LogicalResult ElementwiseOp::fold(FoldAdaptor,
4949 SmallVectorImpl<OpFoldResult> &) {
4950 return memref::foldMemRefCast(*this);
4951}
4952
4953void ElementwiseOp::getEffects(
4954 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4955 &effects) {
4956 if (hasPureTensorSemantics())
4957 return;
4958 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4959}
4960
4961Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4962 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4963}
4964
4965//===----------------------------------------------------------------------===//
4966// PackOp/UnPackOp Common
4967//===----------------------------------------------------------------------===//
4968
4969template <typename OpTy, typename>
4970SmallVector<int64_t>
4972 RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4973 ? packOrUnPack.getDestType()
4974 : packOrUnPack.getSourceType();
4975 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4976 ? packOrUnPack.getSourceType()
4977 : packOrUnPack.getDestType();
4979 packedType.getShape().take_front(unpackedType.getRank()));
4980 if (!packOrUnPack.getOuterDimsPerm().empty()) {
4982 result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
4983 }
4984 return result;
4985}
4990
4991// Given the (potentially) updated packed type, `newPackedTy`, generates an
4992// updated mixed-tile-sizes attribute. A tile size is updated only
4993// when:
4994// * a dim from newPackedTy is static, and
4995// * the corresponding size from mixedTiles is still dynamic.
4996// Otherwise, the original tile size is preserved.
4997// Note - packed-type-dim and mixed-tile-size should always match!
5000 SmallVector<OpFoldResult> mixedTiles) {
5001 SmallVector<OpFoldResult> newMixedTileSizes;
5002 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5003 .getShape()
5004 .take_back(mixedTiles.size()),
5005 mixedTiles)) {
5006 int64_t shape = std::get<0>(it);
5007 if (shape == ShapedType::kDynamic) {
5008 newMixedTileSizes.push_back(std::get<1>(it));
5009 continue;
5010 }
5011
5012 // If the current result dim is static, update the dynamic mixed-size
5013 // (provided the original value is dynamic).
5014 OpFoldResult tile = std::get<1>(it);
5015 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
5016 // Already a constant
5017 newMixedTileSizes.push_back(tile);
5018 } else {
5019 assert(getConstantIntValue(tile).value() == shape &&
5020 "tile size and dim size don't match!");
5021 newMixedTileSizes.push_back(
5022 (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
5023 }
5024 }
5025
5026 return newMixedTileSizes;
5027}
5028
5029template <typename OpTy>
5030static LogicalResult
5032 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5033 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5034 "applies to only pack or unpack operations");
5035 int64_t destRank = op.getDestRank();
5036 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
5037 reifiedReturnShapes[0] =
5038 tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
5039 return success();
5040}
5041
5042template <typename OpTy>
5044 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5045 "applies to only pack or unpack operations");
5046 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
5047 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
5048 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
5049 assert(tiles.size() == dimsToTile.size() &&
5050 "tiles must match indices of dimension to block");
5051 // bind the dimension `i` with the tile factor.
5052 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5053 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5054 return dimAndTileMapping;
5055}
5056
5057template <typename OpTy>
5059 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5060 "applies to only pack or unpack operations");
5061 Builder builder(op);
5062 SmallVector<OpFoldResult> mixedInnerTiles;
5063 unsigned dynamicValIndex = 0;
5064 for (int64_t staticTile : op.getStaticInnerTiles()) {
5065 if (ShapedType::isStatic(staticTile))
5066 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
5067 else
5068 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5069 }
5070 return mixedInnerTiles;
5071}
5072
5073template <typename OpTy>
5075 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5076 "applies to only pack or unpack operations");
5077 SmallVector<Value> dynamicTiles;
5078 SmallVector<int64_t> staticTiles;
5079 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
5080 return staticTiles;
5081}
5082
5083/// Returns true if `dimsPos` is invalid. It is invalid when:
5084/// a) It contains duplicate.
5085/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
5086/// c) The number of elements in `dimsPos` is > than `rank`.
5088 size_t rank) {
5089 size_t dimsPosSize = dimsPos.size();
5090 if (dimsPosSize > rank)
5091 return true;
5092 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
5093 if (dimsPosSize != uniqued.size())
5094 return true;
5095 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5096 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
5097 });
5098}
5099
5100template <typename OpTy>
5101static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
5102 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5103 "applies to only pack or unpack operations");
5104 Operation *op = packOrUnPack.getOperation();
5105
5106 // Return true if we have a zero-value tile.
5107 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
5108 return llvm::any_of(tiles, isZeroInteger);
5109 };
5110
5111 // Verify tiles. Do not allow zero tiles.
5112 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
5113 if (hasZeros(mixedTiles))
5114 return op->emitError("invalid zero tile factor");
5115
5116 // Verify inner_dims_pos and outer_dims_perm.
5117 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
5118 ? packOrUnPack.getSourceType()
5119 : packOrUnPack.getDestType();
5120 size_t unpackedRank = unpackedType.getRank();
5121 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
5122 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
5123 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
5124 return op->emitError("invalid inner_dims_pos vector");
5125 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
5126 return op->emitError("invalid outer_dims_perm vector");
5127 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5128 return op->emitError("outer_dims_perm must be a permutation or empty");
5129
5130 // Tiling factors must be less than or equal to the input rank for pack (or
5131 // output rank for unpack), and must match the number of `inner_dims_pos`.
5132 if (mixedTiles.size() > unpackedRank) {
5133 return op->emitError("tiling factors must be less than or equal to the "
5134 "input rank for pack or output rank for unpack");
5135 }
5136 if (mixedTiles.size() != innerDimsPos.size()) {
5137 return op->emitError(
5138 "tiling factors must equal the number of dimensions to tile");
5139 }
5140
5141 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5142 ? packOrUnPack.getDestType()
5143 : packOrUnPack.getSourceType();
5144 size_t packedRank = packedType.getRank();
5145 // Require output rank to match input rank + number of blocking factors.
5146 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5147 if (expectedPackedRank != packedRank) {
5148 return op->emitError(
5149 "packed rank != (unpacked rank + num tiling factors), got ")
5150 << packedRank << " != " << expectedPackedRank;
5151 }
5152
5153 // Verify result shape is greater than the minimum expected
5154 // by the pack operation, and that the output shape
5155 // represents full tiles.
5156 RankedTensorType expectedPackedType = PackOp::inferPackedType(
5157 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
5158 if (!llvm::all_of(
5159 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
5160 mixedTiles),
5161 [](std::tuple<int64_t, OpFoldResult> it) {
5162 int64_t shape = std::get<0>(it);
5163 if (Attribute attr =
5164 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
5165 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5166 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5167 return shape == staticTileSize;
5168 }
5169 return ShapedType::isDynamic(shape);
5170 })) {
5171 return op->emitError("mismatch in inner tile sizes specified and shaped of "
5172 "tiled dimension in the packed type");
5173 }
5174 if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
5175 packedType.getShape()))) {
5176 return op->emitError("expected ")
5177 << expectedPackedType << " for the packed domain value, got "
5178 << packedType;
5179 }
5180 return success();
5181}
5182
5183namespace {
5184/// Subset of PackOp/UnPackOp fields used to compute the result of applying
5185/// various permutations to the op.
5186// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
5187// these. These may or may not become true foldings / canonicalizations
5188// depending on how aggressive we want to be in automatically folding
5189// transposes.
5190struct PackOrUnPackTransposeResult {
5191 SmallVector<int64_t> innerDimsPos;
5192 SmallVector<OpFoldResult> innerTiles;
5193 SmallVector<int64_t> outerDimsPerm;
5194};
5195} // namespace
5196
5197template <typename OpTy>
5198static PackOrUnPackTransposeResult
5200 ArrayRef<int64_t> innerPermutation,
5201 ArrayRef<int64_t> outerPermutation) {
5202 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5203 "applies to only pack or unpack operations");
5204 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5205 "some permutation must be non-empty");
5206 PackOrUnPackTransposeResult metadata;
5207 metadata.innerDimsPos =
5208 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
5209 metadata.innerTiles =
5210 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
5211 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5212 ? packOrUnPackOp.getSourceRank()
5213 : packOrUnPackOp.getDestRank();
5214 metadata.outerDimsPerm =
5215 packOrUnPackOp.getOuterDimsPerm().empty()
5216 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5217 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
5218 if (!innerPermutation.empty()) {
5219 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5220 isPermutationVector(innerPermutation) &&
5221 "invalid inner permutation");
5222 applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
5223 applyPermutationToVector(metadata.innerTiles, innerPermutation);
5224 }
5225 if (!outerPermutation.empty()) {
5226 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5227 isPermutationVector(outerPermutation) &&
5228 "invalid outer permutation");
5229 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
5230 }
5231 return metadata;
5232}
5233
5234//===----------------------------------------------------------------------===//
5235// PackOp
5236//===----------------------------------------------------------------------===//
5237
5238void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
5239 setNameFn(getResult(), "pack");
5240}
5241
5242void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5243 Value dest, ArrayRef<int64_t> innerDimsPos,
5244 ArrayRef<OpFoldResult> innerTiles,
5245 std::optional<Value> paddingValue,
5246 ArrayRef<int64_t> outerDimsPerm) {
5247 assert(innerDimsPos.size() == innerTiles.size() &&
5248 "number of tile sizes specified must match the specified number of "
5249 "original dimensions to be tiled");
5250 SmallVector<int64_t> staticTileSizes;
5251 SmallVector<Value> dynamicTileSizes;
5252 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5253 build(builder, state, dest.getType(), source, dest,
5254 paddingValue ? *paddingValue : nullptr,
5255 outerDimsPerm.empty() ? nullptr
5256 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5257 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5258 builder.getDenseI64ArrayAttr(staticTileSizes));
5259}
5260
5261LogicalResult
5262PackOp::reifyResultShapes(OpBuilder &builder,
5263 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5264 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5265}
5266
5267DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
5268 return getDimAndTileMappingImpl(*this);
5269}
5270
5271SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5272 return getMixedTilesImpl(*this);
5273}
5274
5275SmallVector<int64_t> PackOp::getStaticTiles() {
5276 return getStaticTilesImpl(*this);
5277}
5278
5279ArrayRef<int64_t> PackOp::getAllOuterDims() {
5280 ShapedType inputType = getSourceType();
5281 int64_t inputRank = inputType.getRank();
5282 return getDestType().getShape().take_front(inputRank);
5283}
5284
5285SmallVector<int64_t> PackOp::getTiledOuterDims() {
5286 auto innerDimsPos = getInnerDimsPos();
5287 SmallVector<int64_t> outerDims(getAllOuterDims());
5289
5290 // Recover the original order of the outer dims.
5291 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5292 invertPermutationVector(outerDimPermInv);
5293 if (!outerDimPermInv.empty())
5294 applyPermutationToVector(outerDims, outerDimPermInv);
5295
5296 // Collect the outer dims corresponding to the tilled inner dims.
5297 for (auto index : innerDimsPos)
5298 res.push_back(outerDims[index]);
5299
5300 return res;
5301}
5302
5303bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5304 ArrayRef<int64_t> innerDimsPos,
5305 ArrayRef<int64_t> outputShape,
5306 ArrayRef<int64_t> outerDimsPerm,
5307 ArrayRef<OpFoldResult> innerTiles) {
5308 SmallVector<int64_t> outputTileSizes(
5309 outputShape.take_front(inputShape.size()));
5310 if (!outerDimsPerm.empty()) {
5311 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5312 "expected output and outer_dims_perm to have same size");
5313 applyPermutationToVector(outputTileSizes,
5314 invertPermutationVector(outerDimsPerm));
5315 }
5316 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5317 if (ShapedType::isDynamic(inputShape[pos]))
5318 continue;
5319 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5320
5321 if (!constantTile) {
5322 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5323 (inputShape[pos] % outputTileSizes[pos] != 0))
5324 return true;
5325 } else if (inputShape[pos] % (*constantTile) != 0) {
5326 return true;
5327 }
5328 }
5329 return false;
5330}
5331
5332bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5333 ArrayRef<int64_t> innerDimsPos,
5334 ArrayRef<int64_t> outputShape,
5335 ArrayRef<int64_t> outerDimsPerm,
5336 ArrayRef<OpFoldResult> innerTiles) {
5337 SmallVector<int64_t> outputTileSizes(
5338 outputShape.take_front(inputShape.size()));
5339 if (!outerDimsPerm.empty()) {
5340 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5341 "expected output and outer_dims_perm to have same size");
5342 applyPermutationToVector(outputTileSizes,
5343 invertPermutationVector(outerDimsPerm));
5344 }
5345 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5346 if (ShapedType::isDynamic(inputShape[pos]) ||
5347 ShapedType::isDynamic(outputTileSizes[pos]))
5348 return true;
5349 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5350 if (!constantTile)
5351 return true;
5352 if (inputShape[pos] % (*constantTile) != 0)
5353 return true;
5354 }
5355 return false;
5356}
5357
5358LogicalResult PackOp::verify() {
5360 return failure();
5361
5362 // Verify padding value, and bail out if the tile does not divide the
5363 // dimension fully. In the case of dynamic tile factors or dimensions, having
5364 // a partial tile is undefined behavior.
5365 auto paddingValue = getPaddingValue();
5366 if (paddingValue &&
5367 paddingValue.getType() != getSourceType().getElementType()) {
5368 return emitOpError("expected padding_value has ")
5369 << getSourceType().getElementType()
5370 << " but got: " << paddingValue.getType();
5371 }
5372
5373 if (!paddingValue &&
5374 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
5375 getDestType().getShape(), getOuterDimsPerm(),
5376 getMixedTiles())) {
5377 return emitOpError(
5378 "invalid tile factor or output size provided. Only full tiles are "
5379 "supported when padding_value is not set");
5380 }
5381 return success();
5382}
5383
5384/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
5385/// Value's to kDynamic, even if they are arith.constant values.
5389 for (auto o : ofrs) {
5390 // Have to do this first, as getConstantIntValue special-cases constants.
5391 if (llvm::dyn_cast_if_present<Value>(o))
5392 result.push_back(ShapedType::kDynamic);
5393 else
5394 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
5395 }
5396 return result;
5397}
5398
5399/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
5400/// the packed type. Having a shared helper helps implement these two methods in
5401/// a way that ensures that they agree on which dimensions are dynamic.
5403 ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
5404 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5405 SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
5406 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5407 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5408 continue;
5409 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5410 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5411 continue;
5412 }
5413 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5414 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5415 }
5416
5417 // Swap tile loops if outer_dims_perm is available.
5418 if (!outerDimsPerm.empty())
5419 applyPermutationToVector(resultShape, outerDimsPerm);
5420
5421 // Append the inner tile dimensions.
5422 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5423 return resultShape;
5424}
5425
5426SmallVector<OpFoldResult> PackOp::getResultShape(
5427 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5428 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5429 ArrayRef<int64_t> outerDimsPerm) {
5430 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5431
5432 AffineExpr s0, s1;
5433 bindSymbols(builder.getContext(), s0, s1);
5434 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
5435 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5436 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
5437 builder, loc, ceilDivExpr,
5438 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5439 }
5440 if (!outerDimsPerm.empty())
5441 applyPermutationToVector(resultDims, outerDimsPerm);
5442 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5443
5444 SmallVector<int64_t> resultTypeShape =
5446 asShapeWithAnyValueAsDynamic(innerTileSizes),
5447 innerDimsPos, outerDimsPerm);
5448
5449 // Fix-up `resultDims` to ensure that they are Value's if and only if the
5450 // result type shape says it's a dynamic dim. This is needed as callers may
5451 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
5452 // dynamic dims returned by that.
5453 for (unsigned i = 0; i < resultDims.size(); ++i) {
5454 if (ShapedType::isStatic(resultTypeShape[i]))
5455 continue;
5456 resultDims[i] =
5457 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
5458 }
5459
5460 return resultDims;
5461}
5462
5463/// Get the expected packed type based on source type, tile factors, position of
5464/// the inner tiles and permutation of the outer tiled loop.
5465RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
5466 ArrayRef<int64_t> innerTileSizes,
5467 ArrayRef<int64_t> innerDimsPos,
5468 ArrayRef<int64_t> outerDimsPerm) {
5470 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5471 return RankedTensorType::get(resultShape, sourceType.getElementType());
5472}
5473
5474Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
5475 ArrayRef<OpFoldResult> innerTileSizes,
5476 ArrayRef<int64_t> innerDimsPos,
5477 ArrayRef<int64_t> outerDimsPerm) {
5478 AffineExpr dim0, dim1;
5479 bindDims(b.getContext(), dim0, dim1);
5480 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5481 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
5482 {v1, v2});
5483 };
5484
5485 SmallVector<OpFoldResult> mixedSizes;
5486 for (auto [index, value] : llvm::enumerate(
5487 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
5488 if (ShapedType::isDynamic(value))
5489 mixedSizes.push_back(
5490 tensor::DimOp::create(b, loc, source, index).getResult());
5491 else
5492 mixedSizes.push_back(b.getIndexAttr(value));
5493 }
5494 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5495 int64_t dimPos = std::get<0>(it);
5496 OpFoldResult tileSize = std::get<1>(it);
5497 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5498 }
5499 if (!outerDimsPerm.empty())
5500 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5501
5502 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5503 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5504 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5505}
5506
5507PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5508 ArrayRef<int64_t> innerPermutation,
5509 ArrayRef<int64_t> outerPermutation) {
5510 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5511 *this, innerPermutation, outerPermutation);
5512 Value transposedDest =
5513 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5514 metadata.innerDimsPos, metadata.outerDimsPerm);
5515 return PackOp::create(b, loc, getSource(), transposedDest,
5516 metadata.innerDimsPos, metadata.innerTiles,
5517 getPaddingValue(), metadata.outerDimsPerm);
5518}
5519
5520/// Returns true if the tiles and the tiled dims are constant.
5521template <typename OpTy>
5523 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5524 "applies to only pack or unpack operations");
5525 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5526 ? op.getDestType()
5527 : op.getSourceType();
5528 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5529 for (auto [dimDest, tile] : llvm::zip(
5530 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5531 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5532 if (!constTileSize || ShapedType::isDynamic(dimDest))
5533 return false;
5534 }
5535 return true;
5536}
5537
5538Speculation::Speculatability PackOp::getSpeculatability() {
5539 if (getPaddingValue())
5541
5542 // The verifier rejects already operations if we can statically prove that the
5543 // sizes of the tiles do not divide perfectly the dimension; thus, check only
5544 // to have constant tiles and tiled inner dimensions.
5547
5549}
5550
5551// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5552// dimensions for pack and unpack.
5553static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5554 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5555 return false;
5556 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5557 return true;
5558 // Outer dims permutation is optional.
5559 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5560 // identity permutation.
5561 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5562 isIdentityPermutation(unPackOp.getOuterDimsPerm());
5563}
5564
5565// Return true if pack and unpack have the same tiles.
5566// Same SSA values or same integer constants.
5567static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5568 auto packTiles = packOp.getMixedTiles();
5569 auto unPackTiles = unPackOp.getMixedTiles();
5570 if (packTiles.size() != unPackTiles.size())
5571 return false;
5572 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5573 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5574 return false;
5575 }
5576 return true;
5577}
5578
5579/// Returns true if the pack op does not need a padding value.
5580static bool paddingIsNotNeeded(PackOp op) {
5581 auto srcType = op.getSourceType();
5582 if (llvm::any_of(op.getInnerDimsPos(),
5583 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5584 return false;
5585 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5586 return false;
5587 return !PackOp::requirePaddingValue(
5588 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5589 op.getOuterDimsPerm(), op.getMixedTiles());
5590}
5591
5592/// Returns true if the `srcShape` or `destShape` is different from the one in
5593/// `packOp` and populates each with the inferred static shape.
5594static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5595 SmallVectorImpl<int64_t> &destShape) {
5596 bool changeNeeded = false;
5597 srcShape.assign(packOp.getSourceType().getShape().begin(),
5598 packOp.getSourceType().getShape().end());
5599 destShape.assign(packOp.getDestType().getShape().begin(),
5600 packOp.getDestType().getShape().end());
5601 llvm::SmallSetVector<int64_t, 4> innerDims;
5602 innerDims.insert_range(packOp.getInnerDimsPos());
5603 SmallVector<int64_t> inverseOuterDimsPerm;
5604 if (!packOp.getOuterDimsPerm().empty())
5605 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5606 int srcRank = packOp.getSourceRank();
5607 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5608 if (innerDims.contains(i))
5609 continue;
5610 int64_t srcPos = i;
5611 int64_t destPos = i;
5612 if (!inverseOuterDimsPerm.empty())
5613 destPos = inverseOuterDimsPerm[srcPos];
5614 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5615 ShapedType::isDynamic(destShape[destPos])) {
5616 continue;
5617 }
5618 int64_t size = srcShape[srcPos];
5619 if (ShapedType::isDynamic(size))
5620 size = destShape[destPos];
5621 srcShape[srcPos] = size;
5622 destShape[destPos] = size;
5623 changeNeeded = true;
5624 }
5625 return changeNeeded;
5626}
5627
5628LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5629 // Fold an pack(unpack(x)) to x.
5630 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5631 if (unPackOp.getSourceType() == packOp.getDestType() &&
5632 !packOp.getPaddingValue() &&
5633 hasSameInnerOuterAttribute(packOp, unPackOp) &&
5634 haveSameTiles(packOp, unPackOp)) {
5635 rewriter.replaceOp(packOp, unPackOp.getSource());
5636 return success();
5637 }
5638 }
5639
5640 // Fold optional PaddingValue operand away if padding is not needed.
5641 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5642 rewriter.startOpModification(packOp);
5643 packOp.getPaddingValueMutable().clear();
5644 rewriter.finalizeOpModification(packOp);
5645 return success();
5646 }
5647
5648 // Insert tensor.cast ops if static shape inference is available..
5649 SmallVector<int64_t> srcShape, destShape;
5650 if (inferStaticShape(packOp, srcShape, destShape)) {
5651 Location loc = packOp.getLoc();
5652 Value source = packOp.getSource();
5653 if (srcShape != packOp.getSourceType().getShape()) {
5654 auto newSrcType = packOp.getSourceType().clone(srcShape);
5655 source =
5656 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5657 }
5658 Value dest = packOp.getDest();
5659 RankedTensorType originalResultType = packOp.getDestType();
5660 bool needUpdateDestType = (destShape != originalResultType.getShape());
5661 if (needUpdateDestType) {
5662 auto newDestType = packOp.getDestType().clone(destShape);
5663 dest =
5664 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5665 }
5666 rewriter.modifyOpInPlace(packOp, [&] {
5667 packOp.getSourceMutable().assign(source);
5668 packOp.getDestMutable().assign(dest);
5669 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5670 });
5671 // Insert a cast if needed
5672 if (needUpdateDestType) {
5673 rewriter.setInsertionPointAfter(packOp);
5674 auto castOp =
5675 tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
5676 rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
5677 }
5678 return success();
5679 }
5680
5681 return failure();
5682}
5683
5684template <typename PackOrUnpackOp>
5685static bool isLikePadUnPad(PackOrUnpackOp packOp,
5686 RankedTensorType packedTensorType) {
5687 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5688 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5689 "Function meant for pack/unpack");
5690 // This is a pad if packing only adds ones and we don't transpose dimensions.
5691
5692 // Check that we are not transposing any dimensions.
5693 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5694 int64_t numPackedDims = innerDimsPos.size();
5695 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5696 if (orderedDims != innerDimsPos) {
5697 // Dimensions don't happen in order.
5698 return false;
5699 }
5700
5701 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5702 int64_t packedRank = packedTensorType.getRank();
5703 // At this point we know that we are taking numPackedDims outer
5704 // dimensions and pushing them all the way as the inner most dimensions.
5705 // What's left on the outer most dimensions is, in this order:
5706 // - the factor of the packed dimensions, then
5707 // - the untouched dimensions
5708 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5709 // if all the dimensions that bubble outerward are ones.
5710 // Therefore check that all the dimensions but the numPackedDims inner most
5711 // ones are ones.
5712 return llvm::all_of(
5713 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5714 [&packedShape](int64_t i) { return packedShape[i] == 1; });
5715}
5716
5717bool PackOp::isLikePad() {
5718 auto packedTensorType =
5719 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5720 return isLikePadUnPad(*this, packedTensorType);
5721}
5722
5723OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5724 std::optional<Attribute> paddingValue;
5725 if (auto pad = adaptor.getPaddingValue())
5726 paddingValue = pad;
5727 if (OpFoldResult reshapedSource = reshapeConstantSource(
5728 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5729 getDestType(), paddingValue))
5730 return reshapedSource;
5731 return {};
5732}
5733
5734/// Folds a tensor.cast op into a consuming PackOp op if the
5735/// `tensor.cast` has source that is more static than the consuming op.
5736///
5737/// Example:
5738/// ```mlir
5739/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5740/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5741/// ```
5742///
5743/// folds into:
5744///
5745/// ```mlir
5746/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5747/// ```
5750
5751 LogicalResult matchAndRewrite(PackOp op,
5752 PatternRewriter &rewriter) const override {
5754 return failure();
5755
5756 SmallVector<Type> newResultTypes(op->getResultTypes());
5757 SmallVector<Value> newOperands =
5759
5760 // Get the updated mixed-tile-sizes attribute.
5761 SmallVector<OpFoldResult> newMixedTileSizes =
5762 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
5763
5764 // Clone op.
5765 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5766 // this point. However, in practice, we use them for things that we'd like
5767 // to preserve. Implement a better abstraction.
5768 PackOp newOp =
5769 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
5770 op.getInnerDimsPos(), newMixedTileSizes,
5771 op.getPaddingValue(), op.getOuterDimsPerm());
5772 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5773
5774 // Replace op.
5775 Value oldResult = op.getResult();
5776 Value newResult = newOp.getResult();
5778 (newResult.getType() != oldResult.getType())
5779 ? tensor::CastOp::create(rewriter, op->getLoc(),
5780 oldResult.getType(), newResult)
5781 : newResult;
5782
5783 rewriter.replaceOp(op, {replacement});
5784
5785 return success();
5786 }
5787};
5788
5789//===----------------------------------------------------------------------===//
5790// UnPackOp
5791//===----------------------------------------------------------------------===//
5792
5793void UnPackOp::getAsmResultNames(
5794 function_ref<void(Value, StringRef)> setNameFn) {
5795 setNameFn(getResult(), "unpack");
5796}
5797
5798LogicalResult
5799UnPackOp::reifyResultShapes(OpBuilder &builder,
5800 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5801 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5802}
5803
5804DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
5805 return getDimAndTileMappingImpl(*this);
5806}
5807
5808SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
5809 return getMixedTilesImpl(*this);
5810}
5811
5812SmallVector<int64_t> UnPackOp::getStaticTiles() {
5813 return getStaticTilesImpl(*this);
5814}
5815
5816ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
5817 ShapedType destType = getDestType();
5818 int64_t destRank = destType.getRank();
5819 return getSourceType().getShape().take_front(destRank);
5820}
5821
5822SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
5823 auto innerDimsPos = getInnerDimsPos();
5824 SmallVector<int64_t> outerDims(getAllOuterDims());
5826
5827 // Recover the original order of the outer dims.
5828 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5829 invertPermutationVector(outerDimPermInv);
5830 if (!outerDimPermInv.empty())
5831 applyPermutationToVector(outerDims, outerDimPermInv);
5832
5833 // Collect the outer dims corresponding to the tilled inner dims.
5834 for (auto index : innerDimsPos)
5835 res.push_back(outerDims[index]);
5836
5837 return res;
5838}
5839
5840LogicalResult UnPackOp::verify() {
5841 return commonVerifierPackAndUnPackOp(*this);
5842}
5843
5844Speculation::Speculatability UnPackOp::getSpeculatability() {
5845 // See PackOp::getSpeculatability.
5848
5850}
5851
5852void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
5853 Value dest, ArrayRef<int64_t> innerDimsPos,
5854 ArrayRef<OpFoldResult> innerTiles,
5855 ArrayRef<int64_t> outerDimsPerm) {
5856 assert(innerDimsPos.size() == innerTiles.size() &&
5857 "number of tile sizes specified must match the specified number of "
5858 "original dimensions to be tiled");
5859 SmallVector<int64_t> staticTileSizes;
5860 SmallVector<Value> dynamicTileSizes;
5861 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5862 build(builder, state, dest.getType(), source, dest,
5863 outerDimsPerm.empty() ? nullptr
5864 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5865 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5866 builder.getDenseI64ArrayAttr(staticTileSizes));
5867}
5868
5869Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
5870 Value source,
5871 ArrayRef<OpFoldResult> innerTileSizes,
5872 ArrayRef<int64_t> innerDimsPos,
5873 ArrayRef<int64_t> outerDimsPerm) {
5874 AffineExpr sym0, sym1;
5875 bindSymbols(b.getContext(), sym0, sym1);
5876 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5877 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
5878 };
5879
5880 SmallVector<OpFoldResult> mixedSizes;
5881 auto srcType = llvm::cast<RankedTensorType>(source.getType());
5882 for (auto i :
5883 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5884 if (srcType.isDynamicDim(i))
5885 mixedSizes.push_back(
5886 tensor::DimOp::create(b, loc, source, i).getResult());
5887 else
5888 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
5889 }
5890 if (!outerDimsPerm.empty()) {
5892 mixedSizes, invertPermutationVector(outerDimsPerm));
5893 }
5894
5895 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5896 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5897
5898 auto elemType = srcType.getElementType();
5899 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5900}
5901
5902UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5903 Value transposedSource,
5904 ArrayRef<int64_t> innerPermutation,
5905 ArrayRef<int64_t> outerPermutation) {
5906 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5907 *this, innerPermutation, outerPermutation);
5908 return UnPackOp::create(b, loc, transposedSource, getDest(),
5909 metadata.innerDimsPos, metadata.innerTiles,
5910 metadata.outerDimsPerm);
5911}
5912
5913/// Returns true if the `srcShape` or `destShape` is different from the one in
5914/// `op` and populates each with the inferred static shape.
5915static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
5916 SmallVectorImpl<int64_t> &destShape) {
5917 bool changeNeeded = false;
5918 srcShape.assign(op.getSourceType().getShape().begin(),
5919 op.getSourceType().getShape().end());
5920 destShape.assign(op.getDestType().getShape().begin(),
5921 op.getDestType().getShape().end());
5922 llvm::SmallSetVector<int64_t, 4> innerDims;
5923 innerDims.insert_range(op.getInnerDimsPos());
5924 SmallVector<int64_t> inverseOuterDimsPerm;
5925 if (!op.getOuterDimsPerm().empty())
5926 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
5927 int destRank = op.getDestRank();
5928 for (auto i : llvm::seq<int64_t>(0, destRank)) {
5929 if (innerDims.contains(i))
5930 continue;
5931 int64_t srcPos = i;
5932 int64_t destPos = i;
5933 if (!inverseOuterDimsPerm.empty())
5934 srcPos = inverseOuterDimsPerm[destPos];
5935 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5936 ShapedType::isDynamic(destShape[destPos])) {
5937 continue;
5938 }
5939 int64_t size = srcShape[srcPos];
5940 if (ShapedType::isDynamic(size))
5941 size = destShape[destPos];
5942 srcShape[srcPos] = size;
5943 destShape[destPos] = size;
5944 changeNeeded = true;
5945 }
5946 return changeNeeded;
5947}
5948
5949LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5950 PatternRewriter &rewriter) {
5951 /// unpack(pack(x)) -> x
5952 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5953 if (packOp.getSourceType() != unPackOp.getDestType())
5954 return failure();
5955 if (packOp.getPaddingValue() ||
5956 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5957 !haveSameTiles(packOp, unPackOp))
5958 return failure();
5959 rewriter.replaceOp(unPackOp, packOp.getSource());
5960 return success();
5961 }
5962 /// unpack(destinationStyleOp(x)) -> unpack(x)
5963 if (auto dstStyleOp =
5964 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5965 auto destValue = cast<OpResult>(unPackOp.getDest());
5966 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5967 rewriter.modifyOpInPlace(unPackOp,
5968 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5969 return success();
5970 }
5971 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
5972 if (unPackOp->hasOneUse()) {
5973 auto extractSliceUser =
5974 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5975 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
5976 OpBuilder::InsertionGuard g(rewriter);
5977 rewriter.setInsertionPoint(unPackOp);
5978 auto newDest = tensor::ExtractSliceOp::create(
5979 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
5980 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5981 extractSliceUser.getMixedStrides());
5982 rewriter.modifyOpInPlace(unPackOp, [&]() {
5983 unPackOp.setDpsInitOperand(0, newDest);
5984 unPackOp.getResult().setType(newDest.getType());
5985 });
5986 rewriter.replaceOp(extractSliceUser, unPackOp);
5987 return success();
5988 }
5989 }
5990
5991 // Insert tensor.cast ops if static shape inference is available..
5992 SmallVector<int64_t> srcShape, destShape;
5993 if (inferStaticShape(unPackOp, srcShape, destShape)) {
5994 Location loc = unPackOp.getLoc();
5995 Value source = unPackOp.getSource();
5996 if (srcShape != unPackOp.getSourceType().getShape()) {
5997 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5998 source = tensor::CastOp::create(rewriter, loc, newSrcType,
5999 unPackOp.getSource());
6000 }
6001 Value dest = unPackOp.getDest();
6002 if (destShape != unPackOp.getDestType().getShape()) {
6003 auto newDestType = unPackOp.getDestType().clone(destShape);
6004 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6005 unPackOp.getDest());
6006 }
6007 Value newOp = UnPackOp::create(
6008 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6009 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6010 rewriter.replaceOpWithNewOp<tensor::CastOp>(
6011 unPackOp, unPackOp.getResult().getType(), newOp);
6012 return success();
6013 }
6014
6015 return failure();
6016}
6017
6018bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6019 // Rank-reduced folding is not supported.
6020 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6021 return false;
6022 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
6023 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
6024 return false;
6025 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6026 SmallVector<int64_t> outerShapeWithoutTranspose =
6028 for (auto [pos, tileSize] :
6029 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6030 if (unpackedTypeAfterFold.isDynamicDim(pos))
6031 return false;
6032 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6033 return false;
6034 if (ShapedType::isDynamic(tileSize))
6035 return false;
6036 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6037 unpackedTypeAfterFold.getDimSize(pos);
6038 if (paddingSize >= tileSize)
6039 return false;
6040 }
6041 return true;
6042}
6043
6044bool UnPackOp::isLikeUnPad() {
6045 RankedTensorType packedTensorType = getSourceType();
6046 return isLikePadUnPad(*this, packedTensorType);
6047}
6048
6049OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
6050 if (OpFoldResult reshapedSource = reshapeConstantSource(
6051 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6052 getResult().getType()))
6053 return reshapedSource;
6054 return {};
6055}
6056
6057/// Folds a tensor.cast op into a consuming UnPackOp op if the
6058/// `tensor.cast` has source that is more static than the consuming op.
6059///
6060/// Example:
6061/// ```mlir
6062/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
6063/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
6064/// ```
6065///
6066/// folds into:
6067///
6068/// ```mlir
6069/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
6070/// ```
6071struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
6072 using OpRewritePattern<UnPackOp>::OpRewritePattern;
6073
6074 LogicalResult matchAndRewrite(UnPackOp op,
6075 PatternRewriter &rewriter) const override {
6077 return failure();
6078
6079 SmallVector<Type> newResultTypes(op->getResultTypes());
6080 SmallVector<Value> newOperands =
6082 Value sourceTensor = newOperands[0];
6083
6084 // Get the updated mixed-tile-sizes attribute.
6086 rewriter, sourceTensor.getType(), op.getMixedTiles());
6087
6088 // Clone op.
6089 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6090 // this point. However, in practice, we use them for things that we'd like
6091 // to preserve. Implement a better abstraction.
6092 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6093 newOperands[1], op.getInnerDimsPos(),
6094 newMixedTileSizes, op.getOuterDimsPerm());
6095 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6096
6097 // Replace op.
6098 Value oldResult = op.getResult();
6099 Value newResult = newOp.getResult();
6101 (newResult.getType() != oldResult.getType())
6102 ? tensor::CastOp::create(rewriter, op->getLoc(),
6103 oldResult.getType(), newResult)
6104 : newResult;
6105
6106 rewriter.replaceOp(op, {replacement});
6107
6108 return success();
6109 }
6110};
6111
6112//===----------------------------------------------------------------------===//
6113// BatchReduceMatmulOp
6114//===----------------------------------------------------------------------===//
6115SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
6117 utils::IteratorType::reduction, utils::IteratorType::parallel,
6118 utils::IteratorType::parallel, utils::IteratorType::reduction};
6119}
6120
6122BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6123 AffineExpr d0, d1, d2, d3;
6124 SmallVector<AffineMap> indexingMaps;
6125 bindDims(context, d0, d1, d2, d3);
6126 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
6127 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
6128 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
6129 return indexingMaps;
6130}
6131
6132bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6133 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6134 if (!maps)
6135 return false;
6136 if (maps.size() != 3)
6137 return false;
6138 auto positions = getAffineResultPositions(maps);
6139 if (failed(positions))
6140 return false;
6141 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6142 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6143 (*positions)[2] == SmallVector<int64_t>{1, 2};
6144}
6145unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
6146
6147std::string BatchReduceMatmulOp::getLibraryCallName() {
6148 return generateLibraryCallName(getOperation());
6149}
6150
6151/// Check if the op has broadcast and/or transpose semantic. Returns true if
6152/// the user defined indexing maps are not equal to default map.
6153bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6154 SmallVector<AffineMap, 3> defaultMaps =
6155 getDefaultIndexingMaps(this->getContext());
6156 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6157 return defaultMaps != explicitMaps;
6158}
6159
6160/// Returns true if the given bcastMap map is a valid broadcast map. A valid
6161/// broadcast map must include K dimension.
6162/// TODO: Strict inclusion of K dimension in the broadcast map is not
6163/// necessary for both input matrices simultaneously. We can relax this
6164/// condition to have K dimension for one input matrix map and infer the K
6165/// dimension for other input matrix map from the one already having K
6166/// dimension.
6167bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6168 bool isLHS) {
6169 assert(bcastMap.getNumResults() < 3 &&
6170 "Expected less than 3 result dim expr.");
6171 bool isValid = false;
6172 enum Indices { batchPos, mPos, nPos, kPos };
6173 if (bcastMap.getNumResults() == 1) {
6174 AffineExpr expr = bcastMap.getResult(0);
6175 isValid = expr.isFunctionOfDim(kPos);
6176 } else if (bcastMap.getNumResults() == 2) {
6177 AffineExpr expr0 = bcastMap.getResult(0);
6178 AffineExpr expr1 = bcastMap.getResult(1);
6179 isValid =
6180 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
6181 expr0.isFunctionOfDim(mPos)) &&
6182 expr1.isFunctionOfDim(kPos))
6183 : ((expr0.isFunctionOfDim(batchPos) &&
6184 expr1.isFunctionOfDim(kPos)) ||
6185 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6186 }
6187 return isValid;
6188}
6189
6190void BatchReduceMatmulOp::regionBuilder(
6193 if (emitError && block.getNumArguments() != 3) {
6194 emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
6195 << block.getNumArguments();
6196 return;
6197 }
6198 assert(block.getNumArguments() == 3 &&
6199 "BatchReduceMatmulOp regionBuilder expects 3 args");
6200 RegionBuilderHelper helper(b, block);
6201 SmallVector<Value> yields;
6202
6203 auto toType = block.getArgument(2).getType();
6204 Value castValA =
6205 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
6206 Value castValB =
6207 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
6208 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
6209 Value addVal =
6210 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
6211 yields.push_back(addVal);
6212 helper.yieldOutputs(yields);
6213}
6214
6215ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6217 SmallVector<Attribute, 3> indexingMapsAttr;
6218 Attribute mapAttr;
6219 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
6220 if (parser.parseEqual())
6221 return failure();
6222 if (parser.parseLSquare())
6223 return failure();
6224
6225 do {
6226 if (parser.parseAttribute(mapAttr))
6227 return failure();
6228 if (!isa<AffineMapAttr>(mapAttr)) {
6229 return parser.emitError(parser.getCurrentLocation(),
6230 "expected affine map attribute");
6231 }
6232 indexingMapsAttr.push_back(mapAttr);
6233
6234 if (parser.parseOptionalComma())
6235 break;
6236 } while (true);
6237
6238 if (parser.parseRSquare())
6239 return failure();
6240 }
6241 // Initialize indexingMaps, if not supplied explicitly.
6242 if (indexingMapsAttr.empty()) {
6243 indexingMapsAttr = llvm::map_to_vector(
6244 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
6245 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6246 }
6247 result.addAttribute("indexing_maps",
6248 parser.getBuilder().getArrayAttr(indexingMapsAttr));
6249 return ::parseNamedStructuredOp(parser, result,
6250 BatchReduceMatmulOp::getNumRegionArgs(),
6251 BatchReduceMatmulOp::getRegionBuilder());
6252}
6253
6254void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6255 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6256 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
6257 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6258
6259 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6260 p << " indexing_maps = [";
6261 llvm::interleaveComma(getIndexingMaps(), p,
6262 [&](Attribute attr) { p.printAttribute(attr); });
6263 p << "]";
6264 }
6265
6266 SmallVector<StringRef, 3> elidedAttrs = {
6267 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
6268 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
6269 elidedAttrs);
6270}
6271
6272/// Verify the user defined indexing maps.
6273LogicalResult BatchReduceMatmulOp::verify() {
6274 // Verification of pure batch_reduce_matmul is handled by
6275 // verifyStructuredOpInterface().
6276 if (!hasUserDefinedMaps())
6277 return success();
6278
6279 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
6281 return failure();
6282 }
6283 return success();
6284}
6285LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6287 return memref::foldMemRefCast(*this);
6288}
6289void BatchReduceMatmulOp::getEffects(
6291 &effects) {
6292 if (hasPureTensorSemantics())
6293 return;
6294 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
6295}
6296
6297Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
6298 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
6299}
6300
6301} // namespace linalg
6302} // namespace mlir
6303
6304//===----------------------------------------------------------------------===//
6305// LinalgDialect
6306//===----------------------------------------------------------------------===//
6307
6308void LinalgDialect::getCanonicalizationPatterns(
6309 RewritePatternSet &results) const {
6310 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
6311 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
6312}
6313
6314Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
6315 Attribute value, Type type,
6316 Location loc) {
6317 return arith::ConstantOp::materialize(builder, value, type, loc);
6318}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition LinalgOps.cpp:57
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
Location getUnknownLoc()
Definition Builders.cpp:25
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
result_iterator result_begin()
Definition Operation.h:413
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_iterator result_end()
Definition Operation.h:414
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
Definition Linalg.h:251
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
Definition Linalg.h:298
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
Definition Linalg.h:157
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
Definition Linalg.h:204
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:95
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:45
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:238
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1293
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override