MLIR 22.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
14
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Attributes.h"
29#include "mlir/IR/Builders.h"
38
39#include "llvm/ADT/DenseMap.h"
40#include "llvm/ADT/STLExtras.h"
41#include "llvm/ADT/SetOperations.h"
42#include "llvm/ADT/SmallVector.h"
43#include "llvm/ADT/StringSet.h"
44#include "llvm/ADT/TypeSwitch.h"
45#include "llvm/Support/FormatVariadic.h"
46#include "llvm/Support/InterleavedRange.h"
47#include "llvm/Support/LogicalResult.h"
48#include "llvm/Support/MathExtras.h"
49#include "llvm/Support/raw_ostream.h"
50#include <cassert>
51#include <optional>
52
53using namespace mlir;
54using namespace mlir::linalg;
55
56/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
58 int64_t dim) {
59 auto type = cast<ShapedType>(v.getType());
60 if (!type.isDynamicDim(dim))
61 return builder.getIndexAttr(type.getDimSize(dim));
62
63 return getAsOpFoldResult(
65 .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
66 return tensor::DimOp::create(builder, loc, v, dim);
67 })
68 .Case<MemRefType>([&](MemRefType t) -> Value {
69 return memref::DimOp::create(builder, loc, v, dim);
70 }));
71}
72
73/// Returns a memref.subview or a tensor.extract_slice based on the type of the
74/// `source`.
78 ArrayRef<OpFoldResult> strides) {
80 .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
81 return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
82 strides);
83 })
84 .Case<MemRefType>([&](MemRefType type) -> Operation * {
85 return memref::SubViewOp::create(b, loc, source, offsets, sizes,
86 strides);
87 })
88 .Default([&](Type t) -> Operation * { return nullptr; });
89}
90
91//===----------------------------------------------------------------------===//
92// Helper functions
93//===----------------------------------------------------------------------===//
94
96 int64_t dim) {
97 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
98 return b.createOrFold<memref::DimOp>(loc, source, dim);
99 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
100 return b.createOrFold<tensor::DimOp>(loc, source, dim);
101 llvm_unreachable("Expected MemRefType or TensorType");
102}
103
105 int64_t dim) {
106 auto shapedType = llvm::cast<ShapedType>(source.getType());
107 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
108 return createOrFoldDimOp(b, loc, source, dim);
109 return b.getIndexAttr(shapedType.getDimSize(dim));
110}
111
112//===----------------------------------------------------------------------===//
113// Support for named Linalg ops defined in ods-gen.
114//===----------------------------------------------------------------------===//
115
119
120/// Fills the region of a structured operation using the provided
121/// `regionBuilder`. The method is used by both named structured ops created by
122/// ods-gen and by manually defined C++ ops. It is called by both builders and
123/// parsers and creates a block with arguments corresponding to the elemental
124/// types of `inputTypes` and `outputTypes`.
125static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
126 TypeRange inputTypes, TypeRange outputTypes,
129 RegionBuilderFn regionBuilder) {
130 SmallVector<Type, 8> argTypes;
132 for (auto containers : {inputTypes, outputTypes}) {
133 for (auto t : containers) {
134 argTypes.push_back(
135 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
136
137 // TODO: Pass in a proper location here.
138 argLocs.push_back(opBuilder.getUnknownLoc());
139 }
140 }
141
142 // RAII.
143 OpBuilder::InsertionGuard guard(opBuilder);
144 Block *body =
145 opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
146
147 opBuilder.setInsertionPointToStart(body);
148 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
149 regionBuilder(b, *body, attrs, emitError);
150
151 // indexing_maps is an auto-generated method.
152
153 // iterator_types is an auto-generated method.
154}
155
156/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
157/// The result types are derived automatically if `resultTensorTypes` is none.
158/// The body of the operation is filled using `regionBuilder`. All ods-gen
159/// created structured operations use the method to implement their builders.
161 std::optional<TypeRange> resultTensorTypes,
162 ValueRange inputs, ValueRange outputs,
163 ArrayRef<NamedAttribute> attributes,
164 RegionBuilderFn regionBuilder) {
165 // Derive the result types if needed.
166 SmallVector<Type> derivedResultTypes =
167 resultTensorTypes.value_or(TypeRange());
168 if (!resultTensorTypes)
169 copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
170 llvm::IsaPred<RankedTensorType>);
171
172 state.addOperands(inputs);
173 state.addOperands(outputs);
174 state.addTypes(derivedResultTypes);
175
176 state.addAttributes(attributes);
177 state.addAttribute(
178 "operandSegmentSizes",
179 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
180 static_cast<int32_t>(outputs.size())}));
181
182 // Create and fill the region of the structured operation.
183 Region &region = *state.addRegion();
184 fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
185 state.attributes.getAttrs(), /*emitError=*/{},
186 regionBuilder);
187}
188
190 std::optional<TypeRange> resultTensorTypes,
191 ValueRange inputs, ValueRange outputs,
192 ArrayRef<NamedAttribute> attributes,
193 RegionBuilderFn regionBuilder,
194 ArrayRef<AffineMap> indexingMaps) {
195 // Initialize indexingMaps attribute, for MatmulOp.
196 SmallVector<Attribute, 3> indexingMapsAttrVal;
197 indexingMapsAttrVal =
198 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
199 return AffineMapAttr::get(map);
200 });
201 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
202 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
203 attributes, regionBuilder);
204}
205
207 std::optional<TypeRange> resultTensorTypes,
208 ValueRange inputs, ValueRange outputs,
209 ArrayRef<NamedAttribute> attributes,
210 RegionBuilderFn regionBuilder,
211 ArrayRef<AffineMap> indexingMaps) {
212 // Initialize indexingMaps attribute, for BatchMatmulOp.
213 SmallVector<Attribute, 4> indexingMapsAttrVal;
214 indexingMapsAttrVal =
215 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
216 return AffineMapAttr::get(map);
217 });
218 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
219 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
220 attributes, regionBuilder);
221}
222
224 std::optional<TypeRange> resultTensorTypes,
225 ValueRange inputs, ValueRange outputs,
226 ArrayRef<NamedAttribute> attributes,
227 RegionBuilderFn regionBuilder,
228 ArrayRef<AffineMap> indexingMaps) {
229 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
230 SmallVector<Attribute, 4> indexingMapsAttrVal;
231 indexingMapsAttrVal =
232 llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
233 return AffineMapAttr::get(map);
234 });
235 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
236 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
237 attributes, regionBuilder);
238}
239
240/// Common parsing used for both named structured ops created by ods-gen and by
241/// manually defined C++ ops. Does not handle regions.
242static ParseResult
244 SmallVectorImpl<Type> &inputTypes,
245 SmallVectorImpl<Type> &outputTypes,
246 bool addOperandSegmentSizes = true) {
247 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
249 outputsOperands;
250
251 if (succeeded(parser.parseOptionalLess())) {
252 if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
253 return failure();
254 }
255 attrsLoc = parser.getCurrentLocation();
256 if (parser.parseOptionalAttrDict(result.attributes))
257 return failure();
258
259 if (succeeded(parser.parseOptionalKeyword("ins"))) {
260 if (parser.parseLParen())
261 return failure();
262
263 inputsOperandsLoc = parser.getCurrentLocation();
264 if (parser.parseOperandList(inputsOperands) ||
265 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
266 return failure();
267 }
268
269 if (succeeded(parser.parseOptionalKeyword("outs"))) {
270 outputsOperandsLoc = parser.getCurrentLocation();
271 if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
272 parser.parseColonTypeList(outputTypes) || parser.parseRParen())
273 return failure();
274 }
275
276 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
277 result.operands) ||
278 parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
279 result.operands))
280 return failure();
281
282 if (addOperandSegmentSizes) {
283 // This is a bit complex because we're trying to be backward compatible with
284 // operation syntax that mix the inherent attributes and the discardable
285 // ones in the same dictionary. If the properties are used, we append the
286 // operandSegmentSizes there directly. Otherwise we append it to the
287 // discardable attributes dictionary where it is handled by the generic
288 // Operation::create(...) method.
289 if (result.propertiesAttr) {
290 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
291 attrs.append("operandSegmentSizes",
293 {static_cast<int32_t>(inputsOperands.size()),
294 static_cast<int32_t>(outputsOperands.size())}));
295 result.propertiesAttr = attrs.getDictionary(parser.getContext());
296 } else {
297 result.addAttribute("operandSegmentSizes",
299 {static_cast<int32_t>(inputsOperands.size()),
300 static_cast<int32_t>(outputsOperands.size())}));
301 }
302 }
303 if (!result.propertiesAttr) {
304 std::optional<RegisteredOperationName> info =
305 result.name.getRegisteredInfo();
306 if (info) {
307 if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
308 return parser.emitError(attrsLoc)
309 << "'" << result.name.getStringRef() << "' op ";
310 })))
311 return failure();
312 }
313 }
314 return success();
315}
316
318 ValueRange outputs) {
319 if (!inputs.empty())
320 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
321 if (!outputs.empty())
322 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
323}
324
325//===----------------------------------------------------------------------===//
326// Specific parsing and printing for named structured ops created by ods-gen.
327//===----------------------------------------------------------------------===//
328
330 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
331 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
332 RegionBuilderFn regionBuilder, SMLoc loc) {
333 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
334 return parser.emitError(
335 parser.getCurrentLocation(),
336 llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
337 "region expects {0} args, got {1}",
338 numRegionArgs, inputTypes.size() + outputTypes.size()));
339 }
340
341 OpBuilder opBuilder(parser.getContext());
342 ParseResult result = success();
344 opBuilder, region, inputTypes, outputTypes, attrs,
345 [&]() {
346 result = failure();
347 return parser.emitError(loc);
348 },
349 regionBuilder);
350 return result;
351}
352
353static ParseResult
355 SmallVectorImpl<Type> &resultTypes) {
356 if (parser.parseOptionalArrowTypeList(resultTypes))
357 return failure();
358 return success();
359}
360
361static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
363 unsigned numRegionArgs,
364 RegionBuilderFn regionBuilder) {
365 // TODO: Enable when ods-gen supports captures.
366 SmallVector<Type, 1> inputTypes, outputTypes;
367 SMLoc loc = parser.getCurrentLocation();
368 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
369 return failure();
370
371 // Parse optional attributes.
372 if (parser.parseOptionalAttrDict(result.attributes))
373 return failure();
374
375 // TODO: consider merging results parsing into region parsing.
376 // Need to wait for declarative assembly resolution to decide.
377 SmallVector<Type, 1> outputTensorsTypes;
378 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
379 return failure();
380 result.addTypes(outputTensorsTypes);
381
382 std::unique_ptr<Region> region = std::make_unique<Region>();
383 if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
384 outputTypes, result.attributes.getAttrs(),
385 regionBuilder, loc))
386 return failure();
387 result.addRegion(std::move(region));
388
389 return success();
390}
391
393 TypeRange resultTypes) {
394 if (resultTypes.empty())
395 return;
396 p.printOptionalArrowTypeList(resultTypes);
397}
398
400 ValueRange inputs, ValueRange outputs,
401 ArrayRef<StringRef> elidedAttrs = {}) {
402 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
403
404 // Printing is shared with generic ops, except for the region and
405 // attributes.
406 printCommonStructuredOpParts(p, inputs, outputs);
407
408 // Results printing.
410
411 // Region is elided.
412}
413
414//===----------------------------------------------------------------------===//
415// Region builder helper.
416// TODO: Move this to a utility library.
417// The public methods on this class are referenced directly from generated code.
418// Helper build the unary, binary, and type conversion functions defined by the
419// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
420// class.
421//
422// Implementations of the math functions must be polymorphic over numeric types,
423// internally performing necessary casts. If the function application makes no
424// sense, then the only recourse is to assert and return nullptr. This can be
425// extended later if it becomes possible to fail construction of the region. The
426// invariant should be enforced at a higher level.
427//
428// TODO: These helpers are currently type polymorphic over the class of integer
429// and floating point types, but they will not internally cast within bit
430// widths of a class (mixed precision such as i8->i32) or across classes
431// (i.e. mixed float and integer). Many such combinations are ambiguous or need
432// to be handled with care and work is being considered to extend the op
433// language to make such cases explicit. In the mean-time, violating this will
434// fail verification, which is deemed acceptable.
435//===----------------------------------------------------------------------===//
436
437namespace {
438
439class RegionBuilderHelper {
440public:
441 RegionBuilderHelper(OpBuilder &builder, Block &block)
442 : builder(builder), block(block) {}
443
444 // Build the unary functions defined by OpDSL.
445 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
446 function_ref<InFlightDiagnostic()> emitError = {}) {
447 if (!isFloatingPoint(arg)) {
448 if (emitError) {
449 emitError() << "unsupported non numeric type";
450 return nullptr;
451 }
452 llvm_unreachable("unsupported non numeric type");
453 }
454 OpBuilder::InsertionGuard g(builder);
455 builder.setInsertionPointToEnd(&block);
456 switch (unaryFn) {
457 case UnaryFn::exp:
458 return math::ExpOp::create(builder, arg.getLoc(), arg);
459 case UnaryFn::log:
460 return math::LogOp::create(builder, arg.getLoc(), arg);
461 case UnaryFn::abs:
462 return math::AbsFOp::create(builder, arg.getLoc(), arg);
463 case UnaryFn::ceil:
464 return math::CeilOp::create(builder, arg.getLoc(), arg);
465 case UnaryFn::floor:
466 return math::FloorOp::create(builder, arg.getLoc(), arg);
467 case UnaryFn::negf:
468 return arith::NegFOp::create(builder, arg.getLoc(), arg);
469 case UnaryFn::reciprocal: {
470 Attribute oneAttr = builder.getOneAttr(arg.getType());
471 auto one = arith::ConstantOp::create(builder, arg.getLoc(),
472 ::cast<TypedAttr>(oneAttr));
473 return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
474 }
475 case UnaryFn::round:
476 return math::RoundOp::create(builder, arg.getLoc(), arg);
477 case UnaryFn::sqrt:
478 return math::SqrtOp::create(builder, arg.getLoc(), arg);
479 case UnaryFn::rsqrt:
480 return math::RsqrtOp::create(builder, arg.getLoc(), arg);
481 case UnaryFn::square:
482 return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
483 case UnaryFn::tanh:
484 return math::TanhOp::create(builder, arg.getLoc(), arg);
485 case UnaryFn::erf:
486 return math::ErfOp::create(builder, arg.getLoc(), arg);
487 }
488 if (emitError) {
489 emitError() << "unsupported unary function";
490 return nullptr;
491 }
492 llvm_unreachable("unsupported unary function");
493 }
494
495 // Build the binary functions defined by OpDSL.
496 // If emitError is provided, an error will be emitted if the operation is not
497 // supported and a nullptr will be returned, otherwise an assertion will be
498 // raised.
499 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
500 function_ref<InFlightDiagnostic()> emitError = {}) {
501 bool allComplex = isComplex(arg0) && isComplex(arg1);
502 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
503 bool allInteger = isInteger(arg0) && isInteger(arg1);
504 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
505 arg1.getType().getIntOrFloatBitWidth() == 1;
506 if (!allComplex && !allFloatingPoint && !allInteger) {
507 if (emitError) {
508 emitError()
509 << "Cannot build binary Linalg operation: expects allComplex, "
510 "allFloatingPoint, or allInteger, got "
511 << arg0.getType() << " and " << arg1.getType();
512 return nullptr;
513 }
514 llvm_unreachable("unsupported non numeric type");
515 }
516 OpBuilder::InsertionGuard g(builder);
517 builder.setInsertionPointToEnd(&block);
518 switch (binaryFn) {
519 case BinaryFn::add:
520 if (allComplex)
521 return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
522 if (allFloatingPoint)
523 return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
524 if (allBool)
525 return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
526 return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
527 case BinaryFn::sub:
528 if (allComplex)
529 return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
532 if (allBool) {
533 if (emitError) {
534 emitError() << "unsupported operation: sub with bools";
535 return nullptr;
536 }
537 llvm_unreachable("unsupported operation: sub with bools");
538 }
539 return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
540 case BinaryFn::mul:
541 if (allComplex)
542 return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
543 if (allFloatingPoint)
544 return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
545 if (allBool)
546 return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
547 return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
548 case BinaryFn::div:
549 if (allComplex)
550 return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
553 if (allBool) {
554 if (emitError) {
555 emitError() << "unsupported operation: div with bools";
556 return nullptr;
557 }
558 llvm_unreachable("unsupported operation: div with bools");
559 }
560 return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
561 case BinaryFn::div_unsigned:
562 if (!allInteger || allBool) {
563 if (emitError) {
564 emitError() << "unsupported operation: unsigned div not on uint";
565 return nullptr;
566 }
567 llvm_unreachable("unsupported operation: unsigned div not on uint");
568 }
569 return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
570 case BinaryFn::max_signed:
571 assert(!allComplex);
572 if (allFloatingPoint)
573 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
574 return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
575 case BinaryFn::min_signed:
576 assert(!allComplex);
577 if (allFloatingPoint)
578 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
579 return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
580 case BinaryFn::max_unsigned:
581 assert(!allComplex);
582 if (allFloatingPoint)
583 return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
584 return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
585 case BinaryFn::min_unsigned:
586 assert(!allComplex);
587 if (allFloatingPoint)
588 return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
589 return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
590 case BinaryFn::powf:
591 assert(allFloatingPoint);
592 return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
593 }
594 if (emitError) {
595 emitError() << "unsupported binary function";
596 return nullptr;
597 }
598 llvm_unreachable("unsupported binary function");
599 }
600
601 // Build the ternary functions defined by OpDSL.
602 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
603 function_ref<InFlightDiagnostic()> emitError = {}) {
604 bool headBool =
605 isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
606 bool tailFloatingPoint =
607 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
608 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
609 OpBuilder::InsertionGuard g(builder);
610 builder.setInsertionPointToEnd(&block);
611 switch (ternaryFn) {
612 case TernaryFn::select:
613 if (!headBool && !(tailFloatingPoint || tailInteger))
614 llvm_unreachable("unsupported non numeric type");
615 return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
616 }
617 if (emitError) {
618 emitError() << "unsupported ternary function";
619 return nullptr;
620 }
621 llvm_unreachable("unsupported ternary function");
622 }
623
624 // Build the type functions defined by OpDSL.
625 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
626 function_ref<InFlightDiagnostic()> emitError = {}) {
627 switch (typeFn) {
628 case TypeFn::cast_signed:
629 return cast(toType, operand, false);
630 case TypeFn::cast_unsigned:
631 return cast(toType, operand, true);
632 }
633 if (emitError) {
634 emitError() << "unsupported type conversion function";
635 return nullptr;
636 }
637 llvm_unreachable("unsupported type conversion function");
638 }
639
640 void yieldOutputs(ValueRange values) {
641 OpBuilder::InsertionGuard g(builder);
642 builder.setInsertionPointToEnd(&block);
643 Location loc = builder.getUnknownLoc();
644 YieldOp::create(builder, loc, values);
645 }
646
647 Value constant(const std::string &value) {
648 OpBuilder::InsertionGuard g(builder);
649 builder.setInsertionPointToEnd(&block);
650 Location loc = builder.getUnknownLoc();
651 Attribute valueAttr = parseAttribute(value, builder.getContext());
652 return arith::ConstantOp::create(builder, loc,
653 ::cast<TypedAttr>(valueAttr));
654 }
655
656 Value index(int64_t dim) {
657 OpBuilder::InsertionGuard g(builder);
658 builder.setInsertionPointToEnd(&block);
659 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
660 }
661
662 Type getIntegerType(unsigned width) {
663 return IntegerType::get(builder.getContext(), width);
664 }
665
666 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
667 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
668
669private:
670 // Generates operations to cast the given operand to a specified type.
671 // If the cast cannot be performed, a warning will be issued and the
672 // operand returned as-is (which will presumably yield a verification
673 // issue downstream).
674 Value cast(Type toType, Value operand, bool isUnsignedCast) {
675 OpBuilder::InsertionGuard g(builder);
676 builder.setInsertionPointToEnd(&block);
677 auto loc = operand.getLoc();
678 if (isa<UnknownLoc>(loc)) {
679 if (operand.getDefiningOp())
680 loc = operand.getDefiningOp()->getLoc();
681 else if (operand.getParentBlock() &&
682 operand.getParentBlock()->getParentOp())
683 loc = operand.getParentBlock()->getParentOp()->getLoc();
684 }
685 return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
686 }
687
688 bool isComplex(Value value) {
689 return llvm::isa<ComplexType>(value.getType());
690 }
691 bool isFloatingPoint(Value value) {
692 return llvm::isa<FloatType>(value.getType());
693 }
694 bool isInteger(Value value) {
695 return llvm::isa<IntegerType>(value.getType());
696 }
697
698 OpBuilder &builder;
699 Block &block;
700};
701
702} // namespace
703
704//===----------------------------------------------------------------------===//
705// CopyOp
706//===----------------------------------------------------------------------===//
707
708namespace {
709
710struct EraseSelfCopy : OpRewritePattern<CopyOp> {
711 using OpRewritePattern<CopyOp>::OpRewritePattern;
712 LogicalResult matchAndRewrite(CopyOp copyOp,
713 PatternRewriter &rewriter) const override {
714 if (copyOp.getInputs() != copyOp.getOutputs())
715 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
716 if (copyOp.hasPureBufferSemantics())
717 rewriter.eraseOp(copyOp);
718 else
719 rewriter.replaceOp(copyOp, copyOp.getInputs());
720
721 return success();
722 }
723};
724
725} // namespace
726
727void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
728 MLIRContext *context) {
729 results.add<EraseSelfCopy>(context);
730}
731
732//===----------------------------------------------------------------------===//
733// FillOp
734//===----------------------------------------------------------------------===//
735
736namespace {
737
738/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
739///
740/// For such op chains, we can create new linalg.fill ops with the result
741/// type of the tensor.expand/collapse_shape op.
742template <typename TensorReshapeOp>
743struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
744 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
745 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
746 PatternRewriter &rewriter) const override {
747 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
748 if (!oldFill)
749 return failure();
750
751 Location loc = oldFill.getLoc();
752 TensorReshapeOp newInit;
753 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
754
755 newInit = TensorReshapeOp::create(
756 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
757 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
758 reshapeOp.getStaticOutputShape());
759 } else {
760 newInit = TensorReshapeOp::create(
761 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
762 reshapeOp.getReassociation());
763 }
764 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
765 ValueRange{newInit});
766 return success();
767 }
768};
769
770/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
771/// filling value are the same.
772struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
774
775 LogicalResult matchAndRewrite(tensor::PadOp padOp,
776 PatternRewriter &rewriter) const override {
777 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
778 if (!fillOp)
779 return failure();
780
781 // We can only fold if the padding value is the same as the original
782 // filling value.
783 Value padValue = padOp.getConstantPaddingValue();
784 if (!padValue || fillOp.value() != padValue)
785 return failure();
786
787 ReifiedRankedShapedTypeDims reifiedShape;
788 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
789 return rewriter.notifyMatchFailure(
790 padOp, "failed to reify tensor.pad op result shape");
791
792 auto emptyTensor =
793 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
794 padOp.getResultType().getElementType());
795 Value replacement =
796 FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
797 ValueRange{emptyTensor})
798 .getResult(0);
799 if (replacement.getType() != padOp.getResultType()) {
800 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
801 padOp.getResultType(), replacement);
802 }
803 rewriter.replaceOp(padOp, replacement);
804 return success();
805 }
806};
807
808/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
809/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
810/// filling value are the same.
811struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
813
814 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
815 PatternRewriter &rewriter) const override {
816 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
817 if (!srcPadOp)
818 return failure();
819
820 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
821 return failure();
822
823 // Walk back the tensor.insert_slice chain and find the first destination
824 // value at the start of the chain.
825 Value firstDest = insertOp.getDest();
826 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
827 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
828 return failure();
829
830 // Make sure the range of values accessed are disjoint. Without this, we
831 // cannot fold tensor.pad away.
832 bool disjoint = false;
833 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
834 // If the dimension has dynamic offset/size, we cannot guarantee
835 // disjoint. So just skip it.
836 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
837 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
838 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
839 continue;
840
841 // Get the range start and end, inclusively for both.
842 int64_t prevStart = prevOp.getStaticOffset(i);
843 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
844 prevOp.getStaticStride(i);
845 int64_t nextStart = insertOp.getStaticOffset(i);
846 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
847 insertOp.getStaticStride(i);
848 if (prevEnd < nextStart || nextEnd < prevStart) {
849 disjoint = true;
850 break;
851 }
852 }
853
854 if (!disjoint)
855 break;
856 firstDest = prevOp.getDest();
857 }
858
859 // Check whether the first destination is a fill op. For overlapped cases,
860 // this also cannot be true.
861 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
862 if (!dstFillOp)
863 return failure();
864
865 // We can only fold if the padding value is the same as the original
866 // filling value.
867 Value padValue = srcPadOp.getConstantPaddingValue();
868 if (!padValue || dstFillOp.value() != padValue)
869 return failure();
870
871 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
872 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
873
874 Location loc = insertOp.getLoc();
875 MLIRContext *context = getContext();
876
877 AffineExpr sym0, sym1;
878 bindSymbols(context, sym0, sym1);
879 auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
880
881 // Calculate the new offsets for the insert. It should be the old offsets
882 // plus low padding sizes.
883 SmallVector<OpFoldResult, 4> newOffsets;
884 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
885 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
886 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
887 }
888
889 RankedTensorType srcPadType = srcPadOp.getSourceType();
890 SmallVector<OpFoldResult, 4> newSizes;
891 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
892 if (srcPadType.isDynamicDim(i)) {
893 newSizes.push_back(
894 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
895 .getResult());
896 } else {
897 newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
898 }
899 }
900
901 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
902 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
903 newSizes, insertOp.getMixedStrides());
904 return success();
905 }
906};
907
908/// Fold tensor.extract(linalg.fill(<input>)) into <input>
909struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
910public:
911 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
912
913 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
914 PatternRewriter &rewriter) const override {
915 // See if tensor input of tensor.extract op is the result of a linalg.fill
916 // op.
917 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
918 if (!fillOp)
919 return failure();
920
921 // Get scalar input operand of linalg.fill op.
922 Value extractedScalar = fillOp.getInputs()[0];
923
924 // Replace tensor.extract op with scalar value used to fill the tensor.
925 rewriter.replaceOp(extractOp, extractedScalar);
926 return success();
927 }
928};
929
930/// Folds pack(fill) into a single fill op if
931/// 1. The pack op does not have padding value, or
932/// 2. The filled value and padding value are the same.
933static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
934 linalg::PackOp packOp) {
935 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
936 if (!fillOp)
937 return failure();
938
939 if (auto paddingValue = packOp.getPaddingValue())
940 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
941 return failure();
942
943 Value packOpDest = packOp.getDest();
944 if (!packOpDest.hasOneUse())
945 return failure();
946
947 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
948 packOp.getDest());
949}
950
951/// Wrapper pattern that applies foldFillPackIntoFillOp method.
952struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
953public:
954 FoldFillWithPack(MLIRContext *context)
955 : OpRewritePattern<linalg::PackOp>(context) {}
956
957 LogicalResult matchAndRewrite(linalg::PackOp packOp,
958 PatternRewriter &rewriter) const override {
959 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
960 if (failed(fillOp))
961 return failure();
962 rewriter.replaceOp(packOp, fillOp.value().result());
963 return success();
964 }
965};
966
967/// Fold fill with copy.
968struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
969 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
970
971 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
972 PatternRewriter &rewriter) const override {
973 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
974 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
975 fillOp.getInputs(),
976 copyOp.getOutputs());
977 return success();
978 }
979 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
980 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
981 fillOp.getOutputs());
982 return success();
983 }
984 return failure();
985 }
986};
987
988/// Fold fill with transpose.
989struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
990 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
991
992 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
993 PatternRewriter &rewriter) const override {
994 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
995 rewriter.replaceOpWithNewOp<FillOp>(
996 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
997 transposeOp.getDpsInitOperand(0)->get());
998 return success();
999 }
1000 return failure();
1001 }
1002};
1003
1004/// Fold a concat with all elements being fills of the same value
1005/// into a fill of the concat result shape.
1006struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1008
1009 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1010 PatternRewriter &rewriter) const override {
1011 auto concatOperands = concatOp.getInputs();
1012 if (concatOperands.empty()) {
1013 return failure();
1014 }
1015
1016 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1017 if (!firstFillOp) {
1018 return failure();
1019 }
1020 // Prefetch the fill value.
1021 OpFoldResult firstFillVal =
1022 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1023 // Collect all the outs values for the fill operations.
1024 SmallVector<Value> allOuts;
1025 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1026
1027 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1028 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1029 if (!fillOp) {
1030 return false;
1031 }
1032
1033 OpFoldResult fillVal =
1034 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1035 if (fillVal != firstFillVal)
1036 return false;
1037
1038 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1039 return true;
1040 };
1041 if (!llvm::all_of(concatOperands.drop_front(),
1042 isDefinedByCompatibleFillOp)) {
1043 return rewriter.notifyMatchFailure(
1044 concatOp, "not all operands are defined by a compatible fill op");
1045 }
1046
1047 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1048 concatOp.getDim(), allOuts);
1049 rewriter.replaceOpWithNewOp<linalg::FillOp>(
1050 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1051 return success();
1052 }
1053};
1054
1055} // namespace
1056
1057void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1058 MLIRContext *context) {
1059 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1060 FoldFillWithPack, FoldFillWithPad,
1061 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1062 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1063 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1064}
1065
1066//===----------------------------------------------------------------------===//
1067// GenericOp
1068//===----------------------------------------------------------------------===//
1069
1071 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1072 ValueRange outputs,
1073 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1074 SmallVector<Type, 4> blockArgTypes;
1075 SmallVector<Location, 4> blockArgLocs;
1076 for (ValueRange container : {inputs, outputs}) {
1077 for (Value v : container) {
1078 Type t = v.getType();
1079 blockArgTypes.push_back(
1080 isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1081 blockArgLocs.push_back(v.getLoc());
1082 }
1083 }
1084
1085 OpBuilder::InsertionGuard guard(builder);
1086 Block *bodyBlock =
1087 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1088 bodyBuild(builder, loc, bodyBlock->getArguments());
1089}
1090
1091void GenericOp::getAsmBlockArgumentNames(Region &region,
1092 OpAsmSetValueNameFn setNameFn) {
1093 for (Value v : getRegionInputArgs())
1094 setNameFn(v, "in");
1095 for (Value v : getRegionOutputArgs())
1096 setNameFn(v, "out");
1097}
1098
1099void GenericOp::build(
1100 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1101 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1102 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1103 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1104 ArrayRef<NamedAttribute> attributes) {
1105 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1106 iteratorTypes, doc, libraryCall);
1107 result.addAttributes(attributes);
1108 if (bodyBuild)
1109 buildGenericRegion(builder, result.location, *result.regions.front(),
1110 inputs, outputs, bodyBuild);
1111}
1112
1113void GenericOp::build(
1114 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1115 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1116 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1117 StringRef libraryCall,
1118 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1119 ArrayRef<NamedAttribute> attributes) {
1120 build(builder, result, resultTensorTypes, inputs, outputs,
1121 builder.getAffineMapArrayAttr(indexingMaps),
1122 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1123 iteratorTypes,
1124 [&](utils::IteratorType iter) -> mlir::Attribute {
1125 return IteratorTypeAttr::get(builder.getContext(), iter);
1126 }))),
1127 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1128 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1129 bodyBuild, attributes);
1130}
1131
1132void GenericOp::build(
1133 OpBuilder &builder, OperationState &result, ValueRange inputs,
1134 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1135 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1136 StringRef libraryCall,
1137 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1138 ArrayRef<NamedAttribute> attributes) {
1139 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1140 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1141}
1142
1143void GenericOp::build(
1144 OpBuilder &builder, OperationState &result, ValueRange inputs,
1145 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146 ArrayRef<utils::IteratorType> iteratorTypes,
1147 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1148 ArrayRef<NamedAttribute> attributes) {
1149 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1150 /*doc=*/"",
1151 /*libraryCall=*/"", bodyBuild, attributes);
1152}
1153
1154void GenericOp::build(
1155 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1156 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes,
1158 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1159 ArrayRef<NamedAttribute> attributes) {
1160 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1161 iteratorTypes,
1162 /*doc=*/"",
1163 /*libraryCall=*/"", bodyBuild, attributes);
1164}
1165
1166void GenericOp::print(OpAsmPrinter &p) {
1167 p << " ";
1168
1169 // Print extra attributes.
1170 auto genericAttrNames = linalgTraitAttrNames();
1171
1172 llvm::StringSet<> genericAttrNamesSet;
1173 genericAttrNamesSet.insert_range(genericAttrNames);
1174 SmallVector<NamedAttribute, 8> genericAttrs;
1175 for (auto attr : (*this)->getAttrs()) {
1176 if (attr.getName() == getIteratorTypesAttrName()) {
1177 auto iteratorTypes =
1178 llvm::cast<ArrayAttr>(attr.getValue())
1179 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1180 // Convert IteratorType enums into the string representation. This is
1181 // needed, because tests still use the old format when 'iterator_types'
1182 // attribute is represented as an array of strings.
1183 // TODO: Remove this conversion once tests are fixed.
1184 SmallVector<Attribute> iteratorTypeNames =
1185 llvm::to_vector(llvm::map_range(
1186 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1187 return StringAttr::get(getContext(), stringifyIteratorType(t));
1188 }));
1189
1190 genericAttrs.emplace_back(
1191 getIteratorTypesAttrName(),
1192 ArrayAttr::get(getContext(), iteratorTypeNames));
1193 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1194 genericAttrs.push_back(attr);
1195 }
1196 }
1197 if (!genericAttrs.empty()) {
1198 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1199 p << genericDictAttr;
1200 }
1201
1202 // Printing is shared with named ops, except for the region and attributes
1203 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1204
1205 genericAttrNames.push_back("operandSegmentSizes");
1206 genericAttrNamesSet.insert(genericAttrNames.back());
1207
1208 bool hasExtraAttrs = false;
1209 for (NamedAttribute n : (*this)->getAttrs()) {
1210 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1211 break;
1212 }
1213 if (hasExtraAttrs) {
1214 p << " attrs = ";
1215 p.printOptionalAttrDict((*this)->getAttrs(),
1216 /*elidedAttrs=*/genericAttrNames);
1217 }
1218
1219 // Print region.
1220 if (!getRegion().empty()) {
1221 p << ' ';
1222 p.printRegion(getRegion());
1223 }
1224
1225 // Print results.
1226 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1227}
1228
1229ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1230 DictionaryAttr dictAttr;
1231 // Parse the core linalg traits that must check into a dictAttr.
1232 // The name is unimportant as we will overwrite result.attributes.
1233 // The core linalg traits must contain the information necessary to pass the
1234 // verifier.
1235 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1236 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1237 return failure();
1238 result.attributes.assign(dictAttr.getValue().begin(),
1239 dictAttr.getValue().end());
1240
1241 // Convert array of string into an array of IteratorType enums. This is
1242 // needed, because tests still use the old format when 'iterator_types'
1243 // attribute is represented as an array of strings.
1244 // TODO: Remove this conversion once tests are fixed.
1245 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1246 result.attributes.get(getIteratorTypesAttrName(result.name)));
1247 if (!iteratorTypes) {
1248 return parser.emitError(attributeLocation)
1249 << "expected " << getIteratorTypesAttrName(result.name)
1250 << " array attribute";
1251 }
1252
1253 SmallVector<Attribute> iteratorTypeAttrs;
1254
1255 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1256 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1257 if (!maybeIteratorType.has_value())
1258 return parser.emitError(parser.getCurrentLocation())
1259 << "unexpected iterator_type (" << s << ")";
1260
1261 iteratorTypeAttrs.push_back(
1262 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1263 }
1264 result.attributes.set(getIteratorTypesAttrName(result.name),
1265 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1266
1267 // Parsing is shared with named ops, except for the region.
1268 SmallVector<Type, 1> inputTypes, outputTypes;
1269 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1270 return failure();
1271
1272 // Optional attributes may be added.
1273 if (succeeded(parser.parseOptionalKeyword("attrs")))
1274 if (failed(parser.parseEqual()) ||
1275 failed(parser.parseOptionalAttrDict(result.attributes)))
1276 return failure();
1277
1278 std::unique_ptr<Region> region = std::make_unique<Region>();
1279 if (parser.parseRegion(*region, {}))
1280 return failure();
1281 result.addRegion(std::move(region));
1282
1283 // Generic ops may specify that a subset of its outputs are tensors. Such
1284 // outputs are specified in the result type.
1285 // TODO: may need to move output parsing before region parsing.
1286 // Need to wait for declarative assembly resolution to decide.
1287 SmallVector<Type, 1> outputTensorsTypes;
1288 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1289 return failure();
1290 result.addTypes(outputTensorsTypes);
1291
1292 return success();
1293}
1294
1297 &effects,
1298 LinalgOp linalgOp) {
1299 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1300 if (!llvm::isa<MemRefType>(operand.getType()))
1301 continue;
1302 effects.emplace_back(
1303 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1304 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1305 }
1306
1307 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1308 if (!llvm::isa<MemRefType>(operand.get().getType()))
1309 continue;
1310 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1311 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1312 /*effectOnFullRegion=*/true,
1314 }
1315 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1316 /*effectOnFullRegion=*/true,
1318 }
1319}
1320
1321void GenericOp::getEffects(
1322 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1323 &effects) {
1324 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1325}
1326
1329 // Operands with value semantics are speculatable, while operands with memory
1330 // semantics are not.
1331 if (!linalgOp.hasPureTensorSemantics())
1333 // The body of the op can still have speculation in its region.
1335}
1336
1337Speculation::Speculatability GenericOp::getSpeculatability() {
1338 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1339}
1340
1341LogicalResult GenericOp::verify() { return success(); }
1342
1343namespace {
1344
1345/// Remove linalg operations that are just copying the values from inputs to
1346/// results. In the memref case, the operation must be copying to and from the
1347/// same value. Requirements are:
1348/// 1) All iterator types are parallel
1349/// 2) The body contains just a yield operation with the yielded values being
1350/// the arguments corresponding to the operands.
1351template <typename OpTy>
1352struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1353 using OpRewritePattern<OpTy>::OpRewritePattern;
1354
1355 LogicalResult matchAndRewrite(OpTy linalgOp,
1356 PatternRewriter &rewriter) const override {
1357 // All indexing maps must be equal. It follows that they are permutations.
1358 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1359 return failure();
1360
1361 // Check that the body of the linalg operation is just a linalg.yield
1362 // operation.
1363 Block &body = linalgOp->getRegion(0).front();
1364 if (!llvm::hasSingleElement(body))
1365 return failure();
1366 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1367 if (!yieldOp)
1368 return failure();
1369
1370 // In the buffer case, we need to check exact buffer equality.
1371 if (linalgOp.hasPureBufferSemantics()) {
1372 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1373 linalgOp.getDpsInputOperand(0)->get() !=
1374 linalgOp.getDpsInitOperand(0)->get()) {
1375 return rewriter.notifyMatchFailure(
1376 linalgOp, "expected single input and output to be the same value");
1377 }
1378
1379 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1380 if (!yieldArg || yieldArg.getOwner() != &body) {
1381 return rewriter.notifyMatchFailure(linalgOp,
1382 "cannot fold fill-like op");
1383 }
1384
1385 rewriter.eraseOp(linalgOp);
1386 return success();
1387 }
1388
1389 if (!linalgOp.hasPureTensorSemantics()) {
1390 return rewriter.notifyMatchFailure(
1391 linalgOp, "mixed semantics is not supported yet");
1392 }
1393
1394 // Get the argument number of the returned values. That is the operand
1395 // number to use for replacing uses of this operation.
1396 SmallVector<Value> returnedArgs;
1397 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1398 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1399 if (!yieldArg || yieldArg.getOwner() != &body)
1400 return failure();
1401 unsigned argumentNumber = yieldArg.getArgNumber();
1402 Value returnedArg = linalgOp->getOperand(argumentNumber);
1403 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1404 // The input can have a different type than the result, e.g. a dynamic
1405 // input dimension can be turned into a static output dimension.
1406 Type returnType = returnedArg.getType();
1407 if (returnType != resultType) {
1408 // Distinguish between sparse conversion or dense tensor casting.
1409 // TODO: unify the two ops?
1412 returnedArg = sparse_tensor::ConvertOp::create(
1413 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1414 else {
1415 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1416 resultType))
1417 return failure();
1418 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1419 resultType, returnedArg);
1420 }
1421 }
1422 returnedArgs.push_back(returnedArg);
1423 }
1424
1425 if (returnedArgs.size() != linalgOp->getNumResults())
1426 return failure();
1427 rewriter.replaceOp(linalgOp, returnedArgs);
1428 return success();
1429 }
1430};
1431
1432} // namespace
1433
1434void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1435 MLIRContext *context) {
1436 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1437}
1438
1439LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1440 return memref::foldMemRefCast(*this);
1441}
1442
1443//===----------------------------------------------------------------------===//
1444// MapOp
1445//===----------------------------------------------------------------------===//
1446
1447static ParseResult parseDstStyleOp(
1449 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1450 nullptr) {
1451 // Parse `ins` and `outs`.
1452 SmallVector<Type, 4> inputTypes, outputTypes;
1453 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1454 /*addOperandSegmentSizes=*/false))
1455 return failure();
1456
1457 // Add result types.
1458 for (Type outputType : outputTypes) {
1459 if (llvm::isa<RankedTensorType>(outputType))
1460 result.addTypes(outputType);
1461 }
1462
1463 // Parse required attributes.
1464 if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1465 return failure();
1466
1467 // Parse optional attributes.
1468 if (parser.parseOptionalAttrDict(result.attributes))
1469 return failure();
1470 return success();
1471}
1472
1473void MapOp::getAsmBlockArgumentNames(Region &region,
1474 OpAsmSetValueNameFn setNameFn) {
1475 for (Value v : getRegionInputArgs())
1476 setNameFn(v, "in");
1477 for (Value v : getRegionOutputArgs())
1478 setNameFn(v, "init");
1479}
1480
1481void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1482 if (!getResults().empty())
1483 setNameFn(getResults().front(), "mapped");
1484}
1485
1486void MapOp::build(
1487 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1488 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1489 ArrayRef<NamedAttribute> attributes) {
1490 build(builder, result, TypeRange{}, inputs, init);
1491 result.addAttributes(attributes);
1492
1493 // Add output types for `RankedTensorType` output arguments.
1494 Type initType = init.getType();
1495 if (llvm::isa<RankedTensorType>(initType))
1496 result.addTypes(initType);
1497
1498 if (bodyBuild)
1499 buildGenericRegion(builder, result.location, *result.regions.front(),
1500 inputs, /*outputs=*/{init}, bodyBuild);
1501}
1502
1504 const OperationName &payloadOpName,
1505 const NamedAttrList &payloadOpAttrs,
1506 ArrayRef<Value> operands,
1507 bool initFirst = false, bool mapInit = true) {
1508 OpBuilder b(parser.getContext());
1509 Region *body = result.addRegion();
1510 Block &block = body->emplaceBlock();
1511 b.setInsertionPointToStart(&block);
1512 for (auto &operand : operands) {
1513 block.addArgument(
1514 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1515 b.getUnknownLoc());
1516 }
1517 SmallVector<Value> payloadOpOperands;
1518 // If initFirst flag is enabled, we consider init as the first position of
1519 // payload operands.
1520 if (initFirst) {
1521 if (mapInit)
1522 payloadOpOperands.push_back(block.getArguments().back());
1523 for (const auto &arg : block.getArguments().drop_back())
1524 payloadOpOperands.push_back(arg);
1525 } else {
1526 payloadOpOperands = {block.getArguments().begin(),
1527 block.getArguments().end() - int(!mapInit)};
1528 }
1529
1530 Operation *payloadOp = b.create(
1531 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1532 payloadOpOperands,
1533 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1534 .getElementType()},
1535 payloadOpAttrs);
1536 YieldOp::create(b, result.location, payloadOp->getResults());
1537}
1538
1539ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1540 std::optional<OperationName> payloadOpName;
1541 NamedAttrList payloadOpAttrs;
1542 if (succeeded(parser.parseOptionalLBrace())) {
1543 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1544 if (failed(operationName))
1545 return failure();
1546 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1547 return failure();
1548 payloadOpName = operationName.value();
1549 if (parser.parseRBrace())
1550 return failure();
1551 }
1552
1553 if (parseDstStyleOp(parser, result))
1554 return failure();
1555
1556 if (payloadOpName.has_value()) {
1557 if (!result.operands.empty())
1558 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1559 payloadOpAttrs, ArrayRef(result.operands), false,
1560 false);
1561 else
1562 result.addRegion();
1563 } else {
1564 SmallVector<OpAsmParser::Argument> regionArgs;
1565 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1566 /*allowType=*/true, /*allowAttrs=*/true)) {
1567 return failure();
1568 }
1569 Region *body = result.addRegion();
1570 if (parser.parseRegion(*body, regionArgs))
1571 return failure();
1572 }
1573 return success();
1574}
1575
1576static bool canUseShortForm(Block *body, bool initFirst = false,
1577 bool mapInit = true) {
1578 // `intFirst == true` implies that we want to map init arg
1579 if (initFirst && !mapInit)
1580 return false;
1581 // Check if the body can be printed in short form. The following 4 conditions
1582 // must be satisfied:
1583
1584 // 1) The body must contain exactly 2 operations: the payload op and a yield.
1585 if (body->getOperations().size() != 2)
1586 return false;
1587 Operation &payload = body->getOperations().front();
1588
1589 // 2) The payload op must have the same number of operands as the number of
1590 // block arguments.
1591 if (payload.getNumOperands() == 0 ||
1592 payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
1593 return false;
1594
1595 // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1596 // must be the first operand of the payload op, otherwise, the operands
1597 // must match the block arguments in order.
1598 if (initFirst) {
1599 // check init
1600 if (payload.getOperands().back() != body->getArgument(0))
1601 return false;
1602 // check rest
1603 for (const auto &[operand, bbArg] :
1604 llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1605 if (bbArg != operand)
1606 return false;
1607 }
1608 } else {
1609 for (const auto &[operand, bbArg] :
1610 llvm::zip(payload.getOperands(),
1611 body->getArguments().drop_back(int(!mapInit)))) {
1612 if (bbArg != operand)
1613 return false;
1614 }
1615 }
1616
1617 // 4) The `yield` operand must be the result of the payload op.
1618 auto yieldOp = cast<YieldOp>(body->getTerminator());
1619 return yieldOp.getNumOperands() == 1 &&
1620 yieldOp.getOperand(0).getDefiningOp() &&
1621 yieldOp.getOperand(0).getDefiningOp() == &payload;
1622}
1623
1624static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1625 SmallVector<StringRef> elidedAttrs;
1626 std::string attrToElide;
1627 p << " { " << payloadOp->getName().getStringRef();
1628 for (const auto &attr : payloadOp->getAttrs()) {
1629 auto fastAttr =
1630 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1631 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1632 attrToElide = attr.getName().str();
1633 elidedAttrs.push_back(attrToElide);
1634 break;
1635 }
1636 }
1637 p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1638 p << " }";
1639}
1640
1641void MapOp::print(OpAsmPrinter &p) {
1642 Block *mapper = getBody();
1643 bool useShortForm =
1644 canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
1645 if (useShortForm) {
1646 printShortForm(p, &mapper->getOperations().front());
1647 }
1648
1649 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1650 p.printOptionalAttrDict((*this)->getAttrs());
1651
1652 if (!useShortForm) {
1653 // Print region if the payload op was not detected.
1654 p.increaseIndent();
1655 p.printNewline();
1656 p << "(";
1657 llvm::interleaveComma(mapper->getArguments(), p,
1658 [&](auto arg) { p.printRegionArgument(arg); });
1659 p << ") ";
1660
1661 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1662 p.decreaseIndent();
1663 }
1664}
1665
1666LogicalResult MapOp::verify() {
1667 auto *bodyBlock = getBody();
1668 auto blockArgs = bodyBlock->getArguments();
1669
1670 // Checks if the number of `inputs` + `init` match the arity of the `mapper`
1671 // region.
1672 if (getInputs().size() + 1 != blockArgs.size())
1673 return emitOpError() << "expects number of operands to match the arity of "
1674 "mapper, but got: "
1675 << getInputs().size() + 1 << " and "
1676 << blockArgs.size();
1677
1678 // The parameters of mapper should all match the element type of inputs.
1679 for (const auto &[bbArgType, inputArg] :
1680 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1681 auto inputElemType =
1682 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1683 if (bbArgType != inputElemType) {
1684 return emitOpError() << "expected element type of input " << inputElemType
1685 << " to match bbArg type " << bbArgType;
1686 }
1687 }
1688
1689 // The shape of each input must match the shape of the output.
1690 auto outputShape = getInit().getType().getShape();
1691 for (Type inputArgType : TypeRange{getInputs()}) {
1692 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1693 if (inputElemShape != outputShape) {
1694 return emitOpError() << "expected shape of input (" << inputElemShape
1695 << ") to match shape of output (" << outputShape
1696 << ")";
1697 }
1698 }
1699
1700 return success();
1701}
1702
1703SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1704 int64_t rank = getInit().getType().getRank();
1705 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1706}
1707
1708ArrayAttr MapOp::getIndexingMaps() {
1709 Builder builder(getContext());
1710 int64_t rank = getInit().getType().getRank();
1711 int64_t numIndexingMaps = getOperands().size();
1712 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1713 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1714}
1715
1716void MapOp::getEffects(
1717 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1718 &effects) {
1719 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1720}
1721
1722Speculation::Speculatability MapOp::getSpeculatability() {
1723 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1724}
1725
1726//===----------------------------------------------------------------------===//
1727// ReduceOp
1728//===----------------------------------------------------------------------===//
1729
1730void ReduceOp::getAsmBlockArgumentNames(Region &region,
1731 OpAsmSetValueNameFn setNameFn) {
1732 for (Value v : getRegionInputArgs())
1733 setNameFn(v, "in");
1734 for (Value v : getRegionOutputArgs())
1735 setNameFn(v, "init");
1736}
1737
1738void ReduceOp::getAsmResultNames(
1739 function_ref<void(Value, StringRef)> setNameFn) {
1740 if (!getResults().empty())
1741 setNameFn(getResults().front(), "reduced");
1742}
1743
1744void ReduceOp::build(
1745 OpBuilder &builder, OperationState &result, ValueRange inputs,
1746 ValueRange inits, ArrayRef<int64_t> dimensions,
1747 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1748 ArrayRef<NamedAttribute> attributes) {
1749 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1750 result.addAttributes(attributes);
1751
1752 // Add output types for `RankedTensorType` output arguments.
1753 for (Value init : inits) {
1754 Type initType = init.getType();
1755 if (llvm::isa<RankedTensorType>(initType))
1756 result.addTypes(initType);
1757 }
1758
1759 if (bodyBuild)
1760 buildGenericRegion(builder, result.location, *result.regions.front(),
1761 inputs, inits, bodyBuild);
1762}
1763
1764SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1765 int64_t inputRank =
1766 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1767 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1768 utils::IteratorType::parallel);
1769 for (int64_t reductionDim : getDimensions())
1770 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1771 return iteratorTypes;
1772}
1773
1774ArrayAttr ReduceOp::getIndexingMaps() {
1775 int64_t inputRank =
1776 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1777 SmallVector<AffineMap> affineMaps(
1778 getNumDpsInputs(),
1780 AffineMap resultMap =
1782 .dropResults(getDimensions());
1783 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1784 affineMaps.push_back(resultMap);
1785 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1786}
1787
1788void ReduceOp::getEffects(
1789 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1790 &effects) {
1791 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1792}
1793
1794Speculation::Speculatability ReduceOp::getSpeculatability() {
1795 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1796}
1797
1798static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1799 NamedAttrList &attributes,
1800 StringRef attributeName) {
1801 if (parser.parseKeyword(attributeName) || parser.parseEqual())
1802 return failure();
1803
1804 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1805 return success();
1806}
1807
1808ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1809 std::optional<OperationName> payloadOpName;
1810 NamedAttrList payloadOpAttrs;
1811 if (succeeded(parser.parseOptionalLBrace())) {
1812 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1813 if (failed(operationName))
1814 return failure();
1815 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1816 return failure();
1817 payloadOpName = operationName.value();
1818 if (parser.parseRBrace())
1819 return failure();
1820 }
1821
1822 if (parseDstStyleOp(
1823 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1824 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1825 }))
1826 return failure();
1827
1828 if (payloadOpName.has_value()) {
1829 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1830 ArrayRef(result.operands), /*initFirst=*/true);
1831 } else {
1832 SmallVector<OpAsmParser::Argument> regionArgs;
1833 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1834 /*allowType=*/true, /*allowAttrs=*/true)) {
1835 return failure();
1836 }
1837
1838 Region *body = result.addRegion();
1839 if (parser.parseRegion(*body, regionArgs))
1840 return failure();
1841 }
1842
1843 return success();
1844}
1845
1846static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1847 ArrayRef<int64_t> attributeValue) {
1848 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1849}
1850
1851void ReduceOp::print(OpAsmPrinter &p) {
1852 Block *mapper = getBody();
1853 bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1854 if (useShortForm) {
1855 printShortForm(p, &mapper->getOperations().front());
1856 }
1857
1858 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1859 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1860 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1861 if (!useShortForm) {
1862 // Print region if the payload op was not detected.
1863 p.increaseIndent();
1864 p.printNewline();
1865 p << "(";
1866 llvm::interleaveComma(mapper->getArguments(), p,
1867 [&](auto arg) { p.printRegionArgument(arg); });
1868 p << ") ";
1869
1870 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1871 p.decreaseIndent();
1872 }
1873}
1874
1875LogicalResult ReduceOp::verify() {
1876 ArrayRef<int64_t> dimensionsRef = getDimensions();
1877
1878 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1879 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1880 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1881 return emitOpError() << "expects all inputs to have the same shapes. "
1882 "Shape at input-index "
1883 << i
1884 << " is not equal to the shape at input-index 0.";
1885 }
1886 }
1887 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1888 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1889 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1890 return emitOpError() << "expects all outputs to have the same shapes. "
1891 "Shape at output-index "
1892 << i
1893 << " is not equal to the shape at output-index 0.";
1894 }
1895 }
1896 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1897 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1898
1899 DenseSet<int64_t> dimensionsToReduce;
1900 for (int64_t dimension : dimensionsRef) {
1901 if (dimension < 0 || dimension >= inputType.getRank()) {
1902 return emitOpError()
1903 << "dimensions for reduction should be in the range [0, "
1904 << inputType.getRank() - 1 << "].";
1905 }
1906 dimensionsToReduce.insert(dimension);
1907 }
1908
1909 auto inputDims = inputType.getShape();
1910 auto initDims = initType.getShape();
1911
1912 // Input dimensions that will be left after the reduction.
1913 SmallVector<int64_t> reducedInputDims;
1914 for (const auto &en : llvm::enumerate(inputDims)) {
1915 if (!dimensionsToReduce.count(en.index()))
1916 reducedInputDims.push_back(en.value());
1917 }
1918
1919 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1920 return emitOpError() << "number of dimensions after reduction "
1921 << reducedInputDims.size()
1922 << " doesn't match the init rank "
1923 << initType.getRank();
1924 }
1925
1926 if (reducedInputDims != initDims)
1927 return emitOpError() << "init dimensions [" << initDims
1928 << "] doesn't match input dimensions after reduction ["
1929 << reducedInputDims << "]";
1930
1931 Block *block = getBody();
1932 if (block->getNumArguments() != this->getNumOperands())
1933 return emitOpError()
1934 << "mismatching number of operands and block arguments";
1935
1936 // Check that the first block arguments match the element type of the inputs.
1937 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1938 Type inputElementType =
1939 llvm::cast<ShapedType>(input.getType()).getElementType();
1940 if (inputElementType != bbArg.getType())
1941 return emitOpError()
1942 << "input element type " << inputElementType
1943 << " does not match corresponding block argument type "
1944 << bbArg.getType();
1945 }
1946
1947 // Check that the last block arguments match the element type of the outputs.
1948 for (auto [output, bbArg] : llvm::zip(
1949 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1950 auto outputElementType =
1951 llvm::cast<ShapedType>(output.getType()).getElementType();
1952 if (outputElementType != bbArg.getType())
1953 return emitOpError()
1954 << "output element type " << outputElementType
1955 << " does not match corresponding block argument type "
1956 << bbArg.getType();
1957 }
1958 return success();
1959}
1960
1961//===----------------------------------------------------------------------===//
1962// TransposeOp
1963//===----------------------------------------------------------------------===//
1964
1965static void buildIdentityRegion(OpBuilder &builder, Location loc,
1966 Region &region, ValueRange inputs,
1967 ValueRange outputs) {
1968 buildGenericRegion(builder, loc, region, inputs, outputs,
1969 [](OpBuilder &b, Location loc, ValueRange args) {
1970 if (!args.empty())
1971 linalg::YieldOp::create(b, loc, args[0]);
1972 });
1973}
1974
1975void TransposeOp::build(::mlir::OpBuilder &builder,
1976 ::mlir::OperationState &result, Value input, Value init,
1977 DenseI64ArrayAttr permutation,
1978 ArrayRef<NamedAttribute> attributes) {
1979 result.addOperands(input);
1980 result.addOperands(init);
1981 result.addAttribute(getPermutationAttrName(result.name), permutation);
1982 result.addAttributes(attributes);
1983
1984 // Add output types for `RankedTensorType` output arguments.
1985 Type initType = init.getType();
1986 if (llvm::isa<RankedTensorType>(initType))
1987 result.addTypes(initType);
1988
1989 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1990 init);
1991}
1992
1993void TransposeOp::build(::mlir::OpBuilder &builder,
1994 ::mlir::OperationState &result, Value input, Value init,
1995 ArrayRef<int64_t> permutation,
1996 ArrayRef<NamedAttribute> attributes) {
1997 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1998 attributes);
1999}
2000
2001ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
2003 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2004 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
2005 })))
2006 return failure();
2007
2008 OpBuilder builder(parser.getContext());
2009 buildIdentityRegion(builder, result.location, *result.addRegion(),
2010 /*inputs=*/result.operands,
2011 /*outputs=*/{});
2012 return success();
2013}
2014
2015void TransposeOp::getAsmResultNames(
2016 function_ref<void(Value, StringRef)> setNameFn) {
2017 if (!getResults().empty())
2018 setNameFn(getResults().front(), "transposed");
2019}
2020
2021void TransposeOp::print(OpAsmPrinter &p) {
2022 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2023 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2024 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2025}
2026
2027LogicalResult TransposeOp::verify() {
2028 ArrayRef<int64_t> permutationRef = getPermutation();
2029
2030 if (!isPermutationVector(permutationRef))
2031 return emitOpError("permutation is not valid");
2032
2033 auto inputType = getInput().getType();
2034 auto initType = getInit().getType();
2035
2036 int64_t rank = inputType.getRank();
2037
2038 if (rank != initType.getRank())
2039 return emitOpError() << "input rank " << rank
2040 << " does not match init rank " << initType.getRank();
2041
2042 if (rank != static_cast<int64_t>(permutationRef.size()))
2043 return emitOpError() << "size of permutation " << permutationRef.size()
2044 << " does not match the argument rank " << rank;
2045
2046 auto inputDims = inputType.getShape();
2047 auto initDims = initType.getShape();
2048
2049 for (int64_t i = 0; i < rank; ++i) {
2050 int64_t inputDim = inputDims[permutationRef[i]];
2051 int64_t initDim = initDims[i];
2052
2053 if (inputDim != initDim) {
2054 return emitOpError() << "dim(result, " << i << ") = " << initDim
2055 << " doesn't match dim(input, permutation[" << i
2056 << "]) = " << inputDim;
2057 }
2058 }
2059
2060 return success();
2061}
2062
2063SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2064 int64_t rank = getInit().getType().getRank();
2065 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2066}
2067
2068ArrayAttr TransposeOp::getIndexingMaps() {
2069 Builder builder(getContext());
2070 int64_t rank = getInit().getType().getRank();
2071 return builder.getAffineMapArrayAttr(
2073 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2074 builder.getMultiDimIdentityMap(rank)});
2075}
2076
2077void TransposeOp::getEffects(
2078 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2079 &effects) {
2080 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2081}
2082
2083Speculation::Speculatability TransposeOp::getSpeculatability() {
2084 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2085}
2086
2087LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2088 SmallVectorImpl<OpFoldResult> &result) {
2089 // Only the tensor type is supported.
2090 if (!isa<TensorType>(getInput().getType()))
2091 return failure();
2092
2093 // Single dimension transpose.
2094 if (getPermutation().size() == 0) {
2095 result.push_back(getInput());
2096 return success();
2097 }
2098 // Identity permutation.
2099 if (isIdentityPermutation(getPermutation())) {
2100 result.push_back(getInput());
2101 return success();
2102 }
2103
2104 return failure();
2105}
2106
2107/// Fold transpose with transpose.
2108struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2109 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2110
2111 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2112 PatternRewriter &rewriter) const override {
2113 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2114 if (!defTransposeOp)
2115 return failure();
2116 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2117 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2118 SmallVector<int64_t> foldedPerms;
2119 foldedPerms.reserve(perms.size());
2120 for (int64_t perm : perms)
2121 foldedPerms.push_back(defPerms[perm]);
2122
2123 rewriter.replaceOpWithNewOp<TransposeOp>(
2124 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2125 foldedPerms);
2126 return success();
2127 }
2128};
2129
2130/// This pattern canonicalize transpose by swapping the order of
2131/// broadcast and transpose:
2132/// transpose(broadcast(input)) -> broadcast(transpose(input))
2133struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2134 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2135
2136 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2137 PatternRewriter &rewriter) const override {
2138 Value input = transposeOp.getInput();
2139 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2140 if (!input.hasOneUse() || !broadcastOp)
2141 return failure();
2142
2143 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2144 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2145
2146 // Get new perms and new dimensions.
2147 SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2149 SmallVector<int64_t> resultDimensions;
2150 unsigned dimensionSize = dimensions.size();
2151 for (unsigned i = 0; i < dimensionSize; ++i)
2152 resultDimensions.push_back(invertPerm[dimensions[i]]);
2153
2154 // Create transpose result.
2155 Value broadcastInput = broadcastOp.getInput();
2156 Location loc = transposeOp.getLoc();
2157 MLIRContext *ctx = transposeOp.getContext();
2159 auto broadcastInputTy =
2160 mlir::cast<RankedTensorType>(broadcastInput.getType());
2161 unsigned inputRank = broadcastInputTy.getRank();
2162 for (unsigned i = 0; i < inputRank; ++i) {
2163 if (broadcastInputTy.isDynamicDim(i)) {
2164 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2165 ->getResult(0));
2166 } else {
2167 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2168 broadcastInputTy.getDimSize(i)));
2169 }
2170 }
2171 SmallVector<OpFoldResult> transposeResultShapes =
2172 applyPermutation(dims, resultPerms);
2173 Value transposeInit = tensor::EmptyOp::create(
2174 rewriter, transposeOp.getLoc(), transposeResultShapes,
2175 broadcastInputTy.getElementType());
2176
2177 // Create broadcast(transpose(input)).
2178 Value transposeResult =
2179 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2180 transposeInit, resultPerms)
2181 ->getResult(0);
2182 rewriter.replaceOpWithNewOp<BroadcastOp>(
2183 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2184 return success();
2185 }
2186};
2187
2188void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2189 MLIRContext *context) {
2190 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2191}
2192
2193//===----------------------------------------------------------------------===//
2194// BroadcastOp
2195//===----------------------------------------------------------------------===//
2196
2197void BroadcastOp::build(::mlir::OpBuilder &builder,
2198 ::mlir::OperationState &result, Value input, Value init,
2199 DenseI64ArrayAttr dimensions,
2200 ArrayRef<NamedAttribute> attributes) {
2201 result.addOperands(input);
2202 result.addOperands(init);
2203 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2204 result.addAttributes(attributes);
2205
2206 // Add output types for `RankedTensorType` output arguments.
2207 Type initType = init.getType();
2208 if (llvm::isa<RankedTensorType>(initType))
2209 result.addTypes(initType);
2210
2211 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2212 init);
2213}
2214
2215void BroadcastOp::build(::mlir::OpBuilder &builder,
2216 ::mlir::OperationState &result, Value input, Value init,
2217 ArrayRef<int64_t> dimensions,
2218 ArrayRef<NamedAttribute> attributes) {
2219 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2220 attributes);
2221}
2222
2223ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2225 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2226 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2227 })))
2228 return failure();
2229
2230 OpBuilder builder(parser.getContext());
2231 buildIdentityRegion(builder, result.location, *result.addRegion(),
2232 /*inputs=*/result.operands,
2233 /*outputs=*/{});
2234 return success();
2235}
2236
2237void BroadcastOp::getAsmResultNames(
2238 function_ref<void(Value, StringRef)> setNameFn) {
2239 if (!getResults().empty())
2240 setNameFn(getResults().front(), "broadcasted");
2241}
2242
2243void BroadcastOp::print(OpAsmPrinter &p) {
2244 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2245 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2246 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2247}
2248
2249LogicalResult BroadcastOp::verify() {
2250 ArrayRef<int64_t> dimensionsRef = getDimensions();
2251
2252 auto inputType = getInput().getType();
2253 auto initType = getInit().getType();
2254
2255 int64_t inputRank = inputType.getRank();
2256 int64_t initRank = initType.getRank();
2257
2258 auto inputShape = inputType.getShape();
2259 auto initShape = initType.getShape();
2260
2261 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2262 return emitOpError() << "input rank plus added dimensions does not "
2263 "match init rank. input rank: "
2264 << inputRank
2265 << ", dimensions size: " << dimensionsRef.size()
2266 << ", init rank: " << initRank;
2267
2268 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2269 if (dim < 0 || dim >= initRank)
2270 return emitOpError() << "dimension " << idx
2271 << " is out of range. expected range: [0, "
2272 << initRank - 1 << "], got: " << dim;
2273 }
2274
2275 // Mapping from input dims to init dims.
2276 SmallVector<int64_t> dimMap;
2277 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2278 if (!llvm::is_contained(dimensionsRef, dim))
2279 dimMap.push_back(dim);
2280 }
2281
2282 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2283 // This dimensions is mapped from the input. Init and input dims should
2284 // match.
2285 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2286 return emitOpError() << "input dim " << inputDimIdx
2287 << " should match init dim " << initDimIdx
2288 << ". input: " << inputShape[inputDimIdx]
2289 << ", init: " << initShape[initDimIdx];
2290 }
2291
2292 return success();
2293}
2294
2295SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2296 int64_t rank = getInit().getType().getRank();
2297 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2298}
2299
2300ArrayAttr BroadcastOp::getIndexingMaps() {
2301 Builder builder(getContext());
2302 int64_t rank = getInit().getType().getRank();
2303 return builder.getAffineMapArrayAttr(
2304 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2305 builder.getMultiDimIdentityMap(rank)});
2306}
2307
2308void BroadcastOp::getEffects(
2309 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2310 &effects) {
2311 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2312}
2313
2314Speculation::Speculatability BroadcastOp::getSpeculatability() {
2315 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2316}
2317
2318/// Fold back-to-back broadcasts together.
2319struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2320 using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
2321
2322 LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2323 PatternRewriter &rewriter) const override {
2324 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2325 if (!defBroadcastOp)
2326 return failure();
2327 ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2328 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2329 SmallVector<int64_t> foldedDims(dimensions);
2330 Value init = broadcastOp.getInit();
2331 int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2332 // Mapping from input dims to init dims.
2333 SmallVector<int64_t> dimMap;
2334 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2335 if (!llvm::is_contained(dimensions, dim))
2336 dimMap.push_back(dim);
2337 }
2338 for (auto dim : defDimensions)
2339 foldedDims.push_back(dimMap[dim]);
2340
2341 llvm::sort(foldedDims);
2342 rewriter.replaceOpWithNewOp<BroadcastOp>(
2343 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2344 return success();
2345 }
2346};
2347
2348void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2349 MLIRContext *context) {
2350 results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2351}
2352
2353//===----------------------------------------------------------------------===//
2354// YieldOp
2355//===----------------------------------------------------------------------===//
2356
2357void linalg::YieldOp::print(OpAsmPrinter &p) {
2358 if (getNumOperands() > 0)
2359 p << ' ' << getOperands();
2360 p.printOptionalAttrDict((*this)->getAttrs());
2361 if (getNumOperands() > 0)
2362 p << " : " << getOperandTypes();
2363}
2364
2365ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2366 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2367 SmallVector<Type, 2> types;
2368 SMLoc loc = parser.getCurrentLocation();
2369 return failure(parser.parseOperandList(opInfo) ||
2370 parser.parseOptionalAttrDict(result.attributes) ||
2371 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2372 parser.resolveOperands(opInfo, types, loc, result.operands));
2373}
2374
2375// Check the operand number and types must match the element types of the
2376// LinalgOp interface's shaped operands.
2377static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2378 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2379 return op.emitOpError("expected number of yield values (")
2380 << op.getNumOperands()
2381 << ") to match the number of inits / outs operands of the enclosing "
2382 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2383
2384 for (OpOperand &opOperand : op->getOpOperands()) {
2385 OpOperand *outputOperand =
2386 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2387 Type elementType = outputOperand->get().getType();
2388 if (isa<MemRefType, RankedTensorType>(elementType))
2389 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2390 if (opOperand.get().getType() != elementType)
2391 return op.emitOpError("type of yield operand ")
2392 << (opOperand.getOperandNumber() + 1) << " ("
2393 << opOperand.get().getType() << ") doesn't match "
2394 << "the element type of the enclosing linalg.generic op ("
2395 << elementType << ")";
2396 }
2397 return success();
2398}
2399
2400LogicalResult linalg::YieldOp::verify() {
2401 auto *parentOp = (*this)->getParentOp();
2402 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2403 return emitOpError("expected single non-empty parent region");
2404
2405 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2406 return verifyYield(*this, linalgOp);
2407
2408 return emitOpError("expected parent op with LinalgOp interface");
2409}
2410
2411//===----------------------------------------------------------------------===//
2412// IndexOp
2413//===----------------------------------------------------------------------===//
2414
2415LogicalResult IndexOp::verify() {
2416 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2417 if (!linalgOp)
2418 return emitOpError("expected parent op with LinalgOp interface");
2419 if (linalgOp.getNumLoops() <= getDim())
2420 return emitOpError("expected dim (")
2421 << getDim() << ") to be lower than the number of loops ("
2422 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2423 return success();
2424}
2425
2426OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2427 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2428 // Bail out if `linalg.index` does not have a proper parent yet at this
2429 // point, e.g., when calling `createOrFold` during IR construction in
2430 // `genericOp::build`.
2431 if (!linalgOp)
2432 return OpFoldResult{};
2433
2434 // Index of unit dims is always 0.
2435 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2436 uint64_t dim = getDim();
2437 assert(dim < loopBounds.size() && "Dim is out of bounds");
2438 if (loopBounds[dim] == 1)
2439 return IntegerAttr::get(IndexType::get(getContext()), 0);
2440
2441 return OpFoldResult{};
2442}
2443
2444/////// Operations corresponding to library calls defined with Tablegen ////////
2445
2446#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2447
2448#define GET_OP_CLASSES
2449#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2450
2451#define GET_OP_CLASSES
2452#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2453#define GET_OP_CLASSES
2454#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2455
2456AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2457 unsigned rank,
2458 MLIRContext *context) {
2459 if (maybeMap)
2460 return *maybeMap;
2461 if (rank == 0)
2462 return AffineMap::get(context);
2463 return AffineMap::getMultiDimIdentityMap(rank, context);
2464}
2465
2467mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2468 MLIRContext *context) {
2470 res.reserve(num);
2471 for (unsigned i = 0; i < num; ++i)
2472 res.push_back(getAffineDimExpr(startIdx++, context));
2473 return res;
2474}
2475
2478 auto rangeA = llvm::make_range(a.begin(), a.end());
2479 auto rangeB = llvm::make_range(b.begin(), b.end());
2480 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2481 return llvm::to_vector<4>(concatRanges);
2482}
2483
2484static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2485 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2486 ss << "view";
2487 for (auto size : memref.getShape())
2488 if (size < 0)
2489 ss << "sx";
2490 else
2491 ss << size << "x";
2492 if (failed(appendMangledType(ss, memref.getElementType())))
2493 return failure();
2494 if (auto as = memref.getMemorySpace()) {
2495 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2496 ss << "as" << attr.getInt();
2497 else
2498 return failure();
2499 }
2500 return success();
2501 }
2502 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2503 ss << "vector";
2504 llvm::interleave(
2505 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2506 if (failed(appendMangledType(ss, vec.getElementType())))
2507 return failure();
2508 return success();
2509 }
2511 ss << t;
2512 return success();
2513 }
2514 return failure();
2515}
2516
2518 assert(isa<LinalgOp>(op));
2519 std::string name(op->getName().getStringRef().str());
2520 std::string fun = "";
2521 for (NamedAttribute kv : op->getAttrs()) {
2522 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2523 fun = stringifyEnum(ufa.getValue()).str() + "_";
2524 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2525 fun = stringifyEnum(bfa.getValue()).str() + "_";
2526 }
2527 }
2528 name.reserve(128);
2529 llvm::replace(name, '.', '_');
2530 llvm::raw_string_ostream ss(name);
2531 ss << "_" << fun;
2532 for (Type t : op->getOperandTypes()) {
2533 if (failed(appendMangledType(ss, t)))
2534 return std::string();
2535 ss << "_";
2536 }
2537 name.pop_back();
2538 return name;
2539}
2540
2541//===----------------------------------------------------------------------===//
2542// Canonicalizers and Folders.
2543//===----------------------------------------------------------------------===//
2544
2545namespace {
2546struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2548
2549 LogicalResult matchAndRewrite(LinalgOp op,
2550 PatternRewriter &rewriter) const override {
2551 for (OpOperand &opOperand : op->getOpOperands()) {
2552 // Linalg "inputs" may be either tensor or memref type.
2553 // tensor<0xelt_type> is a convention that may not always mean
2554 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2555 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2556 if (!mt)
2557 continue;
2558 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2559 rewriter.eraseOp(op);
2560 return success();
2561 }
2562 }
2563 return failure();
2564 }
2565};
2566
2567/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2568/// result that is more static than the linalg op.
2569struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2570 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2571
2572 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2573 PatternRewriter &rewriter) const override {
2574 if (!tensor::canFoldIntoProducerOp(castOp))
2575 return failure();
2576
2577 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2578 if (!linalgOp)
2579 return failure();
2580
2581 // Cast can be in conditionally reachable region, if which case folding will
2582 // generate invalid code. Only conservatively fold ops in same block for
2583 // now.
2584 if (castOp->getBlock() != linalgOp->getBlock())
2585 return failure();
2586
2587 OpBuilder::InsertionGuard guard(rewriter);
2588 rewriter.setInsertionPoint(linalgOp);
2589
2590 Location loc = linalgOp.getLoc();
2591 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2592 unsigned resultNumber = resultValue.getResultNumber();
2593 auto resultType =
2594 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2595 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2596 // going from a more dynamic shape to a less dynamic shape. If the producer
2597 // for this cast, i.e. producer of the out operand, is also an operation
2598 // that folds with tensor.cast consumer (like this pattern), the cast will
2599 // continue to propagate as far up the stack as it can go.
2600 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2601 Value newOperand =
2602 tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2603 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2604 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2605 linalgOp.getDpsInits().end());
2606 outputOperands[resultNumber] = newOperand;
2607 newOperands.append(outputOperands.begin(), outputOperands.end());
2608
2609 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2610 linalgOp->result_type_end());
2611 resultTypes[resultNumber] = resultType;
2612 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2613
2614 // Create a tensor.cast operation back to the original type.
2615 Value castBack = tensor::CastOp::create(
2616 rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2617
2618 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2619 results[resultNumber] = castBack;
2620 rewriter.replaceOp(linalgOp, results);
2621 rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2622 return success();
2623 }
2624};
2625
2626/// For each of the operand in `operands` this function maps the static sizes of
2627/// dimensions to their affine dim expressions.
2628static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2629 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2630 for (OpOperand &opOperand : operands) {
2631 if (linalgOp.isScalar(&opOperand))
2632 continue;
2633 Value src = opOperand.get();
2634 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2635 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2636
2637 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2638 // `tensor.cast` operation and source of the cast operation has a static
2639 // shape, then assign it to the `sourceShape`.
2640 auto *parentOp = src.getDefiningOp();
2641 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2642 if (parentOp) {
2643 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2644 Value castSource = castOp.getSource();
2645 auto castSourceType =
2646 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2647 if (castSourceType && castSourceType.hasStaticShape())
2648 sourceShape = castSourceType.getShape();
2649 }
2650 }
2651
2652 // If the source shape's dimension has a static shape, map the affine dim
2653 // expression to the known static size.
2654 for (unsigned i = 0; i < sourceShape.size(); i++) {
2655 if (sourceType.isDynamicDim(i))
2656 continue;
2657 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2658 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2659 }
2660 }
2661}
2662
2663/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2664/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2665/// their result types is stored in `resultTypes`. If `opOperand` requires no
2666/// change then `changeNeeded` is false and same operand is added in the
2667/// `newOperands` list.
2668static void createNewOperandWithStaticSizes(
2669 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2670 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2671 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2672 bool &changeNeeded) {
2673 Value src = opOperand->get();
2674 newOperands.push_back(src);
2675 if (linalgOp.isScalar(opOperand))
2676 return;
2677 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2678 Type resultType = sourceType;
2679 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2680 resultTypes.push_back(resultType);
2681 return;
2682 }
2683 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2684 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2685 SmallVector<int64_t> newShape;
2686 // If operand is updated with new shape, `newOperandNeeded` will be
2687 // true.
2688 bool newOperandNeeded = false;
2689 for (unsigned i = 0; i < sourceShape.size(); i++) {
2690 int64_t dimShape = sourceShape[i];
2691 AffineExpr dimExpr = sourceMap.getResult(i);
2692 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2693 newShape.push_back(dimShape);
2694 continue;
2695 }
2696 // Dimension has a dynamic shape and corresponding affine dim
2697 // expression is present in the map. So assign the size for the
2698 // given affine dim expression to the dimension.
2699 newShape.push_back(affineExprToSize[dimExpr]);
2700 newOperandNeeded = true;
2701 }
2702 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2703 sourceType.getEncoding());
2704 if (newOperandNeeded) {
2705 changeNeeded = true;
2706 // Get the new operand value given its size and element type by
2707 // casting it.
2708 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2709 unsigned index = opOperand->getOperandNumber();
2710 newOperands[index] = newOperand;
2711 }
2712 if (linalgOp.isDpsInit(opOperand))
2713 resultTypes.push_back(resultType);
2714}
2715
2716/// Static shapes for the operands can be inferred if any one of the operands
2717/// have a static shape. This can be done by referring to the affine dim
2718/// expressions for the operand.
2719struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2720 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2721
2722 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2723 PatternRewriter &rewriter) const override {
2724 if (!linalgOp.hasPureTensorSemantics())
2725 return failure();
2726
2727 // Maps must be projected permutations.
2728 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2729 return !map.isProjectedPermutation();
2730 }))
2731 return failure();
2732
2733 // Maps affine dim expressions to the static size of that dimension.
2734 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2735 Location loc = linalgOp.getLoc();
2736
2737 // For each of the affine dim expression, check if the size is known. If
2738 // known add that in the map.
2739 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2740
2741 SmallVector<Value> newOperands;
2742 SmallVector<Type> resultTypes;
2743
2744 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2745 // change in their types.
2746 bool changeNeeded = false;
2747 newOperands.reserve(linalgOp->getNumOperands());
2748 resultTypes.reserve(linalgOp.getNumDpsInits());
2749
2750 // Iterate over all the operands and update the static sizes.
2751 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2752 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2753 affineExprToSize, linalgOp, newOperands,
2754 resultTypes, changeNeeded);
2755 }
2756
2757 // If the generic op has all the required static information, no
2758 // canonicalization needed.
2759 if (!changeNeeded)
2760 return failure();
2761
2762 // Clone op.
2763 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2764 SmallVector<Value> replacements;
2765 replacements.reserve(newOp->getNumResults());
2766 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2767 Value newResult = std::get<1>(it);
2768 Value oldResult = std::get<0>(it);
2769 Type newType = newResult.getType();
2770 Type oldType = oldResult.getType();
2771 replacements.push_back(
2772 (newType != oldType)
2773 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2774 : newResult);
2775 }
2776 rewriter.replaceOp(linalgOp, replacements);
2777 return success();
2778 }
2779};
2780
2781} // namespace
2782
2783// All named ops canonicalizers and folders are auto-generated in the
2784// .cpp.inc.
2785
2786//===----------------------------------------------------------------------===//
2787// SoftmaxOp
2788//===----------------------------------------------------------------------===//
2789
2790LogicalResult SoftmaxOp::verify() {
2791 ShapedType inputType = getInputOperandType();
2792 ShapedType outputType = getOutputOperandType();
2793
2794 ArrayRef<int64_t> inputShape = inputType.getShape();
2795 ArrayRef<int64_t> outputShape = outputType.getShape();
2796 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2797 return emitOpError("incompatible output shape");
2798
2799 int64_t inputRank = getInputOperandRank();
2800 int64_t dimension = getDimension();
2801 if ((dimension < 0) || (dimension >= inputRank))
2802 return emitOpError("incorrect dimension specified");
2803
2804 return success();
2805}
2806
2807SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2808 int64_t operandRank = getInputOperandRank();
2809 SmallVector<Range> loopBounds(operandRank);
2810 Location loc = getLoc();
2811 Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2812 Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2813 Value source = getInput();
2814 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2815 loopBounds[dim].offset = zero;
2816 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2817 loopBounds[dim].stride = one;
2818 }
2819 return loopBounds;
2820}
2821
2822SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2823 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2824 utils::IteratorType::parallel);
2825 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2826 return iteratorTypes;
2827}
2828
2829FailureOr<TilingResult>
2830SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2831 ArrayRef<OpFoldResult> offsets,
2832 ArrayRef<OpFoldResult> sizes) {
2833 int64_t rank = getInputOperandRank();
2834 auto oneAttr = builder.getI64IntegerAttr(1);
2835 SmallVector<OpFoldResult> strides(rank, oneAttr);
2836 SmallVector<Value> tiledOperands;
2837 Operation *inputSlice =
2838 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2839 if (!inputSlice) {
2840 return emitOpError("failed to compute input slice");
2841 }
2842 tiledOperands.emplace_back(inputSlice->getResult(0));
2843 Operation *outputSlice =
2844 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2845 if (!outputSlice) {
2846 return emitOpError("failed to compute output slice");
2847 }
2848 tiledOperands.emplace_back(outputSlice->getResult(0));
2849
2850 SmallVector<Type, 4> resultTypes;
2851 if (hasPureTensorSemantics())
2852 resultTypes.push_back(tiledOperands[1].getType());
2853 Operation *tiledOp =
2854 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2855
2856 return TilingResult{
2857 {tiledOp},
2858 SmallVector<Value>(tiledOp->getResults()),
2859 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2860}
2861
2862LogicalResult SoftmaxOp::getResultTilePosition(
2863 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2864 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2865 SmallVector<OpFoldResult> &resultSizes) {
2866 if (resultNumber == 0) {
2867 resultOffsets.assign(offsets.begin(), offsets.end());
2868 resultSizes.assign(sizes.begin(), sizes.end());
2869 return success();
2870 }
2871 return failure();
2872}
2873
2874// cast(dynamic) -> static.
2875LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2876 return memref::foldMemRefCast(*this);
2877}
2878
2879LogicalResult
2880SoftmaxOp::reifyResultShapes(OpBuilder &b,
2881 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2882 SmallVector<OpFoldResult> shapes;
2883 Location loc = getOperation()->getLoc();
2884 IRRewriter rewriter(b);
2885 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2886 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2887 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2888 if (!outputShapedType.isDynamicDim(dim)) {
2889 // Static dim: Return IntegerAttr.
2890 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2891 } else {
2892 // Dynamic dim: Return Value.
2893 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2894 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2895 }
2896 }
2897 reifiedReturnShapes.emplace_back(std::move(shapes));
2898 return success();
2899}
2900
2901void SoftmaxOp::getEffects(
2902 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2903 &effects) {
2904 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2905 if (!llvm::isa<MemRefType>(operand.getType()))
2906 continue;
2907 effects.emplace_back(MemoryEffects::Read::get(),
2908 &getOperation()->getOpOperand(index), /*stage=*/0,
2909 /*effectOnFullRegion=*/true,
2911 }
2912
2913 for (OpOperand &operand : getDpsInitsMutable()) {
2914 if (!llvm::isa<MemRefType>(operand.get().getType()))
2915 continue;
2916 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2917 /*effectOnFullRegion=*/true,
2919 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2920 /*effectOnFullRegion=*/true,
2922 }
2923}
2924
2925// Helper functions for softmax decomposition.
2926// @{
2927
2928// Helper function to produce the iterator types (reduction or parallel) and
2929// affine maps for the iterators used in the decomposition of softmax.
2930// This method creates:
2931// If allParallel == true:
2932// - iterator type: {parallel, ..., parallel}
2933// - affine maps:
2934// -- identity with inputRank dimensions.
2935// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2936// where N == inputRank.
2937//
2938// If allParallel == false:
2939// - iterator type at dim(i) == parallel for i != \p dim and
2940// dim(dim) == reduction.
2941// - affine map:
2942// -- identity with inputRank dimensions.
2943// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2944// where N == inputRank.
2945static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2947 int64_t dim, bool allParallel = false) {
2948 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2949 utils::IteratorType::parallel);
2950 if (!allParallel)
2951 iteratorTypes[dim] = utils::IteratorType::reduction;
2952 MLIRContext *ctxt = builder.getContext();
2953 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2954 SmallVector<AffineExpr, 2> affineExprs;
2955 for (int i = 0; i < inputRank; i++) {
2956 if (i != dim)
2957 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2958 }
2959 auto reductionMap =
2960 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2961 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2962 return std::make_tuple(iteratorTypes, indexingMaps);
2963}
2964
2965// Helper function to produce a linalg.generic that computes a reduction on
2966// dimension \p dim with the operation type \p T.
2967template <typename T>
2968static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2969 int64_t dim) {
2970 auto inputType = cast<ShapedType>(input.getType());
2971 ArrayRef<int64_t> inputShape = inputType.getShape();
2972 int64_t inputRank = inputShape.size();
2973 auto [iteratorTypes, indexingMaps] =
2974 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2975 assert(indexingMaps.size() == 2 &&
2976 "We should have two maps: 1 for the input, 1 for the output");
2977 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2978
2979 auto genericOp = linalg::GenericOp::create(
2980 builder, loc, output.getType(), input, output, indexingMaps,
2981 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2982 Value result = T::create(b, loc, args[0], args[1]);
2983 linalg::YieldOp::create(b, loc, result);
2984 });
2985 return genericOp.getResult(0);
2986}
2987
2988/// Produce a linalg generic that computes the second step of the softmax
2989/// decomposition: res = exp(input - max), where \p max is the max of \p input
2990/// on dimension \p dim.
2991static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2992 Value max, Value output, int64_t dim) {
2993 auto inputType = cast<ShapedType>(input.getType());
2994 ArrayRef<int64_t> inputShape = inputType.getShape();
2995 int64_t inputRank = inputShape.size();
2996 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2997 builder, inputRank, dim, /*allParallel=*/true);
2998 assert(indexingMaps.size() == 2 && "We should have one map for each input");
2999 assert(indexingMaps[0].isIdentity() && "input map should be identity");
3000 // Add the affine map for the output argument.
3001 indexingMaps.push_back(indexingMaps[0]);
3002 auto genericOp = linalg::GenericOp::create(
3003 builder, loc, input.getType(), ValueRange{input, max}, output,
3004 indexingMaps, iteratorTypes,
3005 [&](OpBuilder &b, Location loc, ValueRange args) {
3006 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3007 Value result = math::ExpOp::create(b, loc, diff);
3008 linalg::YieldOp::create(b, loc, result);
3009 });
3010 return genericOp.getResult(0);
3011}
3012
3013/// Produce a linalg generic that computes the final step of the softmax
3014/// decomposition.
3015/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
3016/// yield n / d
3017/// }
3018static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
3019 Value denominator, Value output, int64_t dim) {
3020 auto inputType = cast<ShapedType>(numerator.getType());
3021 ArrayRef<int64_t> inputShape = inputType.getShape();
3022 int64_t inputRank = inputShape.size();
3023 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3024 builder, inputRank, dim, /*allParallel=*/true);
3025 assert(indexingMaps.size() == 2 &&
3026 "We should have one map for each input (2)");
3027 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
3028 // Add the affine map for the output tensor.
3029 indexingMaps.push_back(indexingMaps[0]);
3030 auto genericOp = linalg::GenericOp::create(
3031 builder, loc, numerator.getType(), ValueRange{numerator, denominator},
3032 output, indexingMaps, iteratorTypes,
3033 [&](OpBuilder &b, Location loc, ValueRange args) {
3034 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3035 linalg::YieldOp::create(b, loc, result);
3036 });
3037 return genericOp.getResult(0);
3038}
3039// @} End helper functions for softmax decomposition.
3040
3041/// Given an N-dimensional tensor x, this method converts
3042/// softmax(x) to the following sequence of operations:
3043///
3044/// 1. Compute the max of x along dimension d. This results
3045/// in a N-1 dimensional tensor m.
3046/// m = max(x, dim = d)
3047///
3048/// 2. Subtract a broadcasted m from x and exponentiate. This results in
3049/// a N dimensional tensor z.
3050/// z = exp(x - m)
3051///
3052/// 3. Compute the sum of z along dimension d. This results in
3053/// a N-1 dimensional tensor l.
3054/// l = sum(z, dim = d)
3055///
3056/// 4. Divide z and l. This gives the N-dimensional softmax.
3057/// softmax = z / l
3058///
3059FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3060 OpBuilder::InsertionGuard guard(b);
3061 b.setInsertionPoint(*this);
3062 Location loc = getLoc();
3063 Value input = getInput();
3064 ShapedType inputType = getInputOperandType();
3065 Type elementType = inputType.getElementType();
3066 int64_t reductionDim = getDimension();
3067 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3068 Value output = getOutput();
3069 dims.erase(dims.begin() + reductionDim);
3070 // Step 1: Compute max along dim.
3071 Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3072 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3073 elementType, b, loc,
3074 /*useOnlyFiniteValue=*/true);
3075 Value neutralForMaxFInit =
3076 linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3077 .result();
3078 Value max =
3079 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3080
3081 // Step 2: Subtract max from input and exponentiate.
3082 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3083
3084 // Step 3: Compute sum along dim.
3085 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3086 b, loc, /*useOnlyFiniteValue=*/true);
3087 Value zeroInit =
3088 linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3089 Value denominator =
3090 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3091
3092 // Step 4: Compute softmax.
3093 Value result =
3094 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3095 return SmallVector<Value>{result};
3096}
3097
3098//===----------------------------------------------------------------------===//
3099// WinogradFilterTransformOp
3100//===----------------------------------------------------------------------===//
3101
3102LogicalResult WinogradFilterTransformOp::verify() {
3103 auto filterType = cast<ShapedType>(getFilter().getType());
3104 ArrayRef<int64_t> filterShape = filterType.getShape();
3105 int64_t filterH = filterShape[getFilterHDim()];
3106 int64_t filterW = filterShape[getFilterWDim()];
3107 WinogradConv2DFmr fmr = getFmr();
3108 int64_t m, r;
3109 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3110
3111 if (filterH != r && filterH != 1)
3112 return emitOpError("expect filter height either equals to r or 1");
3113 if (filterW != r && filterW != 1)
3114 return emitOpError("expect filter width either equals to r or 1");
3115 if (filterH == 1 && filterW == 1)
3116 return emitOpError("expect either filter height or width equals to r");
3117
3118 SmallVector<int64_t> expectedOutputShape;
3119 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3120 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3121 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3122 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3123
3124 auto outputType = cast<ShapedType>(getOutput().getType());
3125 ArrayRef<int64_t> outputShape = outputType.getShape();
3126 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3127 return emitOpError("the output shape is not expected");
3128 }
3129 return success();
3130}
3131
3132SmallVector<Range>
3133WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3134 Location loc = getLoc();
3135 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3136 IntegerAttr oneAttr = builder.getIndexAttr(1);
3137 Value filter = getFilter();
3138 int64_t filterRank = getFilterOperandRank();
3139 SmallVector<Range> loopBounds(filterRank);
3140 for (unsigned dim = 0; dim < filterRank; ++dim) {
3141 loopBounds[dim].offset = zeroAttr;
3142 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3143 loopBounds[dim].stride = oneAttr;
3144 }
3145 return loopBounds;
3146}
3147
3148SmallVector<utils::IteratorType>
3149WinogradFilterTransformOp::getLoopIteratorTypes() {
3150 int64_t filterRank = getFilterOperandRank();
3151 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3152 utils::IteratorType::parallel);
3153 return iteratorTypes;
3154}
3155
3156LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3157 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3158 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3159 SmallVector<OpFoldResult> &resultSizes) {
3160 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3161 ShapedType filterType = getFilterOperandType();
3162 ArrayRef<int64_t> filterShape = filterType.getShape();
3163 int64_t filterH = filterShape[getFilterHDim()];
3164 int64_t filterW = filterShape[getFilterWDim()];
3165 WinogradConv2DFmr fmr = getFmr();
3166 int64_t m, r;
3167 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3168 int64_t alpha = m + r - 1;
3169 int64_t alphaH = filterH != 1 ? alpha : 1;
3170 int64_t alphaW = filterW != 1 ? alpha : 1;
3171 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3172 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3173
3174 resultOffsets.append(
3175 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3176 resultSizes.append(
3177 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3178
3179 return success();
3180}
3181
3182/// Implement tiling for winograd_filter_transform
3183/// The input of winograd_filter_transform is (F, KH, KW, C).
3184/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3185/// Users can specify the tile sizes of F and C.
3186/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3187/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3188FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3189 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3190 ArrayRef<OpFoldResult> sizes) {
3191 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3192 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3193 ShapedType filterType = getFilterOperandType();
3194 ArrayRef<int64_t> filterShape = filterType.getShape();
3195 int64_t filterH = filterShape[getFilterHDim()];
3196 int64_t filterW = filterShape[getFilterWDim()];
3197 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3198 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3199 SmallVector<Value> tiledOperands;
3200 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3201
3202 sliceOffsets.append(
3203 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3204 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3205 sizes[getFilterCDim()]});
3206 int64_t filterRank = getFilterOperandRank();
3207 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3208 Location loc = getLoc();
3209 auto filterSlice = tensor::ExtractSliceOp::create(
3210 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3211 tiledOperands.emplace_back(filterSlice);
3212
3213 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3214 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3215 resultSizes)))
3216 return failure();
3217
3218 int64_t outputRank = getOutputOperandRank();
3219 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3220 auto outputSlice = tensor::ExtractSliceOp::create(
3221 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3222 tiledOperands.emplace_back(outputSlice);
3223
3224 SmallVector<Type> resultTypes;
3225 resultTypes.push_back(tiledOperands[1].getType());
3226 Operation *tiledOp =
3227 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3228
3229 return TilingResult{
3230 {tiledOp},
3231 SmallVector<Value>(tiledOp->getResults()),
3232 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3233}
3234
3235//===----------------------------------------------------------------------===//
3236// WinogradInputTransformOp
3237//===----------------------------------------------------------------------===//
3238
3239LogicalResult WinogradInputTransformOp::verify() {
3240 auto inputType = cast<ShapedType>(getInput().getType());
3241 ArrayRef<int64_t> inputShape = inputType.getShape();
3242 int64_t inputH = inputShape[getInputHDim()];
3243 int64_t inputW = inputShape[getInputWDim()];
3244 WinogradConv2DFmr fmr = getFmr();
3245 int64_t m, r;
3246 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3247 int64_t tileSize = m + r - 1;
3248
3249 auto outputType = cast<ShapedType>(getOutput().getType());
3250 ArrayRef<int64_t> outputShape = outputType.getShape();
3251 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3252 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3253
3254 SmallVector<int64_t> expectedOutputShape(6, inputH);
3255 if (ShapedType::isDynamic(inputH)) {
3256 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3257 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3258 } else {
3259 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3260 expectedOutputShape[getOutputTileHDim()] =
3261 leftTransform ? (inputH - (r - 1)) / m : inputH;
3262 }
3263 if (ShapedType::isDynamic(inputW)) {
3264 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3265 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3266 } else {
3267 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3268 expectedOutputShape[getOutputTileWDim()] =
3269 rightTransform ? (inputW - (r - 1)) / m : inputW;
3270 }
3271 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3272 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3273
3274 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3275 return emitOpError("the output shape is not expected");
3276 }
3277 return success();
3278}
3279
3280SmallVector<Range>
3281WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3282 Location loc = getLoc();
3283 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3284 IntegerAttr oneAttr = builder.getIndexAttr(1);
3285 Value output = getOutput();
3286 int64_t outputRank = getOutputOperandRank();
3287 SmallVector<Range> loopBounds(outputRank);
3288 for (unsigned dim = 0; dim < outputRank; ++dim) {
3289 loopBounds[dim].offset = zeroAttr;
3290 // alphaH, alphaW, tileH, tileW, N, C
3291 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3292 loopBounds[dim].stride = oneAttr;
3293 }
3294 return loopBounds;
3295}
3296
3297SmallVector<utils::IteratorType>
3298WinogradInputTransformOp::getLoopIteratorTypes() {
3299 int64_t outputRank = getOutputOperandRank();
3300 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3301 utils::IteratorType::parallel);
3302 return iteratorTypes;
3303}
3304
3305LogicalResult WinogradInputTransformOp::getResultTilePosition(
3306 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3307 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3308 SmallVector<OpFoldResult> &resultSizes) {
3309 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3310 ShapedType outputType = getOutputOperandType();
3311 ArrayRef<int64_t> outputShape = outputType.getShape();
3312 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3313 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3314
3315 WinogradConv2DFmr fmr = getFmr();
3316 int64_t m, r;
3317 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3318 int64_t alpha = m + r - 1;
3319 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3320 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3321
3322 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3323 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3324
3325 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3326 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3327 offsets[getOutputCDim()]});
3328 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3329 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3330 sizes[getOutputCDim()]});
3331
3332 return success();
3333}
3334
3335/// Implement tiling for winograd_input_transform
3336/// The input of winograd_input_transform is (N, H, W, C).
3337/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3338/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3339/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3340/// the values for the sizes of tileH, tileW, N, C for one tile.
3341FailureOr<TilingResult>
3342WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3343 ArrayRef<OpFoldResult> offsets,
3344 ArrayRef<OpFoldResult> sizes) {
3345 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3346 WinogradConv2DFmr fmr = getFmr();
3347 int64_t m, r;
3348 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3349
3350 ShapedType outputType = getOutputOperandType();
3351 ArrayRef<int64_t> outputShape = outputType.getShape();
3352 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3353 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3354
3355 Location loc = getLoc();
3356 MLIRContext *context = builder.getContext();
3357 auto identityAffineMap =
3358 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3359 auto offsetAffineMap =
3360 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3361 Value mappedOffsetH = affine::makeComposedAffineApply(
3362 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3363 offsets[getOutputTileHDim()]);
3364 Value mappedOffsetW = affine::makeComposedAffineApply(
3365 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3366 offsets[getOutputTileWDim()]);
3367 auto sizeAffineMap = AffineMap::get(
3368 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3369 Value mappedSizeH = affine::makeComposedAffineApply(
3370 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3371 Value mappedSizeW = affine::makeComposedAffineApply(
3372 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3373
3374 SmallVector<Value> tiledOperands;
3375 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3376
3377 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3378 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3379 sliceOffsets.append(
3380 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3381 OpFoldResult sizeH =
3382 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3383 OpFoldResult sizeW =
3384 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3385 sliceSizes.append(
3386 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3387 int64_t inputRank = getInputOperandRank();
3388 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3389 auto inputSlice = tensor::ExtractSliceOp::create(
3390 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3391 tiledOperands.emplace_back(inputSlice);
3392
3393 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3394 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3395 resultSizes)))
3396 return failure();
3397
3398 int64_t outputRank = getOutputOperandRank();
3399 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3400 auto outputSlice = tensor::ExtractSliceOp::create(
3401 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3402 tiledOperands.emplace_back(outputSlice);
3403
3404 SmallVector<Type> resultTypes;
3405 resultTypes.push_back(tiledOperands[1].getType());
3406 Operation *tiledOp =
3407 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3408
3409 return TilingResult{
3410 {tiledOp},
3411 SmallVector<Value>(tiledOp->getResults()),
3412 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3413}
3414
3415//===----------------------------------------------------------------------===//
3416// WinogradOutputTransformOp
3417//===----------------------------------------------------------------------===//
3418
3419LogicalResult WinogradOutputTransformOp::verify() {
3420 auto valueType = cast<ShapedType>(getValue().getType());
3421 ArrayRef<int64_t> valueShape = valueType.getShape();
3422 int64_t valueH = valueShape[getValueAlphaHDim()];
3423 int64_t valueW = valueShape[getValueAlphaWDim()];
3424 int64_t valueTileH = valueShape[getValueTileHDim()];
3425 int64_t valueTileW = valueShape[getValueTileWDim()];
3426 WinogradConv2DFmr fmr = getFmr();
3427 int64_t m, r;
3428 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3429 bool leftTransform = valueH != 1;
3430 bool rightTransform = valueW != 1;
3431
3432 int64_t outputRank = getOutputOperandRank();
3433 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3434 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3435 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3436 } else {
3437 if (valueH != (leftTransform ? m + r - 1 : 1))
3438 return emitOpError("expect input height equals to input tile size");
3439 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3440 }
3441 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3442 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3443 } else {
3444 if (valueW != (rightTransform ? m + r - 1 : 1))
3445 return emitOpError("expect input width equals to input tile size");
3446 expectedOutputShape[getOutputWDim()] =
3447 (rightTransform ? m : 1) * valueTileW;
3448 }
3449 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3450 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3451
3452 auto outputType = cast<ShapedType>(getOutput().getType());
3453 ArrayRef<int64_t> outputShape = outputType.getShape();
3454 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3455 return emitOpError("the output shape is not expected");
3456 }
3457 return success();
3458}
3459
3460SmallVector<Range>
3461WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3462 Location loc = getLoc();
3463 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3464 IntegerAttr oneAttr = builder.getIndexAttr(1);
3465 Value value = getValue();
3466 int64_t valueRank = getValueOperandRank();
3467 SmallVector<Range> loopBounds(valueRank);
3468 for (unsigned dim = 0; dim < valueRank; ++dim) {
3469 loopBounds[dim].offset = zeroAttr;
3470 // alphaH, alphaW, tileH, tileW, N, F
3471 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3472 loopBounds[dim].stride = oneAttr;
3473 }
3474 return loopBounds;
3475}
3476
3477SmallVector<utils::IteratorType>
3478WinogradOutputTransformOp::getLoopIteratorTypes() {
3479 int64_t valueRank = getValueOperandRank();
3480 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3481 utils::IteratorType::parallel);
3482 return iteratorTypes;
3483}
3484
3485LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3486 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3487 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3488 SmallVector<OpFoldResult> &resultSizes) {
3489 WinogradConv2DFmr fmr = getFmr();
3490 int64_t m, r;
3491 std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3492
3493 Location loc = getLoc();
3494 MLIRContext *context = builder.getContext();
3495 auto identityAffineMap =
3496 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3497 auto affineMap =
3498 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3499
3500 ShapedType valueType = getValueOperandType();
3501 ArrayRef<int64_t> valueShape = valueType.getShape();
3502 int64_t valueH = valueShape[0];
3503 int64_t valueW = valueShape[1];
3504 Value mappedOffsetH = affine::makeComposedAffineApply(
3505 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3506 offsets[getValueTileHDim()]);
3507 Value mappedOffsetW = affine::makeComposedAffineApply(
3508 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3509 offsets[getValueTileWDim()]);
3510 Value mappedSizeH = affine::makeComposedAffineApply(
3511 builder, loc, affineMap, sizes[getValueTileHDim()]);
3512 Value mappedSizeW = affine::makeComposedAffineApply(
3513 builder, loc, affineMap, sizes[getValueTileWDim()]);
3514
3515 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3516 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3517 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3518 OpFoldResult sizeH =
3519 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3520 OpFoldResult sizeW =
3521 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3522
3523 resultOffsets.append(
3524 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3525 resultSizes.append(
3526 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3527 return success();
3528}
3529
3530/// Implement tiling for winograd_output_transform
3531/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3532/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3533/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3534/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3535/// for the sizes of tileH, tileW, N, F for one tile.
3536FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3537 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3538 ArrayRef<OpFoldResult> sizes) {
3539 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3540 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3541 Location loc = getLoc();
3542 SmallVector<Value> tiledOperands;
3543 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3544
3545 ShapedType valueType = getValueOperandType();
3546 ArrayRef<int64_t> valueShape = valueType.getShape();
3547 int64_t alphaH = valueShape[getValueAlphaHDim()];
3548 int64_t alphaW = valueShape[getValueAlphaWDim()];
3549 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3550 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3551
3552 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3553 offsets[getValueTileWDim()], offsets[getValueNDim()],
3554 offsets[getValueFDim()]});
3555 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3556 sizes[getValueTileWDim()], sizes[getValueNDim()],
3557 sizes[getValueFDim()]});
3558 int64_t valueRank = getValueOperandRank();
3559 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3560 auto valueSlice = tensor::ExtractSliceOp::create(
3561 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3562 tiledOperands.emplace_back(valueSlice);
3563
3564 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3565 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3566 resultSizes)))
3567 return failure();
3568
3569 int64_t outputRank = getOutputOperandRank();
3570 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3571 auto outputSlice = tensor::ExtractSliceOp::create(
3572 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3573 tiledOperands.emplace_back(outputSlice);
3574
3575 SmallVector<Type> resultTypes;
3576 resultTypes.push_back(tiledOperands[1].getType());
3577 Operation *tiledOp =
3578 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3579
3580 return TilingResult{
3581 {tiledOp},
3582 SmallVector<Value>(tiledOp->getResults()),
3583 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3584}
3585
3586//===----------------------------------------------------------------------===//
3587// LinalgDialect
3588// TODO: Merge with the LinalgDialect block at the bottom
3589//===----------------------------------------------------------------------===//
3590
3591// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3592static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3593 auto explicitRange = subMap.getResults();
3594 auto defaultRange = fullMap.getResults();
3595 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3596 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3597 llvm::set_union(explicitSet, defaultSet);
3598 return explicitSet == defaultSet;
3599}
3600
3601/// Check if the user defined map is valid broadcast map. Here broadcast
3602/// indexing maps are defined in context of corresponding default indexing maps
3603/// for the given Op. This way the check becomes very simple i.e just check the
3604/// number of result dims.
3605/// Returns true if the explictMap is broadcasted with respect to the
3606/// defaultMap.
3607static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3608 return explictMap.getNumResults() < defaultMap.getNumResults();
3609}
3610
3611/// Verifies the broadcast and transpose semantic sepecified by the explicit
3612/// indexing map for the MatmulOp \p op for each operand specified by \p
3613/// opIndex.
3614static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3615 unsigned opIndex) {
3616 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3617 SmallVector<AffineMap, 3> defaultIndexingMaps =
3618 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3619
3620 auto opIndexingMap = opIndexingMaps[opIndex];
3621 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3622 // Check general validity of indexing map results.
3623 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3624 return matmulOp->emitOpError()
3625 << "Unexpected dim expression in map result.";
3626
3627 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3628 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3629 return matmulOp->emitOpError()
3630 << "Invalid broadcast requested, should be (d2).";
3631 }
3632 return success();
3633 }
3634 return success();
3635}
3636
3637// Check general validity of input indexing map of
3638// BatchMatmulOp/BatchReduceMatmulOp.
3639template <typename OpTy>
3640static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3641 AffineMap opIndexingMap,
3642 AffineMap defaultIndexingMap, bool isLHS) {
3643 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3644 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3645 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3646 // Check the result dims are valid.
3647 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3648 return batchVariantMatmulOp->emitOpError()
3649 << "Unexpected result dim expression (outside the set of default "
3650 "result dims).";
3651
3652 // Check for valid number of result dims of input maps.
3653 if (opIndexingMap.getNumResults() > 3)
3654 return batchVariantMatmulOp->emitOpError()
3655 << "no. of result dim expressions exceeds 3.";
3656
3657 auto hasValidBatchDim = [](AffineMap map) {
3658 AffineExpr batchDim = map.getResult(0);
3659 return batchDim.isFunctionOfDim(0);
3660 };
3661
3662 // Check if the requested broadcast is valid.
3663 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3664 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3665 return batchVariantMatmulOp->emitOpError()
3666 << "Invalid broadcast requested.";
3667 } else if (!hasValidBatchDim(opIndexingMap)) {
3668 return batchVariantMatmulOp->emitOpError()
3669 << "Invalid batch dimension expression.";
3670 }
3671 return success();
3672}
3673
3674/// This function checks if the given AffineMap for the output of a
3675/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3676/// dimensions and if the output map result dimensions are valid.
3677template <typename OpTy>
3678static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3679 AffineMap opIndexingMap) {
3680 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3681 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3682 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3683 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3684 opIndexingMap.getNumResults() != 3) {
3685
3686 return batchVariantMatmulOp->emitOpError()
3687 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3688 << ").";
3689 }
3690 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3691 opIndexingMap.getNumResults() != 2) {
3692 return batchVariantMatmulOp->emitOpError()
3693 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3694 << ").";
3695 }
3696
3697 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3698 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3699 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3700 outputMap.getResult(1).isFunctionOfDim(1) &&
3701 outputMap.getResult(2).isFunctionOfDim(2)
3702 : outputMap.getResult(0).isFunctionOfDim(1) &&
3703 outputMap.getResult(1).isFunctionOfDim(2);
3704 };
3705
3706 if (!areValidOutputResultDim(opIndexingMap)) {
3707 return batchVariantMatmulOp->emitOpError()
3708 << "Invalid output map result dimension.";
3709 }
3710
3711 return success();
3712}
3713
3714/// Verifies the broadcast and transpose semantic specified by the explicit
3715/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3716/// specified by opIndex.
3717template <typename OpTy>
3718static LogicalResult
3720 unsigned opIndex) {
3721 SmallVector<AffineMap, 3> opIndexingMaps =
3722 batchVariantMatmulOp.getIndexingMapsArray();
3723 SmallVector<AffineMap, 3> defaultIndexingMaps =
3724 batchVariantMatmulOp.getDefaultIndexingMaps(
3725 batchVariantMatmulOp->getContext());
3726
3727 if (opIndexingMaps.size() != 3)
3728 return batchVariantMatmulOp->emitOpError()
3729 << "Indexing_map attribute must have 3 affine maps.";
3730
3731 auto opIndexingMap = opIndexingMaps[opIndex];
3732 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3733
3734 if (opIndex == 2 &&
3735 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3736 return failure();
3737
3738 if (opIndex != 2 &&
3739 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3740 defaultIndexingMap, opIndex == 0)))
3741 return failure();
3742
3743 return success();
3744}
3745
3746namespace mlir {
3747namespace linalg {
3748
3749std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3750 if (m == 2 && r == 3)
3751 return WinogradConv2DFmr::F_2_3;
3752 if (m == 4 && r == 3)
3753 return WinogradConv2DFmr::F_4_3;
3754 if (m == 2 && r == 5)
3755 return WinogradConv2DFmr::F_2_5;
3756 return std::nullopt;
3757}
3758
3759std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3760 switch (fmr) {
3761 case WinogradConv2DFmr::F_2_3:
3762 return {2, 3};
3763 case WinogradConv2DFmr::F_4_3:
3764 return {4, 3};
3765 case WinogradConv2DFmr::F_2_5:
3766 return {2, 5};
3767 }
3768}
3769
3770//===----------------------------------------------------------------------===//
3771// MatMulOp
3772//===----------------------------------------------------------------------===//
3773
3774static FailureOr<SmallVector<SmallVector<int64_t>>>
3777 for (auto map : maps) {
3778 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3779 if (!attr)
3780 return failure();
3782 for (auto result : attr.getAffineMap().getResults()) {
3783 auto dim = dyn_cast<AffineDimExpr>(result);
3784 if (!dim)
3785 return failure();
3786 pos.push_back(dim.getPosition());
3787 }
3788 positions.push_back(pos);
3789 }
3790 return positions;
3791}
3792
3793/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3794SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3795 AffineExpr d0, d1, d2;
3796 SmallVector<AffineMap> indexingMaps;
3797 bindDims(context, d0, d1, d2);
3798 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3799 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3800 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3801 return indexingMaps;
3802}
3803
3804bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3805 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3806 if (!maps)
3807 return false;
3808 if (maps.size() != 3)
3809 return false;
3810 auto positions = getAffineResultPositions(maps);
3811 if (failed(positions))
3812 return false;
3813 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3814 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3815 (*positions)[2] == SmallVector<int64_t>{0, 1};
3816}
3817
3818SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3819 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3820 utils::IteratorType::parallel,
3821 utils::IteratorType::reduction};
3822}
3823
3824unsigned MatmulOp::getNumRegionArgs() { return 3; }
3825
3826std::string MatmulOp::getLibraryCallName() {
3827 return generateLibraryCallName(getOperation());
3828}
3829
3830bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3831
3832/// Check if the op has broadcast and/or transpose semantic. Returns true if
3833/// the user defined indexing maps are not equal to default map.
3834bool MatmulOp::hasUserDefinedMaps() {
3835 SmallVector<AffineMap, 3> defaultMaps =
3836 getDefaultIndexingMaps(this->getContext());
3837 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3838 return defaultMaps != explicitMaps;
3839}
3840
3841/// Implements the block region builder for the MatmulOp. This is called by
3842/// 'fillStructuredOpRegion'.
3843void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3844 ArrayRef<NamedAttribute> attrs,
3845 function_ref<InFlightDiagnostic()> emitError) {
3846 if (emitError && block.getNumArguments() != 3) {
3847 emitError() << "MatmulOp regionBuilder expects 3 args, got "
3848 << block.getNumArguments();
3849 return;
3850 }
3851 assert(block.getNumArguments() == 3 &&
3852 "MatmulOp regionBuilder expects 3 args");
3853 RegionBuilderHelper helper(b, block);
3854 SmallVector<Value> yields;
3855
3856 TypeFn castVal = TypeFn::cast_signed;
3857 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3858 return attr.getName() == "cast";
3859 });
3860 if (castIter != attrs.end()) {
3861 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3862 castVal = attr.getValue();
3863 }
3864
3865 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3866 block.getArgument(0));
3867 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3868 block.getArgument(1));
3869 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3870 if (!value3)
3871 return;
3872 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3873 value3, emitError);
3874 if (!value4)
3875 return;
3876 yields.push_back(value4);
3877 helper.yieldOutputs(yields);
3878}
3879
3880/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3881/// broadcast map must include K dimension.
3882/// TODO: Strict inclusion of K dimension in the broadcast map is not
3883/// necessary for both input matrices simultaneously. We can relax this
3884/// condition to have K dimension for one input matrix map and infer the K
3885/// dimension for other input matrix map from the one already having K
3886/// dimension.
3887bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3888 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3889 AffineExpr expr = bcastMap.getResult(0);
3890 // Invalid map if the common dimension of matmul not found.
3891 return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3892}
3893
3894static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3895 if (parser.parseOptionalKeyword("indexing_maps"))
3896 return ArrayAttr{
3897 nullptr}; // Success in case indexing_maps was not provided.
3898
3899 ArrayAttr arrayAttr;
3900 if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3901 return failure();
3902
3903 if (llvm::any_of(arrayAttr,
3904 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3905 return parser.emitError(parser.getCurrentLocation())
3906 << "element of indexing_maps array is not an affine_map";
3907
3908 return arrayAttr;
3909}
3910
3911ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3912 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3913 if (failed(indexingMapsAttr))
3914 return failure();
3915
3916 if (*indexingMapsAttr == nullptr) {
3917 auto indexingMapAttrs = llvm::map_to_vector(
3918 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3919 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3920 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3921 }
3922
3923 result.addAttribute("indexing_maps", *indexingMapsAttr);
3924 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3925 MatmulOp::getRegionBuilder());
3926}
3927
3928void MatmulOp::print(OpAsmPrinter &p) {
3929 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3930 MatmulOp::getDefaultIndexingMaps(getContext()),
3931 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3932 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3933 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3934
3935 std::array<StringRef, 3> elidedAttrs = {
3936 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3937 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3938 elidedAttrs);
3939}
3940
3941/// Verify the user defined indexing maps.
3942LogicalResult MatmulOp::verify() {
3943 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3944 if (!hasUserDefinedMaps())
3945 return success();
3946
3947 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3948 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3949 return failure();
3950 }
3951 return success();
3952}
3953
3954LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3955 return memref::foldMemRefCast(*this);
3956}
3957
3958void MatmulOp::getEffects(
3959 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3960 &effects) {
3961 if (hasPureTensorSemantics())
3962 return;
3963 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3964}
3965
3966Speculation::Speculatability MatmulOp::getSpeculatability() {
3967 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3968}
3969
3970SmallVector<AffineMap>
3971MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3972 AffineExpr d0, d1, d2;
3973 MLIRContext *context = builder.getContext();
3974 bindDims(context, d0, d1, d2);
3975 AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
3976 AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
3977 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
3978 return {mapLHS, mapRHS, mapOut};
3979}
3980
3982 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3983 if (!maps)
3984 return false;
3985 if (maps.size() != 3)
3986 return false;
3987 auto positions = getAffineResultPositions(maps);
3988 if (failed(positions))
3989 return false;
3990 return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
3991 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3992 (*positions)[2] == SmallVector<int64_t>{0, 1};
3993}
3994
3997 ValueRange inputs, ValueRange outputs,
3998 ArrayRef<NamedAttribute> attributes) {
3999 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4000 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4001}
4002
4005 ValueRange inputs, ValueRange outputs,
4006 ArrayRef<NamedAttribute> attributes) {
4007 OperationState state(location, getOperationName());
4008 build(builder, state, inputs, outputs, attributes);
4009 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4010 assert(res && "builder didn't return the right type");
4011 return res;
4012}
4013
4016 TypeRange resultTensorTypes,
4017 ValueRange inputs, ValueRange outputs,
4018 ArrayRef<NamedAttribute> attributes) {
4019 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4020 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4021}
4022
4025 TypeRange resultTensorTypes, ValueRange inputs,
4026 ValueRange outputs,
4027 ArrayRef<NamedAttribute> attributes) {
4028 OperationState state(location, getOperationName());
4029 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4030 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4031 assert(res && "builder didn't return the right type");
4032 return res;
4033}
4034
4037 TypeRange resultTensorTypes,
4038 ValueRange inputs, ValueRange outputs,
4039 Attribute cast,
4040 ArrayRef<NamedAttribute> attributes) {
4041 result.addAttribute("cast", cast);
4042 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4043 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4044}
4045
4048 TypeRange resultTensorTypes, ValueRange inputs,
4049 ValueRange outputs, Attribute cast,
4050 ArrayRef<NamedAttribute> attributes) {
4051 OperationState state(location, getOperationName());
4052 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4053 auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4054 assert(res && "builder didn't return the right type");
4055 return res;
4056}
4057
4059 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4061 op->getAttr("indexing_maps"));
4062}
4063
4065MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4066 AffineExpr d0, d1, d2;
4067 MLIRContext *context = builder.getContext();
4068 bindDims(context, d0, d1, d2);
4069 AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
4070 AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
4071 AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4072 return {mapLHS, mapRHS, mapOut};
4073}
4074
4076 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4077 if (!maps)
4078 return false;
4079 if (maps.size() != 3)
4080 return false;
4081 auto positions = getAffineResultPositions(maps);
4082 if (failed(positions))
4083 return false;
4084 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
4085 (*positions)[1] == SmallVector<int64_t>{1, 2} &&
4086 (*positions)[2] == SmallVector<int64_t>{0, 1};
4087}
4088
4091 ValueRange inputs, ValueRange outputs,
4092 ArrayRef<NamedAttribute> attributes) {
4093 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4094 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4095}
4096
4099 ValueRange inputs, ValueRange outputs,
4100 ArrayRef<NamedAttribute> attributes) {
4101 OperationState state(location, getOperationName());
4102 build(builder, state, inputs, outputs, attributes);
4103 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4104 assert(res && "builder didn't return the right type");
4105 return res;
4106}
4107
4110 TypeRange resultTensorTypes,
4111 ValueRange inputs, ValueRange outputs,
4112 ArrayRef<NamedAttribute> attributes) {
4113 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4114 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4115}
4116
4119 TypeRange resultTensorTypes, ValueRange inputs,
4120 ValueRange outputs,
4121 ArrayRef<NamedAttribute> attributes) {
4122 OperationState state(location, getOperationName());
4123 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4124 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4125 assert(res && "builder didn't return the right type");
4126 return res;
4127}
4128
4131 TypeRange resultTensorTypes,
4132 ValueRange inputs, ValueRange outputs,
4133 Attribute cast,
4134 ArrayRef<NamedAttribute> attributes) {
4135 result.addAttribute("cast", cast);
4136 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4137 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4138}
4139
4142 TypeRange resultTensorTypes, ValueRange inputs,
4143 ValueRange outputs, Attribute cast,
4144 ArrayRef<NamedAttribute> attributes) {
4145 OperationState state(location, getOperationName());
4146 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4147 auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4148 assert(res && "builder didn't return the right type");
4149 return res;
4150}
4151
4153 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4155 op->getAttr("indexing_maps"));
4156}
4157
4159BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4160 AffineExpr d0, d1, d2, d3;
4161 MLIRContext *context = builder.getContext();
4162 bindDims(context, d0, d1, d2, d3);
4163 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
4164 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
4165 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4166 return {mapLHS, mapRHS, mapOut};
4167}
4168
4170 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4171 if (!maps)
4172 return false;
4173 if (maps.size() != 3)
4174 return false;
4175 auto positions = getAffineResultPositions(maps);
4176 if (failed(positions))
4177 return false;
4178 return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4179 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4180 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4181}
4182
4184 OpBuilder &builder, OperationState &result, ValueRange inputs,
4185 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4186 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4187 BatchMatmulOp::getRegionBuilder(),
4188 getDefaultIndexingMaps(builder));
4189}
4190
4193 ValueRange inputs, ValueRange outputs,
4194 ArrayRef<NamedAttribute> attributes) {
4195 OperationState state(location, getOperationName());
4196 build(builder, state, inputs, outputs, attributes);
4197 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4198 assert(res && "builder didn't return the right type");
4199 return res;
4200}
4201
4203 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4204 ValueRange inputs, ValueRange outputs,
4205 ArrayRef<NamedAttribute> attributes) {
4206 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4207 BatchMatmulOp::getRegionBuilder(),
4208 getDefaultIndexingMaps(builder));
4209}
4210
4213 TypeRange resultTensorTypes, ValueRange inputs,
4214 ValueRange outputs,
4215 ArrayRef<NamedAttribute> attributes) {
4216 OperationState state(location, getOperationName());
4217 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4218 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4219 assert(res && "builder didn't return the right type");
4220 return res;
4221}
4222
4224 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4225 ValueRange inputs, ValueRange outputs, Attribute cast,
4226 ArrayRef<NamedAttribute> attributes) {
4227 result.addAttribute("cast", cast);
4228 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4229 BatchMatmulOp::getRegionBuilder(),
4230 getDefaultIndexingMaps(builder));
4231}
4232
4235 TypeRange resultTensorTypes, ValueRange inputs,
4236 ValueRange outputs, Attribute cast,
4237 ArrayRef<NamedAttribute> attributes) {
4238 OperationState state(location, getOperationName());
4239 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4240 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4241 assert(res && "builder didn't return the right type");
4242 return res;
4243}
4244
4246 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4248 op->getAttr("indexing_maps"));
4249}
4250
4252BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4253 AffineExpr d0, d1, d2, d3;
4254 MLIRContext *context = builder.getContext();
4255 bindDims(context, d0, d1, d2, d3);
4256 AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4257 AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4258 AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4259 return {mapLHS, mapRHS, mapOut};
4260}
4261
4263 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4264 if (!maps)
4265 return false;
4266 if (maps.size() != 3)
4267 return false;
4268 auto positions = getAffineResultPositions(maps);
4269 if (failed(positions))
4270 return false;
4271 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4272 (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4273 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4274}
4275
4277 OpBuilder &builder, OperationState &result, ValueRange inputs,
4278 ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4279 buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4280 BatchMatmulOp::getRegionBuilder(),
4281 getDefaultIndexingMaps(builder));
4282}
4283
4286 ValueRange inputs, ValueRange outputs,
4287 ArrayRef<NamedAttribute> attributes) {
4288 OperationState state(location, getOperationName());
4289 build(builder, state, inputs, outputs, attributes);
4290 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4291 assert(res && "builder didn't return the right type");
4292 return res;
4293}
4294
4296 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4297 ValueRange inputs, ValueRange outputs,
4298 ArrayRef<NamedAttribute> attributes) {
4299 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4300 BatchMatmulOp::getRegionBuilder(),
4301 getDefaultIndexingMaps(builder));
4302}
4303
4306 TypeRange resultTensorTypes, ValueRange inputs,
4307 ValueRange outputs,
4308 ArrayRef<NamedAttribute> attributes) {
4309 OperationState state(location, getOperationName());
4310 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4311 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4312 assert(res && "builder didn't return the right type");
4313 return res;
4314}
4315
4317 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4318 ValueRange inputs, ValueRange outputs, Attribute cast,
4319 ArrayRef<NamedAttribute> attributes) {
4320 result.addAttribute("cast", cast);
4321 buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4322 BatchMatmulOp::getRegionBuilder(),
4323 getDefaultIndexingMaps(builder));
4324}
4325
4328 TypeRange resultTensorTypes, ValueRange inputs,
4329 ValueRange outputs, Attribute cast,
4330 ArrayRef<NamedAttribute> attributes) {
4331 OperationState state(location, getOperationName());
4332 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4333 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4334 assert(res && "builder didn't return the right type");
4335 return res;
4336}
4337
4339 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4341 op->getAttr("indexing_maps"));
4342}
4343
4344//===----------------------------------------------------------------------===//
4345// ContractOp
4346//===----------------------------------------------------------------------===//
4347
4348SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
4349 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4350 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
4351 // domains are all the same, and each implements a projected permutation.
4352 // Each iteration space dim must occur for at least one operand and either
4353 // takes part in a contraction/reduction or else has parallel iteration type.
4354 // We have that a dim is a contraction/reduction dim if and only if the dim
4355 // occurs for the output operand. We use this fact for fast inference:
4356 // NB: In case we allow dims to occur solely for one input, the above still
4357 // holds: per the einsum semantics, these are reduction dims as well.
4358 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
4359 for (auto result : outAffineMap.getResults()) {
4360 auto dimExpr = dyn_cast<AffineDimExpr>(result);
4361 assert(dimExpr && "affine_map is a projected permutation");
4362 dimsInOutput[dimExpr.getPosition()] = true;
4363 }
4364
4366 for (auto dimOccursInOutput : dimsInOutput)
4367 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4368 : utils::IteratorType::reduction);
4369
4370 return iteratorTypes;
4371}
4372
4373unsigned ContractOp::getNumRegionArgs() { return 3; }
4374
4375/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
4376void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4377 ArrayRef<NamedAttribute> attrs,
4378 function_ref<InFlightDiagnostic()> emitError) {
4379 if (emitError && block.getNumArguments() != 3) {
4380 emitError() << "ContractOp regionBuilder expects 3 args, got "
4381 << block.getNumArguments();
4382 return;
4383 }
4384 assert(block.getNumArguments() == 3 &&
4385 "ContractOp regionBuilder expects 3 args");
4386 RegionBuilderHelper helper(b, block);
4387
4388 TypeFn castSignedness = TypeFn::cast_signed;
4389 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4390 return attr.getName() == "cast";
4391 });
4392 if (castIter != attrs.end()) {
4393 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4394 castSignedness = attr.getValue();
4395 }
4396
4397 // TODO: Support fields with operators besides mult & add.
4398 Type outType = block.getArgument(2).getType();
4399 Value lhsAtOutType =
4400 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
4401 Value rhsAtOutType =
4402 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
4403 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4404 rhsAtOutType, emitError);
4405 if (!productAtOutType)
4406 return;
4407 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4408 productAtOutType, emitError);
4409 if (!result)
4410 return;
4411 helper.yieldOutputs({result});
4412}
4413
4414ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
4415 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4416 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
4417 return parser.emitError(parser.getCurrentLocation(),
4418 "expected 'indexing_maps' attribute");
4419 result.addAttribute("indexing_maps", *indexingMapsAttr);
4420
4421 return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
4422 regionBuilder);
4423}
4424
4425void ContractOp::print(OpAsmPrinter &p) {
4426 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4428 p, getOperation(), getInputs(), getOutputs(),
4429 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
4430}
4431
4432LogicalResult ContractOp::verify() {
4433 int iterationSpaceDims = -1;
4434 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
4435 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
4436 // access an input operand (so occurrence count can be at most 2) and
4437 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
4438 SmallVector<size_t> inOccurrences;
4439 SmallVector<size_t> outOccurrences;
4440
4441 // A helper so that for each operand's affine_map and type we check that ...
4442 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4443 bool isInput) -> LogicalResult {
4444 // ... the affine_map is a projected permutation;
4445 if (!affineMap.isProjectedPermutation())
4446 return emitError("provided affine_map is not a projected permutation");
4447
4448 // ... the rank of the affine_map's results and corresponding type match;
4449 if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
4450 if (affineMap.getNumResults() != shapedType.getRank())
4451 return emitError("ranks of shaped operand and results of corresponding "
4452 "affine_map differ");
4453 } else if (affineMap.getNumResults() != 0) {
4454 return emitError("affine_map specifies shaped access while operand has "
4455 "non-shaped type");
4456 }
4457
4458 // ... the rank of the affine_map's domain is the same as those seen prior;
4459 if (iterationSpaceDims == -1) {
4460 iterationSpaceDims = affineMap.getNumDims();
4461 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4462 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4463 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4464 return emitError("iteration spaces of provided affine_maps differ");
4465 }
4466
4467 // ... update counts of dims used to access either an input or the output.
4468 for (AffineExpr affineExpr : affineMap.getResults()) {
4469 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4470 if (!affineDimExpr)
4471 llvm_unreachable("affine_map is a projected permutation");
4472
4473 if (isInput)
4474 inOccurrences[affineDimExpr.getPosition()] += 1;
4475 else
4476 outOccurrences[affineDimExpr.getPosition()] += 1;
4477 }
4478
4479 return success();
4480 };
4481
4482 for (auto &&[affineMap, operandType, isInput] :
4483 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4484 SmallVector<bool>{true, true, false})) {
4485 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4486 return failure(); // NB: checkAffineMapAndType will emit relevant error.
4487 }
4488
4489 bool hasContractingDim = false;
4490 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4491 size_t inOccCount = inOccurrences[dimIndex];
4492 size_t outOccCount = outOccurrences[dimIndex];
4493
4494 // We have a contracting dim if and only if ...
4495 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4496
4497 if (inOccCount == 0 && outOccCount == 0)
4498 return emitError() << "iteration space dim at index " << dimIndex
4499 << " not used to access any operand";
4500
4501 // NB: We disallow a dim which occurs for only one input operand and not
4502 // for the output. In terms of einsum semantics such dims have a
4503 // sensible meaning - namely an additional reduction per each such dim.
4504 // By contrast, the ContractionOpInterface does not know about this
4505 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4506 // while vector.contract's verifier accepts dims of this kind many of
4507 // its lowerings give up on encountering these dims.
4508 // TODO: Remove following once we have comprehensive support for input-only
4509 // reduction dims, at both the linalg- and vector-dialect levels.
4510 if (inOccCount == 1 && outOccCount != 1)
4511 return emitError()
4512 << "iteration space dim at index " << dimIndex
4513 << " is neither a contracting dim nor of parallel iteration type";
4514 }
4515
4516 if (!hasContractingDim)
4517 return emitError("'indexing_maps' do not specify a contracting dimension");
4518
4519 return success();
4520}
4521
4522LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4523 return memref::foldMemRefCast(*this);
4524}
4525
4526void ContractOp::getEffects(
4527 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4528 &effects) {
4529 if (hasPureTensorSemantics())
4530 return;
4531 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4532}
4533
4534Speculation::Speculatability ContractOp::getSpeculatability() {
4535 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4536}
4537
4538//===----------------------------------------------------------------------===//
4539// Implementation of BatchMatmulOp
4540//===----------------------------------------------------------------------===//
4541SmallVector<AffineMap>
4542BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4543 AffineExpr d0, d1, d2, d3;
4544 SmallVector<AffineMap> indexingMaps;
4545 bindDims(context, d0, d1, d2, d3);
4546 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4547 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4548 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4549 return indexingMaps;
4550}
4551
4552bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4553 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4554 if (!maps)
4555 return false;
4556 if (maps.size() != 3)
4557 return false;
4558 auto positions = getAffineResultPositions(maps);
4559 if (failed(positions))
4560 return false;
4561 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4562 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4563 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4564}
4565
4566SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4567 return SmallVector<utils::IteratorType>{
4568 utils::IteratorType::parallel, utils::IteratorType::parallel,
4569 utils::IteratorType::parallel, utils::IteratorType::reduction};
4570}
4571
4572unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4573
4574std::string BatchMatmulOp::getLibraryCallName() {
4575 return generateLibraryCallName(getOperation());
4576}
4577
4578/// Check if the op has broadcast and/or transpose semantic. Returns true if
4579/// the user defined indexing maps are not equal to default map.
4580bool BatchMatmulOp::hasUserDefinedMaps() {
4581 SmallVector<AffineMap, 3> defaultMaps =
4582 getDefaultIndexingMaps(this->getContext());
4583 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4584 return defaultMaps != explicitMaps;
4585}
4586
4587/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4588/// broadcast map must include K dimension.
4589/// TODO: Strict inclusion of K dimension in the broadcast map is not
4590/// necessary for both input matrices simultaneously. We can relax this
4591/// condition to have K dimension for one input matrix map and infer the K
4592/// dimension for other input matrix map from the one already having K
4593/// dimension.
4594bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4595 assert(bcastMap.getNumResults() < 3 &&
4596 "Expected less than 3 result dim expr.");
4597 bool isValid = false;
4598 enum Indices { batchPos, mPos, nPos, kPos };
4599 if (bcastMap.getNumResults() == 1) {
4600 AffineExpr expr = bcastMap.getResult(0);
4601 isValid = expr.isFunctionOfDim(kPos);
4602 } else if (bcastMap.getNumResults() == 2) {
4603 AffineExpr expr0 = bcastMap.getResult(0);
4604 AffineExpr expr1 = bcastMap.getResult(1);
4605 isValid =
4606 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4607 expr0.isFunctionOfDim(mPos)) &&
4608 expr1.isFunctionOfDim(kPos))
4609 : ((expr0.isFunctionOfDim(batchPos) &&
4610 expr1.isFunctionOfDim(kPos)) ||
4611 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4612 }
4613 return isValid;
4614}
4615
4616void BatchMatmulOp::regionBuilder(
4617 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4618 function_ref<InFlightDiagnostic()> emitError) {
4619 if (emitError && block.getNumArguments() != 3) {
4620 emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4621 << block.getNumArguments();
4622 return;
4623 }
4624 assert(block.getNumArguments() == 3 &&
4625 "BatchMatmulOp regionBuilder expects 3 args");
4626 RegionBuilderHelper helper(b, block);
4627 SmallVector<Value> yields;
4628
4629 TypeFn castVal = TypeFn::cast_signed;
4630 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4631 return attr.getName() == "cast";
4632 });
4633 if (castIter != attrs.end()) {
4634 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4635 castVal = attr.getValue();
4636 }
4637
4638 auto toType = block.getArgument(2).getType();
4639 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4640 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4641 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4642 Value addVal =
4643 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
4644 yields.push_back(addVal);
4645 helper.yieldOutputs(yields);
4646}
4647
4648ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4649 SmallVector<Attribute, 3> indexingMapsAttr;
4650 Attribute mapAttr;
4651 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4652 if (parser.parseEqual())
4653 return failure();
4654
4655 if (parser.parseLSquare())
4656 return failure();
4657
4658 do {
4659 if (parser.parseAttribute(mapAttr))
4660 return failure();
4661 if (!isa<AffineMapAttr>(mapAttr)) {
4662 return parser.emitError(parser.getCurrentLocation(),
4663 "expected affine map attribute");
4664 }
4665 indexingMapsAttr.push_back(mapAttr);
4666
4667 if (parser.parseOptionalComma())
4668 break;
4669 } while (true);
4670
4671 if (parser.parseRSquare())
4672 return failure();
4673 }
4674 // Initialize indexingMaps, if not supplied explicitly.
4675 if (indexingMapsAttr.empty()) {
4676 indexingMapsAttr = llvm::map_to_vector(
4677 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4678 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4679 }
4680 result.addAttribute("indexing_maps",
4681 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4682
4683 return ::parseNamedStructuredOp(parser, result,
4684 BatchMatmulOp::getNumRegionArgs(),
4685 BatchMatmulOp::getRegionBuilder());
4686}
4687
4688void BatchMatmulOp::print(OpAsmPrinter &p) {
4689 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4690 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4691 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4692 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4693 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4694
4695 std::array<StringRef, 3> elidedAttrs = {
4696 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4697 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4698 elidedAttrs);
4699}
4700
4701/// Verify the user defined indexing maps.
4702LogicalResult BatchMatmulOp::verify() {
4703 // Verification of pure batch_matmul is handled by
4704 // verifyStructuredOpInterface().
4705 if (!hasUserDefinedMaps())
4706 return success();
4707
4708 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4710 return failure();
4711 }
4712 return success();
4713}
4714
4715LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4716 SmallVectorImpl<OpFoldResult> &) {
4717 return memref::foldMemRefCast(*this);
4718}
4719
4720void BatchMatmulOp::getEffects(
4721 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4722 &effects) {
4723 if (hasPureTensorSemantics())
4724 return;
4725 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4726}
4727
4728Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4729 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4730}
4731
4732//===----------------------------------------------------------------------===//
4733// ElementwiseOp
4734//===----------------------------------------------------------------------===//
4735//
4736namespace {
4737struct ArityGroupAndKind {
4738 // The enum class {Unary, Binary, Ternary, ..}
4739 ElementwiseArityGroup arityGroup;
4740
4741 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4742 union Kind {
4743 UnaryFn unaryFn;
4744 BinaryFn binaryFn;
4745 TernaryFn ternaryFn;
4746 } kind;
4747};
4748
4749unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4750 return static_cast<unsigned>(arityGroup);
4751}
4752} // namespace
4753
4754static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4755 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4756 constexpr int lastBinary =
4757 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4758 constexpr int lastTernary =
4759 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4760
4761 int val = static_cast<int>(kind);
4762 ArityGroupAndKind result;
4763
4764 if (val < lastUnary) {
4765 result.arityGroup = ElementwiseArityGroup::Unary;
4766 result.kind.unaryFn = static_cast<UnaryFn>(val);
4767 return result;
4768 }
4769 if (val < lastBinary) {
4770 result.arityGroup = ElementwiseArityGroup::Binary;
4771 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4772 return result;
4773 }
4774 if (val >= lastTernary) {
4775 llvm_unreachable("unhandled ElementwiseFn");
4776 }
4777 result.arityGroup = ElementwiseArityGroup::Ternary;
4778 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4779 return result;
4780}
4781
4782SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4783 auto rank = getResultRank();
4784 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4785}
4786
4788ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4789 MLIRContext *context) {
4790 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4791 return SmallVector<AffineMap>(numMaps, map);
4792}
4793
4794ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4795 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4796 Attribute attr;
4797 mlir::linalg::ElementwiseKind elemwiseKindVal;
4798 if (parser.parseKeyword("kind") || parser.parseEqual())
4799 return failure();
4800
4801 if (succeeded(parser.parseAttribute(attr))) {
4802 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4803 if (!elemwiseKindAttr)
4804 return parser.emitError(parser.getCurrentLocation(),
4805 "expected ElementwiseKind attribute");
4806 elemwiseKindVal = elemwiseKindAttr.getValue();
4807 } else {
4808 return parser.emitError(parser.getCurrentLocation(),
4809 "expected operation 'kind' attribute");
4810 }
4811 result.addAttribute(
4812 "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4813
4814 // Parse optional `indexing_maps`
4815 SmallVector<Attribute, 3> indexingMapsAttr;
4816 Attribute mapAttr;
4817 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4818 if (parser.parseEqual())
4819 return failure();
4820 if (parser.parseLSquare())
4821 return failure();
4822 do {
4823 if (parser.parseAttribute(mapAttr))
4824 return failure();
4825 if (!isa<AffineMapAttr>(mapAttr))
4826 return parser.emitError(parser.getCurrentLocation(),
4827 "expected affine map attribute");
4828 indexingMapsAttr.push_back(mapAttr);
4829 if (parser.parseOptionalComma())
4830 break;
4831 } while (true);
4832 if (parser.parseRSquare())
4833 return failure();
4834 }
4835 // At this stage of parsing the only way to infer number of region
4836 // args is through op kind, as input output tensors are not parsed yet.
4837 auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4838 int numRegionArgs =
4839 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4840 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4841 ElementwiseOp::getRegionBuilder())) {
4842 return parser.emitError(parser.getCurrentLocation(),
4843 "unable to parse elemwise op");
4844 }
4845
4846 // Initialize indexingMaps, if not supplied explicitly.
4847 if (indexingMapsAttr.empty()) {
4848 // We need to infer the numDims of the indexing maps from the output
4849 // type which is already parsed by now.
4850 auto resultType = result.operands[result.operands.size() - 1].getType();
4851 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4852 if (!shapedType)
4853 return parser.emitError(parser.getCurrentLocation(),
4854 "return type needs to be shaped type");
4855 auto numDims = shapedType.getRank();
4856 indexingMapsAttr = llvm::map_to_vector(
4857 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4858 parser.getContext()),
4859 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4860 }
4861
4862 result.addAttribute("indexing_maps",
4863 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4864 return success();
4865}
4866
4867void ElementwiseOp::print(OpAsmPrinter &p) {
4868 p << " kind=";
4869 p.printAttribute(getKindAttr());
4870 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4871 "indexing_maps"};
4872 unsigned arity =
4873 getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4874 unsigned numDims = getResultRank();
4875
4876 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4877 ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4878 getContext()),
4879 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4880
4881 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4882 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4883
4884 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4885 elidedAttrs);
4886}
4887
4888LogicalResult ElementwiseOp::verify() {
4889 // All necessary checks are done either by
4890 // - EnumAttr (e.g. unknown operation kind)
4891 // - verifyStructuredOpInterface (incorrect map, sizes).
4892 return success();
4893}
4894
4895/// Implements the block region builder for the ElementwiseOp. This is called by
4896/// 'fillStructuredOpRegion'.
4897void ElementwiseOp::regionBuilder(
4898 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4899 function_ref<InFlightDiagnostic()> emitError) {
4900 ElementwiseKind elemwiseKind;
4901 for (auto attr : attrs) {
4902 if (attr.getName() == b.getStringAttr("kind")) {
4903 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4904 assert(kindAttr && "op kind attribute incorrectly set");
4905 elemwiseKind = kindAttr.getValue();
4906 break;
4907 }
4908 }
4909
4910 ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4911 auto arityGroup = groupAndKind.arityGroup;
4912 auto kind = groupAndKind.kind;
4913 if (emitError && block.getNumArguments() !=
4914 getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
4915 emitError() << "Elementwise regionBuilder expects "
4916 << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
4917 << block.getNumArguments();
4918 return;
4919 }
4920 assert(block.getNumArguments() ==
4921 getArityGroupAsUInt(arityGroup) + 1 /*output*/
4922 && "Elementwise regionBuilder number of block args mismatch");
4923
4924 RegionBuilderHelper helper(b, block);
4925 SmallVector<Value> yields;
4926 Value result;
4927
4928 if (arityGroup == ElementwiseArityGroup::Unary) {
4929 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4930
4931 } else if (arityGroup == ElementwiseArityGroup::Binary) {
4932 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4933 block.getArgument(1));
4934
4935 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4936 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4937 block.getArgument(1), block.getArgument(2));
4938
4939 } else {
4940 assert(false && "found unhandled category in elemwise");
4941 }
4942
4943 yields.push_back(result);
4944 helper.yieldOutputs(yields);
4945}
4946
4947LogicalResult ElementwiseOp::fold(FoldAdaptor,
4948 SmallVectorImpl<OpFoldResult> &) {
4949 return memref::foldMemRefCast(*this);
4950}
4951
4952void ElementwiseOp::getEffects(
4953 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4954 &effects) {
4955 if (hasPureTensorSemantics())
4956 return;
4957 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4958}
4959
4960Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4961 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4962}
4963
4964//===----------------------------------------------------------------------===//
4965// PackOp/UnPackOp Common
4966//===----------------------------------------------------------------------===//
4967
4968template <typename OpTy, typename>
4969SmallVector<int64_t>
4971 RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4972 ? packOrUnPack.getDestType()
4973 : packOrUnPack.getSourceType();
4974 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4975 ? packOrUnPack.getSourceType()
4976 : packOrUnPack.getDestType();
4978 packedType.getShape().take_front(unpackedType.getRank()));
4979 if (!packOrUnPack.getOuterDimsPerm().empty()) {
4981 result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
4982 }
4983 return result;
4984}
4989
4990// Given the (potentially) updated packed type, `newPackedTy`, generates an
4991// updated mixed-tile-sizes attribute. A tile size is updated only
4992// when:
4993// * a dim from newPackedTy is static, and
4994// * the corresponding size from mixedTiles is still dynamic.
4995// Otherwise, the original tile size is preserved.
4996// Note - packed-type-dim and mixed-tile-size should always match!
4999 SmallVector<OpFoldResult> mixedTiles) {
5000 SmallVector<OpFoldResult> newMixedTileSizes;
5001 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5002 .getShape()
5003 .take_back(mixedTiles.size()),
5004 mixedTiles)) {
5005 int64_t shape = std::get<0>(it);
5006 if (shape == ShapedType::kDynamic) {
5007 newMixedTileSizes.push_back(std::get<1>(it));
5008 continue;
5009 }
5010
5011 // If the current result dim is static, update the dynamic mixed-size
5012 // (provided the original value is dynamic).
5013 OpFoldResult tile = std::get<1>(it);
5014 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
5015 // Already a constant
5016 newMixedTileSizes.push_back(tile);
5017 } else {
5018 assert(getConstantIntValue(tile).value() == shape &&
5019 "tile size and dim size don't match!");
5020 newMixedTileSizes.push_back(
5021 (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
5022 }
5023 }
5024
5025 return newMixedTileSizes;
5026}
5027
5028template <typename OpTy>
5029static LogicalResult
5031 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5032 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5033 "applies to only pack or unpack operations");
5034 int64_t destRank = op.getDestRank();
5035 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
5036 reifiedReturnShapes[0] =
5037 tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
5038 return success();
5039}
5040
5041template <typename OpTy>
5043 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5044 "applies to only pack or unpack operations");
5045 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
5046 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
5047 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
5048 assert(tiles.size() == dimsToTile.size() &&
5049 "tiles must match indices of dimension to block");
5050 // bind the dimension `i` with the tile factor.
5051 for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5052 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5053 return dimAndTileMapping;
5054}
5055
5056template <typename OpTy>
5058 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5059 "applies to only pack or unpack operations");
5060 Builder builder(op);
5061 SmallVector<OpFoldResult> mixedInnerTiles;
5062 unsigned dynamicValIndex = 0;
5063 for (int64_t staticTile : op.getStaticInnerTiles()) {
5064 if (ShapedType::isStatic(staticTile))
5065 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
5066 else
5067 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5068 }
5069 return mixedInnerTiles;
5070}
5071
5072template <typename OpTy>
5074 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5075 "applies to only pack or unpack operations");
5076 SmallVector<Value> dynamicTiles;
5077 SmallVector<int64_t> staticTiles;
5078 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
5079 return staticTiles;
5080}
5081
5082/// Returns true if `dimsPos` is invalid. It is invalid when:
5083/// a) It contains duplicate.
5084/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
5085/// c) The number of elements in `dimsPos` is > than `rank`.
5087 size_t rank) {
5088 size_t dimsPosSize = dimsPos.size();
5089 if (dimsPosSize > rank)
5090 return true;
5091 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
5092 if (dimsPosSize != uniqued.size())
5093 return true;
5094 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5095 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
5096 });
5097}
5098
5099template <typename OpTy>
5100static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
5101 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5102 "applies to only pack or unpack operations");
5103 Operation *op = packOrUnPack.getOperation();
5104
5105 // Return true if we have a zero-value tile.
5106 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
5107 return llvm::any_of(tiles, isZeroInteger);
5108 };
5109
5110 // Verify tiles. Do not allow zero tiles.
5111 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
5112 if (hasZeros(mixedTiles))
5113 return op->emitError("invalid zero tile factor");
5114
5115 // Verify inner_dims_pos and outer_dims_perm.
5116 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
5117 ? packOrUnPack.getSourceType()
5118 : packOrUnPack.getDestType();
5119 size_t unpackedRank = unpackedType.getRank();
5120 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
5121 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
5122 if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
5123 return op->emitError("invalid inner_dims_pos vector");
5124 if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
5125 return op->emitError("invalid outer_dims_perm vector");
5126 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5127 return op->emitError("outer_dims_perm must be a permutation or empty");
5128
5129 // Tiling factors must be less than or equal to the input rank for pack (or
5130 // output rank for unpack), and must match the number of `inner_dims_pos`.
5131 if (mixedTiles.size() > unpackedRank) {
5132 return op->emitError("tiling factors must be less than or equal to the "
5133 "input rank for pack or output rank for unpack");
5134 }
5135 if (mixedTiles.size() != innerDimsPos.size()) {
5136 return op->emitError(
5137 "tiling factors must equal the number of dimensions to tile");
5138 }
5139
5140 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5141 ? packOrUnPack.getDestType()
5142 : packOrUnPack.getSourceType();
5143 size_t packedRank = packedType.getRank();
5144 // Require output rank to match input rank + number of blocking factors.
5145 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5146 if (expectedPackedRank != packedRank) {
5147 return op->emitError(
5148 "packed rank != (unpacked rank + num tiling factors), got ")
5149 << packedRank << " != " << expectedPackedRank;
5150 }
5151
5152 // Verify result shape is greater than the minimum expected
5153 // by the pack operation, and that the output shape
5154 // represents full tiles.
5155 RankedTensorType expectedPackedType = PackOp::inferPackedType(
5156 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
5157 if (!llvm::all_of(
5158 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
5159 mixedTiles),
5160 [](std::tuple<int64_t, OpFoldResult> it) {
5161 int64_t shape = std::get<0>(it);
5162 if (Attribute attr =
5163 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
5164 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5165 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5166 return shape == staticTileSize;
5167 }
5168 return ShapedType::isDynamic(shape);
5169 })) {
5170 return op->emitError("mismatch in inner tile sizes specified and shaped of "
5171 "tiled dimension in the packed type");
5172 }
5173 if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
5174 packedType.getShape()))) {
5175 return op->emitError("expected ")
5176 << expectedPackedType << " for the packed domain value, got "
5177 << packedType;
5178 }
5179 return success();
5180}
5181
5182namespace {
5183/// Subset of PackOp/UnPackOp fields used to compute the result of applying
5184/// various permutations to the op.
5185// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
5186// these. These may or may not become true foldings / canonicalizations
5187// depending on how aggressive we want to be in automatically folding
5188// transposes.
5189struct PackOrUnPackTransposeResult {
5190 SmallVector<int64_t> innerDimsPos;
5191 SmallVector<OpFoldResult> innerTiles;
5192 SmallVector<int64_t> outerDimsPerm;
5193};
5194} // namespace
5195
5196template <typename OpTy>
5197static PackOrUnPackTransposeResult
5199 ArrayRef<int64_t> innerPermutation,
5200 ArrayRef<int64_t> outerPermutation) {
5201 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5202 "applies to only pack or unpack operations");
5203 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5204 "some permutation must be non-empty");
5205 PackOrUnPackTransposeResult metadata;
5206 metadata.innerDimsPos =
5207 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
5208 metadata.innerTiles =
5209 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
5210 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5211 ? packOrUnPackOp.getSourceRank()
5212 : packOrUnPackOp.getDestRank();
5213 metadata.outerDimsPerm =
5214 packOrUnPackOp.getOuterDimsPerm().empty()
5215 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5216 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
5217 if (!innerPermutation.empty()) {
5218 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5219 isPermutationVector(innerPermutation) &&
5220 "invalid inner permutation");
5221 applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
5222 applyPermutationToVector(metadata.innerTiles, innerPermutation);
5223 }
5224 if (!outerPermutation.empty()) {
5225 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5226 isPermutationVector(outerPermutation) &&
5227 "invalid outer permutation");
5228 applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
5229 }
5230 return metadata;
5231}
5232
5233//===----------------------------------------------------------------------===//
5234// PackOp
5235//===----------------------------------------------------------------------===//
5236
5237void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
5238 setNameFn(getResult(), "pack");
5239}
5240
5241void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5242 Value dest, ArrayRef<int64_t> innerDimsPos,
5243 ArrayRef<OpFoldResult> innerTiles,
5244 std::optional<Value> paddingValue,
5245 ArrayRef<int64_t> outerDimsPerm) {
5246 assert(innerDimsPos.size() == innerTiles.size() &&
5247 "number of tile sizes specified must match the specified number of "
5248 "original dimensions to be tiled");
5249 SmallVector<int64_t> staticTileSizes;
5250 SmallVector<Value> dynamicTileSizes;
5251 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5252 build(builder, state, dest.getType(), source, dest,
5253 paddingValue ? *paddingValue : nullptr,
5254 outerDimsPerm.empty() ? nullptr
5255 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5256 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5257 builder.getDenseI64ArrayAttr(staticTileSizes));
5258}
5259
5260LogicalResult
5261PackOp::reifyResultShapes(OpBuilder &builder,
5262 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5263 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5264}
5265
5266DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
5267 return getDimAndTileMappingImpl(*this);
5268}
5269
5270SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5271 return getMixedTilesImpl(*this);
5272}
5273
5274SmallVector<int64_t> PackOp::getStaticTiles() {
5275 return getStaticTilesImpl(*this);
5276}
5277
5278ArrayRef<int64_t> PackOp::getAllOuterDims() {
5279 ShapedType inputType = getSourceType();
5280 int64_t inputRank = inputType.getRank();
5281 return getDestType().getShape().take_front(inputRank);
5282}
5283
5284SmallVector<int64_t> PackOp::getTiledOuterDims() {
5285 auto innerDimsPos = getInnerDimsPos();
5286 SmallVector<int64_t> outerDims(getAllOuterDims());
5288
5289 // Recover the original order of the outer dims.
5290 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5291 invertPermutationVector(outerDimPermInv);
5292 if (!outerDimPermInv.empty())
5293 applyPermutationToVector(outerDims, outerDimPermInv);
5294
5295 // Collect the outer dims corresponding to the tilled inner dims.
5296 for (auto index : innerDimsPos)
5297 res.push_back(outerDims[index]);
5298
5299 return res;
5300}
5301
5302bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5303 ArrayRef<int64_t> innerDimsPos,
5304 ArrayRef<int64_t> outputShape,
5305 ArrayRef<int64_t> outerDimsPerm,
5306 ArrayRef<OpFoldResult> innerTiles) {
5307 SmallVector<int64_t> outputTileSizes(
5308 outputShape.take_front(inputShape.size()));
5309 if (!outerDimsPerm.empty()) {
5310 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5311 "expected output and outer_dims_perm to have same size");
5312 applyPermutationToVector(outputTileSizes,
5313 invertPermutationVector(outerDimsPerm));
5314 }
5315 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5316 if (ShapedType::isDynamic(inputShape[pos]))
5317 continue;
5318 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5319
5320 if (!constantTile) {
5321 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5322 (inputShape[pos] % outputTileSizes[pos] != 0))
5323 return true;
5324 } else if (inputShape[pos] % (*constantTile) != 0) {
5325 return true;
5326 }
5327 }
5328 return false;
5329}
5330
5331bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5332 ArrayRef<int64_t> innerDimsPos,
5333 ArrayRef<int64_t> outputShape,
5334 ArrayRef<int64_t> outerDimsPerm,
5335 ArrayRef<OpFoldResult> innerTiles) {
5336 SmallVector<int64_t> outputTileSizes(
5337 outputShape.take_front(inputShape.size()));
5338 if (!outerDimsPerm.empty()) {
5339 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5340 "expected output and outer_dims_perm to have same size");
5341 applyPermutationToVector(outputTileSizes,
5342 invertPermutationVector(outerDimsPerm));
5343 }
5344 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5345 if (ShapedType::isDynamic(inputShape[pos]) ||
5346 ShapedType::isDynamic(outputTileSizes[pos]))
5347 return true;
5348 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5349 if (!constantTile)
5350 return true;
5351 if (inputShape[pos] % (*constantTile) != 0)
5352 return true;
5353 }
5354 return false;
5355}
5356
5357LogicalResult PackOp::verify() {
5359 return failure();
5360
5361 // Verify padding value, and bail out if the tile does not divide the
5362 // dimension fully. In the case of dynamic tile factors or dimensions, having
5363 // a partial tile is undefined behavior.
5364 auto paddingValue = getPaddingValue();
5365 if (paddingValue &&
5366 paddingValue.getType() != getSourceType().getElementType()) {
5367 return emitOpError("expected padding_value has ")
5368 << getSourceType().getElementType()
5369 << " but got: " << paddingValue.getType();
5370 }
5371
5372 if (!paddingValue &&
5373 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
5374 getDestType().getShape(), getOuterDimsPerm(),
5375 getMixedTiles())) {
5376 return emitOpError(
5377 "invalid tile factor or output size provided. Only full tiles are "
5378 "supported when padding_value is not set");
5379 }
5380 return success();
5381}
5382
5383/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
5384/// Value's to kDynamic, even if they are arith.constant values.
5388 for (auto o : ofrs) {
5389 // Have to do this first, as getConstantIntValue special-cases constants.
5390 if (llvm::dyn_cast_if_present<Value>(o))
5391 result.push_back(ShapedType::kDynamic);
5392 else
5393 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
5394 }
5395 return result;
5396}
5397
5398/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
5399/// the packed type. Having a shared helper helps implement these two methods in
5400/// a way that ensures that they agree on which dimensions are dynamic.
5402 ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
5403 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5404 SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
5405 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5406 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5407 continue;
5408 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5409 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5410 continue;
5411 }
5412 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5413 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5414 }
5415
5416 // Swap tile loops if outer_dims_perm is available.
5417 if (!outerDimsPerm.empty())
5418 applyPermutationToVector(resultShape, outerDimsPerm);
5419
5420 // Append the inner tile dimensions.
5421 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5422 return resultShape;
5423}
5424
5425SmallVector<OpFoldResult> PackOp::getResultShape(
5426 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5427 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5428 ArrayRef<int64_t> outerDimsPerm) {
5429 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5430
5431 AffineExpr s0, s1;
5432 bindSymbols(builder.getContext(), s0, s1);
5433 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
5434 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5435 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
5436 builder, loc, ceilDivExpr,
5437 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5438 }
5439 if (!outerDimsPerm.empty())
5440 applyPermutationToVector(resultDims, outerDimsPerm);
5441 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5442
5443 SmallVector<int64_t> resultTypeShape =
5445 asShapeWithAnyValueAsDynamic(innerTileSizes),
5446 innerDimsPos, outerDimsPerm);
5447
5448 // Fix-up `resultDims` to ensure that they are Value's if and only if the
5449 // result type shape says it's a dynamic dim. This is needed as callers may
5450 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
5451 // dynamic dims returned by that.
5452 for (unsigned i = 0; i < resultDims.size(); ++i) {
5453 if (ShapedType::isStatic(resultTypeShape[i]))
5454 continue;
5455 resultDims[i] =
5456 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
5457 }
5458
5459 return resultDims;
5460}
5461
5462/// Get the expected packed type based on source type, tile factors, position of
5463/// the inner tiles and permutation of the outer tiled loop.
5464RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
5465 ArrayRef<int64_t> innerTileSizes,
5466 ArrayRef<int64_t> innerDimsPos,
5467 ArrayRef<int64_t> outerDimsPerm) {
5469 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5470 return RankedTensorType::get(resultShape, sourceType.getElementType());
5471}
5472
5473Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
5474 ArrayRef<OpFoldResult> innerTileSizes,
5475 ArrayRef<int64_t> innerDimsPos,
5476 ArrayRef<int64_t> outerDimsPerm) {
5477 AffineExpr dim0, dim1;
5478 bindDims(b.getContext(), dim0, dim1);
5479 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5480 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
5481 {v1, v2});
5482 };
5483
5484 SmallVector<OpFoldResult> mixedSizes;
5485 for (auto [index, value] : llvm::enumerate(
5486 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
5487 if (ShapedType::isDynamic(value))
5488 mixedSizes.push_back(
5489 tensor::DimOp::create(b, loc, source, index).getResult());
5490 else
5491 mixedSizes.push_back(b.getIndexAttr(value));
5492 }
5493 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5494 int64_t dimPos = std::get<0>(it);
5495 OpFoldResult tileSize = std::get<1>(it);
5496 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5497 }
5498 if (!outerDimsPerm.empty())
5499 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5500
5501 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5502 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5503 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5504}
5505
5506PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5507 ArrayRef<int64_t> innerPermutation,
5508 ArrayRef<int64_t> outerPermutation) {
5509 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5510 *this, innerPermutation, outerPermutation);
5511 Value transposedDest =
5512 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5513 metadata.innerDimsPos, metadata.outerDimsPerm);
5514 return PackOp::create(b, loc, getSource(), transposedDest,
5515 metadata.innerDimsPos, metadata.innerTiles,
5516 getPaddingValue(), metadata.outerDimsPerm);
5517}
5518
5519/// Returns true if the tiles and the tiled dims are constant.
5520template <typename OpTy>
5522 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5523 "applies to only pack or unpack operations");
5524 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5525 ? op.getDestType()
5526 : op.getSourceType();
5527 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5528 for (auto [dimDest, tile] : llvm::zip(
5529 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5530 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5531 if (!constTileSize || ShapedType::isDynamic(dimDest))
5532 return false;
5533 }
5534 return true;
5535}
5536
5537Speculation::Speculatability PackOp::getSpeculatability() {
5538 if (getPaddingValue())
5540
5541 // The verifier rejects already operations if we can statically prove that the
5542 // sizes of the tiles do not divide perfectly the dimension; thus, check only
5543 // to have constant tiles and tiled inner dimensions.
5546
5548}
5549
5550// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5551// dimensions for pack and unpack.
5552static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5553 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5554 return false;
5555 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5556 return true;
5557 // Outer dims permutation is optional.
5558 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5559 // identity permutation.
5560 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5561 isIdentityPermutation(unPackOp.getOuterDimsPerm());
5562}
5563
5564// Return true if pack and unpack have the same tiles.
5565// Same SSA values or same integer constants.
5566static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5567 auto packTiles = packOp.getMixedTiles();
5568 auto unPackTiles = unPackOp.getMixedTiles();
5569 if (packTiles.size() != unPackTiles.size())
5570 return false;
5571 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5572 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5573 return false;
5574 }
5575 return true;
5576}
5577
5578/// Returns true if the pack op does not need a padding value.
5579static bool paddingIsNotNeeded(PackOp op) {
5580 auto srcType = op.getSourceType();
5581 if (llvm::any_of(op.getInnerDimsPos(),
5582 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5583 return false;
5584 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5585 return false;
5586 return !PackOp::requirePaddingValue(
5587 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5588 op.getOuterDimsPerm(), op.getMixedTiles());
5589}
5590
5591/// Returns true if the `srcShape` or `destShape` is different from the one in
5592/// `packOp` and populates each with the inferred static shape.
5593static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5594 SmallVectorImpl<int64_t> &destShape) {
5595 bool changeNeeded = false;
5596 srcShape.assign(packOp.getSourceType().getShape().begin(),
5597 packOp.getSourceType().getShape().end());
5598 destShape.assign(packOp.getDestType().getShape().begin(),
5599 packOp.getDestType().getShape().end());
5600 llvm::SmallSetVector<int64_t, 4> innerDims;
5601 innerDims.insert_range(packOp.getInnerDimsPos());
5602 SmallVector<int64_t> inverseOuterDimsPerm;
5603 if (!packOp.getOuterDimsPerm().empty())
5604 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5605 int srcRank = packOp.getSourceRank();
5606 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5607 if (innerDims.contains(i))
5608 continue;
5609 int64_t srcPos = i;
5610 int64_t destPos = i;
5611 if (!inverseOuterDimsPerm.empty())
5612 destPos = inverseOuterDimsPerm[srcPos];
5613 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5614 ShapedType::isDynamic(destShape[destPos])) {
5615 continue;
5616 }
5617 int64_t size = srcShape[srcPos];
5618 if (ShapedType::isDynamic(size))
5619 size = destShape[destPos];
5620 srcShape[srcPos] = size;
5621 destShape[destPos] = size;
5622 changeNeeded = true;
5623 }
5624 return changeNeeded;
5625}
5626
5627LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5628 // Fold an pack(unpack(x)) to x.
5629 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5630 if (unPackOp.getSourceType() == packOp.getDestType() &&
5631 !packOp.getPaddingValue() &&
5632 hasSameInnerOuterAttribute(packOp, unPackOp) &&
5633 haveSameTiles(packOp, unPackOp)) {
5634 rewriter.replaceOp(packOp, unPackOp.getSource());
5635 return success();
5636 }
5637 }
5638
5639 // Fold optional PaddingValue operand away if padding is not needed.
5640 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5641 rewriter.startOpModification(packOp);
5642 packOp.getPaddingValueMutable().clear();
5643 rewriter.finalizeOpModification(packOp);
5644 return success();
5645 }
5646
5647 // Insert tensor.cast ops if static shape inference is available..
5648 SmallVector<int64_t> srcShape, destShape;
5649 if (inferStaticShape(packOp, srcShape, destShape)) {
5650 Location loc = packOp.getLoc();
5651 Value source = packOp.getSource();
5652 if (srcShape != packOp.getSourceType().getShape()) {
5653 auto newSrcType = packOp.getSourceType().clone(srcShape);
5654 source =
5655 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5656 }
5657 Value dest = packOp.getDest();
5658 RankedTensorType originalResultType = packOp.getDestType();
5659 bool needUpdateDestType = (destShape != originalResultType.getShape());
5660 if (needUpdateDestType) {
5661 auto newDestType = packOp.getDestType().clone(destShape);
5662 dest =
5663 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5664 }
5665 rewriter.modifyOpInPlace(packOp, [&] {
5666 packOp.getSourceMutable().assign(source);
5667 packOp.getDestMutable().assign(dest);
5668 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5669 });
5670 // Insert a cast if needed
5671 if (needUpdateDestType) {
5672 rewriter.setInsertionPointAfter(packOp);
5673 auto castOp =
5674 tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
5675 rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
5676 }
5677 return success();
5678 }
5679
5680 return failure();
5681}
5682
5683template <typename PackOrUnpackOp>
5684static bool isLikePadUnPad(PackOrUnpackOp packOp,
5685 RankedTensorType packedTensorType) {
5686 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5687 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5688 "Function meant for pack/unpack");
5689 // This is a pad if packing only adds ones and we don't transpose dimensions.
5690
5691 // Check that we are not transposing any dimensions.
5692 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5693 int64_t numPackedDims = innerDimsPos.size();
5694 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5695 if (orderedDims != innerDimsPos) {
5696 // Dimensions don't happen in order.
5697 return false;
5698 }
5699
5700 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5701 int64_t packedRank = packedTensorType.getRank();
5702 // At this point we know that we are taking numPackedDims outer
5703 // dimensions and pushing them all the way as the inner most dimensions.
5704 // What's left on the outer most dimensions is, in this order:
5705 // - the factor of the packed dimensions, then
5706 // - the untouched dimensions
5707 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5708 // if all the dimensions that bubble outerward are ones.
5709 // Therefore check that all the dimensions but the numPackedDims inner most
5710 // ones are ones.
5711 return llvm::all_of(
5712 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5713 [&packedShape](int64_t i) { return packedShape[i] == 1; });
5714}
5715
5716bool PackOp::isLikePad() {
5717 auto packedTensorType =
5718 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5719 return isLikePadUnPad(*this, packedTensorType);
5720}
5721
5722OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5723 std::optional<Attribute> paddingValue;
5724 if (auto pad = adaptor.getPaddingValue())
5725 paddingValue = pad;
5726 if (OpFoldResult reshapedSource = reshapeConstantSource(
5727 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5728 getDestType(), paddingValue))
5729 return reshapedSource;
5730 return {};
5731}
5732
5733/// Folds a tensor.cast op into a consuming PackOp op if the
5734/// `tensor.cast` has source that is more static than the consuming op.
5735///
5736/// Example:
5737/// ```mlir
5738/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5739/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5740/// ```
5741///
5742/// folds into:
5743///
5744/// ```mlir
5745/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5746/// ```
5749
5750 LogicalResult matchAndRewrite(PackOp op,
5751 PatternRewriter &rewriter) const override {
5753 return failure();
5754
5755 SmallVector<Type> newResultTypes(op->getResultTypes());
5756 SmallVector<Value> newOperands =
5758
5759 // Get the updated mixed-tile-sizes attribute.
5760 SmallVector<OpFoldResult> newMixedTileSizes =
5761 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
5762
5763 // Clone op.
5764 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5765 // this point. However, in practice, we use them for things that we'd like
5766 // to preserve. Implement a better abstraction.
5767 PackOp newOp =
5768 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
5769 op.getInnerDimsPos(), newMixedTileSizes,
5770 op.getPaddingValue(), op.getOuterDimsPerm());
5771 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5772
5773 // Replace op.
5774 Value oldResult = op.getResult();
5775 Value newResult = newOp.getResult();
5777 (newResult.getType() != oldResult.getType())
5778 ? tensor::CastOp::create(rewriter, op->getLoc(),
5779 oldResult.getType(), newResult)
5780 : newResult;
5781
5782 rewriter.replaceOp(op, {replacement});
5783
5784 return success();
5785 }
5786};
5787
5788//===----------------------------------------------------------------------===//
5789// UnPackOp
5790//===----------------------------------------------------------------------===//
5791
5792void UnPackOp::getAsmResultNames(
5793 function_ref<void(Value, StringRef)> setNameFn) {
5794 setNameFn(getResult(), "unpack");
5795}
5796
5797LogicalResult
5798UnPackOp::reifyResultShapes(OpBuilder &builder,
5799 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5800 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5801}
5802
5803DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
5804 return getDimAndTileMappingImpl(*this);
5805}
5806
5807SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
5808 return getMixedTilesImpl(*this);
5809}
5810
5811SmallVector<int64_t> UnPackOp::getStaticTiles() {
5812 return getStaticTilesImpl(*this);
5813}
5814
5815ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
5816 ShapedType destType = getDestType();
5817 int64_t destRank = destType.getRank();
5818 return getSourceType().getShape().take_front(destRank);
5819}
5820
5821SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
5822 auto innerDimsPos = getInnerDimsPos();
5823 SmallVector<int64_t> outerDims(getAllOuterDims());
5825
5826 // Recover the original order of the outer dims.
5827 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5828 invertPermutationVector(outerDimPermInv);
5829 if (!outerDimPermInv.empty())
5830 applyPermutationToVector(outerDims, outerDimPermInv);
5831
5832 // Collect the outer dims corresponding to the tilled inner dims.
5833 for (auto index : innerDimsPos)
5834 res.push_back(outerDims[index]);
5835
5836 return res;
5837}
5838
5839LogicalResult UnPackOp::verify() {
5840 return commonVerifierPackAndUnPackOp(*this);
5841}
5842
5843Speculation::Speculatability UnPackOp::getSpeculatability() {
5844 // See PackOp::getSpeculatability.
5847
5849}
5850
5851void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
5852 Value dest, ArrayRef<int64_t> innerDimsPos,
5853 ArrayRef<OpFoldResult> innerTiles,
5854 ArrayRef<int64_t> outerDimsPerm) {
5855 assert(innerDimsPos.size() == innerTiles.size() &&
5856 "number of tile sizes specified must match the specified number of "
5857 "original dimensions to be tiled");
5858 SmallVector<int64_t> staticTileSizes;
5859 SmallVector<Value> dynamicTileSizes;
5860 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5861 build(builder, state, dest.getType(), source, dest,
5862 outerDimsPerm.empty() ? nullptr
5863 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5864 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5865 builder.getDenseI64ArrayAttr(staticTileSizes));
5866}
5867
5868Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
5869 Value source,
5870 ArrayRef<OpFoldResult> innerTileSizes,
5871 ArrayRef<int64_t> innerDimsPos,
5872 ArrayRef<int64_t> outerDimsPerm) {
5873 AffineExpr sym0, sym1;
5874 bindSymbols(b.getContext(), sym0, sym1);
5875 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5876 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
5877 };
5878
5879 SmallVector<OpFoldResult> mixedSizes;
5880 auto srcType = llvm::cast<RankedTensorType>(source.getType());
5881 for (auto i :
5882 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5883 if (srcType.isDynamicDim(i))
5884 mixedSizes.push_back(
5885 tensor::DimOp::create(b, loc, source, i).getResult());
5886 else
5887 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
5888 }
5889 if (!outerDimsPerm.empty()) {
5891 mixedSizes, invertPermutationVector(outerDimsPerm));
5892 }
5893
5894 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5895 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5896
5897 auto elemType = srcType.getElementType();
5898 return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5899}
5900
5901UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5902 Value transposedSource,
5903 ArrayRef<int64_t> innerPermutation,
5904 ArrayRef<int64_t> outerPermutation) {
5905 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5906 *this, innerPermutation, outerPermutation);
5907 return UnPackOp::create(b, loc, transposedSource, getDest(),
5908 metadata.innerDimsPos, metadata.innerTiles,
5909 metadata.outerDimsPerm);
5910}
5911
5912/// Returns true if the `srcShape` or `destShape` is different from the one in
5913/// `op` and populates each with the inferred static shape.
5914static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
5915 SmallVectorImpl<int64_t> &destShape) {
5916 bool changeNeeded = false;
5917 srcShape.assign(op.getSourceType().getShape().begin(),
5918 op.getSourceType().getShape().end());
5919 destShape.assign(op.getDestType().getShape().begin(),
5920 op.getDestType().getShape().end());
5921 llvm::SmallSetVector<int64_t, 4> innerDims;
5922 innerDims.insert_range(op.getInnerDimsPos());
5923 SmallVector<int64_t> inverseOuterDimsPerm;
5924 if (!op.getOuterDimsPerm().empty())
5925 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
5926 int destRank = op.getDestRank();
5927 for (auto i : llvm::seq<int64_t>(0, destRank)) {
5928 if (innerDims.contains(i))
5929 continue;
5930 int64_t srcPos = i;
5931 int64_t destPos = i;
5932 if (!inverseOuterDimsPerm.empty())
5933 srcPos = inverseOuterDimsPerm[destPos];
5934 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5935 ShapedType::isDynamic(destShape[destPos])) {
5936 continue;
5937 }
5938 int64_t size = srcShape[srcPos];
5939 if (ShapedType::isDynamic(size))
5940 size = destShape[destPos];
5941 srcShape[srcPos] = size;
5942 destShape[destPos] = size;
5943 changeNeeded = true;
5944 }
5945 return changeNeeded;
5946}
5947
5948LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5949 PatternRewriter &rewriter) {
5950 /// unpack(pack(x)) -> x
5951 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5952 if (packOp.getSourceType() != unPackOp.getDestType())
5953 return failure();
5954 if (packOp.getPaddingValue() ||
5955 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5956 !haveSameTiles(packOp, unPackOp))
5957 return failure();
5958 rewriter.replaceOp(unPackOp, packOp.getSource());
5959 return success();
5960 }
5961 /// unpack(destinationStyleOp(x)) -> unpack(x)
5962 if (auto dstStyleOp =
5963 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5964 auto destValue = cast<OpResult>(unPackOp.getDest());
5965 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5966 rewriter.modifyOpInPlace(unPackOp,
5967 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5968 return success();
5969 }
5970 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
5971 if (unPackOp->hasOneUse()) {
5972 auto extractSliceUser =
5973 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5974 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
5975 OpBuilder::InsertionGuard g(rewriter);
5976 rewriter.setInsertionPoint(unPackOp);
5977 auto newDest = tensor::ExtractSliceOp::create(
5978 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
5979 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5980 extractSliceUser.getMixedStrides());
5981 rewriter.modifyOpInPlace(unPackOp, [&]() {
5982 unPackOp.setDpsInitOperand(0, newDest);
5983 unPackOp.getResult().setType(newDest.getType());
5984 });
5985 rewriter.replaceOp(extractSliceUser, unPackOp);
5986 return success();
5987 }
5988 }
5989
5990 // Insert tensor.cast ops if static shape inference is available..
5991 SmallVector<int64_t> srcShape, destShape;
5992 if (inferStaticShape(unPackOp, srcShape, destShape)) {
5993 Location loc = unPackOp.getLoc();
5994 Value source = unPackOp.getSource();
5995 if (srcShape != unPackOp.getSourceType().getShape()) {
5996 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5997 source = tensor::CastOp::create(rewriter, loc, newSrcType,
5998 unPackOp.getSource());
5999 }
6000 Value dest = unPackOp.getDest();
6001 if (destShape != unPackOp.getDestType().getShape()) {
6002 auto newDestType = unPackOp.getDestType().clone(destShape);
6003 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6004 unPackOp.getDest());
6005 }
6006 Value newOp = UnPackOp::create(
6007 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6008 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6009 rewriter.replaceOpWithNewOp<tensor::CastOp>(
6010 unPackOp, unPackOp.getResult().getType(), newOp);
6011 return success();
6012 }
6013
6014 return failure();
6015}
6016
6017bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6018 // Rank-reduced folding is not supported.
6019 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6020 return false;
6021 if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
6022 !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
6023 return false;
6024 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6025 SmallVector<int64_t> outerShapeWithoutTranspose =
6027 for (auto [pos, tileSize] :
6028 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6029 if (unpackedTypeAfterFold.isDynamicDim(pos))
6030 return false;
6031 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6032 return false;
6033 if (ShapedType::isDynamic(tileSize))
6034 return false;
6035 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6036 unpackedTypeAfterFold.getDimSize(pos);
6037 if (paddingSize >= tileSize)
6038 return false;
6039 }
6040 return true;
6041}
6042
6043bool UnPackOp::isLikeUnPad() {
6044 RankedTensorType packedTensorType = getSourceType();
6045 return isLikePadUnPad(*this, packedTensorType);
6046}
6047
6048OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
6049 if (OpFoldResult reshapedSource = reshapeConstantSource(
6050 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6051 getResult().getType()))
6052 return reshapedSource;
6053 return {};
6054}
6055
6056/// Folds a tensor.cast op into a consuming UnPackOp op if the
6057/// `tensor.cast` has source that is more static than the consuming op.
6058///
6059/// Example:
6060/// ```mlir
6061/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
6062/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
6063/// ```
6064///
6065/// folds into:
6066///
6067/// ```mlir
6068/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
6069/// ```
6070struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
6071 using OpRewritePattern<UnPackOp>::OpRewritePattern;
6072
6073 LogicalResult matchAndRewrite(UnPackOp op,
6074 PatternRewriter &rewriter) const override {
6076 return failure();
6077
6078 SmallVector<Type> newResultTypes(op->getResultTypes());
6079 SmallVector<Value> newOperands =
6081 Value sourceTensor = newOperands[0];
6082
6083 // Get the updated mixed-tile-sizes attribute.
6085 rewriter, sourceTensor.getType(), op.getMixedTiles());
6086
6087 // Clone op.
6088 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6089 // this point. However, in practice, we use them for things that we'd like
6090 // to preserve. Implement a better abstraction.
6091 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6092 newOperands[1], op.getInnerDimsPos(),
6093 newMixedTileSizes, op.getOuterDimsPerm());
6094 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6095
6096 // Replace op.
6097 Value oldResult = op.getResult();
6098 Value newResult = newOp.getResult();
6100 (newResult.getType() != oldResult.getType())
6101 ? tensor::CastOp::create(rewriter, op->getLoc(),
6102 oldResult.getType(), newResult)
6103 : newResult;
6104
6105 rewriter.replaceOp(op, {replacement});
6106
6107 return success();
6108 }
6109};
6110
6111//===----------------------------------------------------------------------===//
6112// BatchReduceMatmulOp
6113//===----------------------------------------------------------------------===//
6114SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
6116 utils::IteratorType::reduction, utils::IteratorType::parallel,
6117 utils::IteratorType::parallel, utils::IteratorType::reduction};
6118}
6119
6121BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6122 AffineExpr d0, d1, d2, d3;
6123 SmallVector<AffineMap> indexingMaps;
6124 bindDims(context, d0, d1, d2, d3);
6125 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
6126 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
6127 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
6128 return indexingMaps;
6129}
6130
6131bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6132 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6133 if (!maps)
6134 return false;
6135 if (maps.size() != 3)
6136 return false;
6137 auto positions = getAffineResultPositions(maps);
6138 if (failed(positions))
6139 return false;
6140 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6141 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6142 (*positions)[2] == SmallVector<int64_t>{1, 2};
6143}
6144unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
6145
6146std::string BatchReduceMatmulOp::getLibraryCallName() {
6147 return generateLibraryCallName(getOperation());
6148}
6149
6150/// Check if the op has broadcast and/or transpose semantic. Returns true if
6151/// the user defined indexing maps are not equal to default map.
6152bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6153 SmallVector<AffineMap, 3> defaultMaps =
6154 getDefaultIndexingMaps(this->getContext());
6155 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6156 return defaultMaps != explicitMaps;
6157}
6158
6159/// Returns true if the given bcastMap map is a valid broadcast map. A valid
6160/// broadcast map must include K dimension.
6161/// TODO: Strict inclusion of K dimension in the broadcast map is not
6162/// necessary for both input matrices simultaneously. We can relax this
6163/// condition to have K dimension for one input matrix map and infer the K
6164/// dimension for other input matrix map from the one already having K
6165/// dimension.
6166bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6167 bool isLHS) {
6168 assert(bcastMap.getNumResults() < 3 &&
6169 "Expected less than 3 result dim expr.");
6170 bool isValid = false;
6171 enum Indices { batchPos, mPos, nPos, kPos };
6172 if (bcastMap.getNumResults() == 1) {
6173 AffineExpr expr = bcastMap.getResult(0);
6174 isValid = expr.isFunctionOfDim(kPos);
6175 } else if (bcastMap.getNumResults() == 2) {
6176 AffineExpr expr0 = bcastMap.getResult(0);
6177 AffineExpr expr1 = bcastMap.getResult(1);
6178 isValid =
6179 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
6180 expr0.isFunctionOfDim(mPos)) &&
6181 expr1.isFunctionOfDim(kPos))
6182 : ((expr0.isFunctionOfDim(batchPos) &&
6183 expr1.isFunctionOfDim(kPos)) ||
6184 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6185 }
6186 return isValid;
6187}
6188
6189void BatchReduceMatmulOp::regionBuilder(
6192 if (emitError && block.getNumArguments() != 3) {
6193 emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
6194 << block.getNumArguments();
6195 return;
6196 }
6197 assert(block.getNumArguments() == 3 &&
6198 "BatchReduceMatmulOp regionBuilder expects 3 args");
6199 RegionBuilderHelper helper(b, block);
6200 SmallVector<Value> yields;
6201
6202 auto toType = block.getArgument(2).getType();
6203 Value castValA =
6204 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
6205 Value castValB =
6206 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
6207 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
6208 Value addVal =
6209 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
6210 yields.push_back(addVal);
6211 helper.yieldOutputs(yields);
6212}
6213
6214ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6216 SmallVector<Attribute, 3> indexingMapsAttr;
6217 Attribute mapAttr;
6218 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
6219 if (parser.parseEqual())
6220 return failure();
6221 if (parser.parseLSquare())
6222 return failure();
6223
6224 do {
6225 if (parser.parseAttribute(mapAttr))
6226 return failure();
6227 if (!isa<AffineMapAttr>(mapAttr)) {
6228 return parser.emitError(parser.getCurrentLocation(),
6229 "expected affine map attribute");
6230 }
6231 indexingMapsAttr.push_back(mapAttr);
6232
6233 if (parser.parseOptionalComma())
6234 break;
6235 } while (true);
6236
6237 if (parser.parseRSquare())
6238 return failure();
6239 }
6240 // Initialize indexingMaps, if not supplied explicitly.
6241 if (indexingMapsAttr.empty()) {
6242 indexingMapsAttr = llvm::map_to_vector(
6243 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
6244 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6245 }
6246 result.addAttribute("indexing_maps",
6247 parser.getBuilder().getArrayAttr(indexingMapsAttr));
6248 return ::parseNamedStructuredOp(parser, result,
6249 BatchReduceMatmulOp::getNumRegionArgs(),
6250 BatchReduceMatmulOp::getRegionBuilder());
6251}
6252
6253void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6254 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6255 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
6256 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6257
6258 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6259 p << " indexing_maps = [";
6260 llvm::interleaveComma(getIndexingMaps(), p,
6261 [&](Attribute attr) { p.printAttribute(attr); });
6262 p << "]";
6263 }
6264
6265 SmallVector<StringRef, 3> elidedAttrs = {
6266 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
6267 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
6268 elidedAttrs);
6269}
6270
6271/// Verify the user defined indexing maps.
6272LogicalResult BatchReduceMatmulOp::verify() {
6273 // Verification of pure batch_reduce_matmul is handled by
6274 // verifyStructuredOpInterface().
6275 if (!hasUserDefinedMaps())
6276 return success();
6277
6278 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
6280 return failure();
6281 }
6282 return success();
6283}
6284LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6286 return memref::foldMemRefCast(*this);
6287}
6288void BatchReduceMatmulOp::getEffects(
6290 &effects) {
6291 if (hasPureTensorSemantics())
6292 return;
6293 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
6294}
6295
6296Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
6297 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
6298}
6299
6300} // namespace linalg
6301} // namespace mlir
6302
6303//===----------------------------------------------------------------------===//
6304// LinalgDialect
6305//===----------------------------------------------------------------------===//
6306
6307void LinalgDialect::getCanonicalizationPatterns(
6308 RewritePatternSet &results) const {
6309 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
6310 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
6311}
6312
6313Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
6314 Attribute value, Type type,
6315 Location loc) {
6316 return arith::ConstantOp::materialize(builder, value, type, loc);
6317}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition LinalgOps.cpp:57
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
Base type for affine expression.
Definition AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
OpListType & getOperations()
Definition Block.h:137
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
BlockArgListType getArguments()
Definition Block.h:87
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
Location getUnknownLoc()
Definition Builders.cpp:25
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
result_iterator result_begin()
Definition Operation.h:413
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
unsigned getNumOperands()
Definition Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_iterator result_end()
Definition Operation.h:414
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator end()
Definition Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
Definition Linalg.h:251
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
Definition Linalg.h:298
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
Definition Linalg.h:157
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
Definition Linalg.h:204
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition LinalgOps.cpp:95
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:45
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:238
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1293
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override